In [1]:
from __future__ import print_function
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler, SubsetRandomSampler
from torchvision import models, transforms
import os
import numpy as np
import sys

from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity #, download_and_extract_archive
import pandas
import csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import time
import copy

In [2]:
sys.path.append(os.path.expanduser("~/few-shot-learning/"))
from few_shot.datasets import FashionProductImages, FashionProductImagesSmall

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop((80,60), scale=(0.8, 1.0)),
        # transforms.Resize((80,60)),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((80,60)),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((80,60)),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

#datasets = {
#    classes: {
#        split: FashionProductImages(
#            "~/data",
#            split='train' if split in ['train', 'val'] else 'test',
#            classes=classes,
#            transform=data_transforms[split]
#        ) for split in ["train", "test", "val"]
#    } for classes in ["top", "bottom"]
#}

In [4]:
# fashion = FashionProductImagesSmall("~/data", classes="top", split="train", transform=data_transforms["train"])
fashion = FashionProductImages("~/data", classes="bottom", split="test", transform=data_transforms["train"])

X, y = fashion[10]
# X, y = fashion[10:15] # fails

len(fashion)

# 18000 + 15149 + 5787 + 5505

5505

In [None]:
for i in range(len(fashion)):
    X, y = fashion[i]
    # print(y.shape)
    # if not isinstance(y, int):
    #    import pdb; pdb.set_trace()
    if not (X.shape[1]==80 and X.shape[2]==60):
        import pdb; pdb.set_trace()

counter = 0
for (i, batch) in enumerate(train_loader):
    counter += 64
    print(counter)
    # import pdb; pdb.set_trace()   

In [None]:
class_sample_count = np.bincount(fashion.target_indices, minlength=fashion.n_classes)
print(class_sample_count)
print(len(class_sample_count))
print(sum(class_sample_count))

print(np.unique(fashion.targets))
print(len(np.unique(fashion.targets)))

print(np.unique(fashion.df_meta["articleType"]))
print(len(np.unique(fashion.df_meta["articleType"])))

# all perfumes are from 2017, which means they're all in the test set
fashion.df_meta[
    (fashion.df_meta["articleType"]=="Perfume and Body Mist")
     & (fashion.df_meta["year"] == 2017.0)
]

In [None]:
batch_size = 64

# train_size = int(len(fashion) * 0.9)
# trainset, valset = random_split(fashion, [train_size, len(fashion) - train_size])
trainset = datasets['top']['train']
valset = datasets['top']['val']

train_sampler, train_indices, val_sampler, val_indices = get_train_and_val_sampler(trainset, balanced_training=True)

train_loader = DataLoader(trainset, batch_size=batch_size, num_workers=4, sampler=train_sampler)
val_loader = DataLoader(valset, batch_size=batch_size, num_workers=4, sampler=val_sampler)

dataloaders = {"train": train_loader, "val": val_loader}
# dataset_sizes = {"train": len(train_indices), "val": len(val_indices)}

In [None]:
y_counts = np.zeros(trainset.n_classes)

for batch in train_loader:
    X, y = batch
    y_counts += np.bincount(y, minlength=20)
    
print(y_counts)
print(y_counts / y_counts.sum())
print(y_counts.sum())

y_counts = np.zeros(valset.n_classes)

for batch in val_loader:
    X, y = batch
    y_counts += np.bincount(y, minlength=20)
    
print(y_counts)
print(y_counts / y_counts.sum())
print(y_counts.sum())

In [None]:
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

In [None]:
import matplotlib.pyplot as plt

conf_matrix_tr
plt.imshow(conf_matrix_transfer)

In [1]:
from __future__ import print_function
import torch
import sys
import os

sys.path.append(os.path.expanduser("~/few-shot-learning/"))
from few_shot.transfer import main

In [2]:
main(
    datadir='~/data',
    architecture='resnet50',
    num_workers=8,
    epochs=100,
    start_epoch=0,
    batch_size=64,
    learning_rate=1e-3,
    optimizer=torch.optim.Adam,
    print_freq=10,
    resume=False,
    evaluate=False,
    seed=None,
    gpu=0,
    device=None,
    # dtype=torch.float16
    distributed=True
)

=> using pre-trained model 'resnet50'
=> Running 100 epochs of fine-tuning (top20)
Epoch: [0][  0/282]	Loss 2.9906e+00 (avg: 2.9906e+00)	Acc@1   4.69 (avg:   4.69)	Acc@5  26.56 (avg:  26.56)	Time  8.209 (avg:  8.209)	Data  2.656 (avg:  2.656)
Epoch: [0][ 10/282]	Loss 7.6899e-01 (avg: 1.5784e+00)	Acc@1  76.56 (avg:  54.69)	Acc@5  95.31 (avg:  83.81)	Time  1.863 (avg:  2.431)	Data  0.000 (avg:  0.242)
Epoch: [0][ 20/282]	Loss 8.5276e-01 (avg: 1.2136e+00)	Acc@1  73.44 (avg:  64.06)	Acc@5  98.44 (avg:  90.92)	Time  1.855 (avg:  2.160)	Data  0.000 (avg:  0.127)
Epoch: [0][ 30/282]	Loss 6.4250e-01 (avg: 1.0028e+00)	Acc@1  76.56 (avg:  70.46)	Acc@5  98.44 (avg:  93.70)	Time  1.875 (avg:  2.065)	Data  0.000 (avg:  0.087)
Epoch: [0][ 40/282]	Loss 3.6353e-01 (avg: 8.7969e-01)	Acc@1  90.62 (avg:  73.93)	Acc@5  98.44 (avg:  95.08)	Time  1.865 (avg:  2.019)	Data  0.000 (avg:  0.066)
Epoch: [0][ 50/282]	Loss 4.9074e-01 (avg: 8.0449e-01)	Acc@1  87.50 (avg:  76.07)	Acc@5  98.44 (avg:  95.93)	Time  1.8

Epoch: [1][200/282]	Loss 1.1775e-01 (avg: 2.1660e-01)	Acc@1  98.44 (avg:  93.02)	Acc@5 100.00 (avg:  99.88)	Time  1.876 (avg:  1.885)	Data  0.000 (avg:  0.017)
Epoch: [1][210/282]	Loss 2.1442e-01 (avg: 2.1496e-01)	Acc@1  92.19 (avg:  93.10)	Acc@5 100.00 (avg:  99.87)	Time  1.871 (avg:  1.885)	Data  0.000 (avg:  0.016)
Epoch: [1][220/282]	Loss 2.0315e-01 (avg: 2.1686e-01)	Acc@1  93.75 (avg:  93.04)	Acc@5 100.00 (avg:  99.87)	Time  1.864 (avg:  1.884)	Data  0.000 (avg:  0.015)
Epoch: [1][230/282]	Loss 1.6213e-01 (avg: 2.1520e-01)	Acc@1  96.88 (avg:  93.13)	Acc@5 100.00 (avg:  99.88)	Time  1.861 (avg:  1.883)	Data  0.000 (avg:  0.015)
Epoch: [1][240/282]	Loss 1.9137e-01 (avg: 2.1342e-01)	Acc@1  92.19 (avg:  93.19)	Acc@5 100.00 (avg:  99.88)	Time  1.867 (avg:  1.883)	Data  0.000 (avg:  0.014)
Epoch: [1][250/282]	Loss 1.8515e-01 (avg: 2.1298e-01)	Acc@1  92.19 (avg:  93.17)	Acc@5 100.00 (avg:  99.88)	Time  1.859 (avg:  1.882)	Data  0.000 (avg:  0.014)
Epoch: [1][260/282]	Loss 2.4850e-01 (avg

Epoch: [3][100/282]	Loss 1.0951e-01 (avg: 1.5057e-01)	Acc@1  93.75 (avg:  95.10)	Acc@5 100.00 (avg:  99.95)	Time  1.888 (avg:  1.900)	Data  0.000 (avg:  0.030)
Epoch: [3][110/282]	Loss 1.5334e-01 (avg: 1.4849e-01)	Acc@1  95.31 (avg:  95.16)	Acc@5 100.00 (avg:  99.96)	Time  1.887 (avg:  1.897)	Data  0.000 (avg:  0.027)
Epoch: [3][120/282]	Loss 5.1227e-02 (avg: 1.4699e-01)	Acc@1 100.00 (avg:  95.20)	Acc@5 100.00 (avg:  99.96)	Time  1.872 (avg:  1.894)	Data  0.000 (avg:  0.025)
Epoch: [3][130/282]	Loss 9.6352e-02 (avg: 1.4581e-01)	Acc@1  95.31 (avg:  95.21)	Acc@5 100.00 (avg:  99.96)	Time  1.868 (avg:  1.892)	Data  0.000 (avg:  0.024)
Epoch: [3][140/282]	Loss 1.6042e-01 (avg: 1.4760e-01)	Acc@1  95.31 (avg:  95.15)	Acc@5 100.00 (avg:  99.96)	Time  1.870 (avg:  1.891)	Data  0.000 (avg:  0.022)
Epoch: [3][150/282]	Loss 1.9454e-01 (avg: 1.4797e-01)	Acc@1  95.31 (avg:  95.15)	Acc@5 100.00 (avg:  99.95)	Time  1.872 (avg:  1.890)	Data  0.000 (avg:  0.021)
Epoch: [3][160/282]	Loss 9.1155e-02 (avg

Test: [20/29]	Loss 1.9200e-01 (avg: 1.8660e-01)	Acc@1  92.19 (avg:  92.93)	Acc@5 100.00 (avg:  99.93)
 * Acc@1 92.611 Acc@5 99.889
Epoch: [5][  0/282]	Loss 2.7068e-02 (avg: 2.7068e-02)	Acc@1 100.00 (avg: 100.00)	Acc@5 100.00 (avg: 100.00)	Time  4.712 (avg:  4.712)	Data  2.851 (avg:  2.851)
Epoch: [5][ 10/282]	Loss 9.7999e-02 (avg: 1.2105e-01)	Acc@1  96.88 (avg:  96.88)	Acc@5 100.00 (avg: 100.00)	Time  1.873 (avg:  2.122)	Data  0.000 (avg:  0.261)
Epoch: [5][ 20/282]	Loss 1.1010e-01 (avg: 1.1415e-01)	Acc@1  98.44 (avg:  96.88)	Acc@5 100.00 (avg:  99.93)	Time  1.865 (avg:  2.002)	Data  0.000 (avg:  0.137)
Epoch: [5][ 30/282]	Loss 2.5010e-02 (avg: 1.1366e-01)	Acc@1 100.00 (avg:  96.82)	Acc@5 100.00 (avg:  99.95)	Time  1.874 (avg:  1.959)	Data  0.000 (avg:  0.093)
Epoch: [5][ 40/282]	Loss 5.7810e-02 (avg: 1.0930e-01)	Acc@1  98.44 (avg:  96.95)	Acc@5 100.00 (avg:  99.96)	Time  1.877 (avg:  1.938)	Data  0.000 (avg:  0.071)
Epoch: [5][ 50/282]	Loss 1.5856e-01 (avg: 1.0763e-01)	Acc@1  95.31 (a

Epoch: [6][200/282]	Loss 9.3172e-03 (avg: 5.9064e-02)	Acc@1 100.00 (avg:  98.26)	Acc@5 100.00 (avg:  99.99)	Time  1.858 (avg:  1.885)	Data  0.000 (avg:  0.016)
Epoch: [6][210/282]	Loss 9.3039e-02 (avg: 5.8451e-02)	Acc@1  95.31 (avg:  98.28)	Acc@5 100.00 (avg:  99.99)	Time  1.877 (avg:  1.885)	Data  0.000 (avg:  0.015)
Epoch: [6][220/282]	Loss 2.6254e-02 (avg: 5.7446e-02)	Acc@1  98.44 (avg:  98.30)	Acc@5 100.00 (avg:  99.99)	Time  1.873 (avg:  1.884)	Data  0.000 (avg:  0.015)
Epoch: [6][230/282]	Loss 4.8599e-02 (avg: 5.6898e-02)	Acc@1  96.88 (avg:  98.32)	Acc@5 100.00 (avg:  99.99)	Time  1.870 (avg:  1.883)	Data  0.000 (avg:  0.014)
Epoch: [6][240/282]	Loss 7.6089e-02 (avg: 5.6904e-02)	Acc@1  96.88 (avg:  98.32)	Acc@5 100.00 (avg:  99.99)	Time  1.872 (avg:  1.883)	Data  0.000 (avg:  0.014)
Epoch: [6][250/282]	Loss 4.0791e-02 (avg: 5.6552e-02)	Acc@1  96.88 (avg:  98.33)	Acc@5 100.00 (avg:  99.99)	Time  1.881 (avg:  1.882)	Data  0.000 (avg:  0.013)
Epoch: [6][260/282]	Loss 7.4171e-02 (avg

Epoch: [8][100/282]	Loss 2.2696e-02 (avg: 4.6984e-02)	Acc@1  98.44 (avg:  98.65)	Acc@5 100.00 (avg: 100.00)	Time  1.874 (avg:  1.901)	Data  0.000 (avg:  0.032)
Epoch: [8][110/282]	Loss 3.5098e-02 (avg: 4.8905e-02)	Acc@1  98.44 (avg:  98.52)	Acc@5 100.00 (avg: 100.00)	Time  1.867 (avg:  1.898)	Data  0.000 (avg:  0.029)
Epoch: [8][120/282]	Loss 1.7307e-02 (avg: 4.9331e-02)	Acc@1 100.00 (avg:  98.51)	Acc@5 100.00 (avg: 100.00)	Time  1.876 (avg:  1.896)	Data  0.000 (avg:  0.027)
Epoch: [8][130/282]	Loss 8.8901e-02 (avg: 4.8602e-02)	Acc@1  95.31 (avg:  98.52)	Acc@5 100.00 (avg: 100.00)	Time  1.873 (avg:  1.894)	Data  0.000 (avg:  0.025)
Epoch: [8][140/282]	Loss 8.8435e-02 (avg: 4.7720e-02)	Acc@1  96.88 (avg:  98.56)	Acc@5 100.00 (avg: 100.00)	Time  1.875 (avg:  1.892)	Data  0.000 (avg:  0.023)
Epoch: [8][150/282]	Loss 5.7923e-02 (avg: 4.6341e-02)	Acc@1  98.44 (avg:  98.61)	Acc@5 100.00 (avg: 100.00)	Time  1.867 (avg:  1.891)	Data  0.000 (avg:  0.022)
Epoch: [8][160/282]	Loss 3.2745e-02 (avg

Test: [20/29]	Loss 1.0872e-01 (avg: 2.2209e-01)	Acc@1  98.44 (avg:  94.20)	Acc@5 100.00 (avg:  99.78)
 * Acc@1 94.500 Acc@5 99.833
Epoch: [10][  0/282]	Loss 3.1891e-02 (avg: 3.1891e-02)	Acc@1  98.44 (avg:  98.44)	Acc@5 100.00 (avg: 100.00)	Time 12.540 (avg: 12.540)	Data 10.701 (avg: 10.701)
Epoch: [10][ 10/282]	Loss 1.9893e-02 (avg: 5.9865e-02)	Acc@1 100.00 (avg:  98.30)	Acc@5 100.00 (avg: 100.00)	Time  1.865 (avg:  2.837)	Data  0.000 (avg:  0.974)
Epoch: [10][ 20/282]	Loss 3.7427e-02 (avg: 5.3981e-02)	Acc@1  98.44 (avg:  98.36)	Acc@5 100.00 (avg: 100.00)	Time  1.869 (avg:  2.377)	Data  0.000 (avg:  0.511)
Epoch: [10][ 30/282]	Loss 2.5999e-02 (avg: 4.9844e-02)	Acc@1  98.44 (avg:  98.44)	Acc@5 100.00 (avg: 100.00)	Time  1.876 (avg:  2.214)	Data  0.000 (avg:  0.347)
Epoch: [10][ 40/282]	Loss 8.7287e-03 (avg: 4.4974e-02)	Acc@1 100.00 (avg:  98.59)	Acc@5 100.00 (avg: 100.00)	Time  1.877 (avg:  2.130)	Data  0.000 (avg:  0.262)
Epoch: [10][ 50/282]	Loss 1.6286e-02 (avg: 4.3971e-02)	Acc@1 100

Epoch: [11][190/282]	Loss 3.1639e-03 (avg: 1.5077e-02)	Acc@1 100.00 (avg:  99.63)	Acc@5 100.00 (avg: 100.00)	Time  1.866 (avg:  1.886)	Data  0.000 (avg:  0.018)
Epoch: [11][200/282]	Loss 2.9570e-03 (avg: 1.5026e-02)	Acc@1 100.00 (avg:  99.63)	Acc@5 100.00 (avg: 100.00)	Time  1.882 (avg:  1.886)	Data  0.000 (avg:  0.017)
Epoch: [11][210/282]	Loss 5.3847e-02 (avg: 1.5011e-02)	Acc@1  98.44 (avg:  99.64)	Acc@5 100.00 (avg: 100.00)	Time  1.880 (avg:  1.885)	Data  0.000 (avg:  0.016)
Epoch: [11][220/282]	Loss 8.4640e-03 (avg: 1.4566e-02)	Acc@1 100.00 (avg:  99.65)	Acc@5 100.00 (avg: 100.00)	Time  1.869 (avg:  1.884)	Data  0.000 (avg:  0.016)
Epoch: [11][230/282]	Loss 4.2741e-03 (avg: 1.4327e-02)	Acc@1 100.00 (avg:  99.67)	Acc@5 100.00 (avg: 100.00)	Time  1.867 (avg:  1.884)	Data  0.000 (avg:  0.015)
Epoch: [11][240/282]	Loss 8.8226e-03 (avg: 1.3886e-02)	Acc@1 100.00 (avg:  99.68)	Acc@5 100.00 (avg: 100.00)	Time  1.858 (avg:  1.883)	Data  0.000 (avg:  0.015)
Epoch: [11][250/282]	Loss 6.8424e-

Epoch: [13][ 80/282]	Loss 2.1968e-03 (avg: 1.4657e-02)	Acc@1 100.00 (avg:  99.61)	Acc@5 100.00 (avg: 100.00)	Time  1.863 (avg:  1.910)	Data  0.000 (avg:  0.041)
Epoch: [13][ 90/282]	Loss 6.9413e-03 (avg: 1.4282e-02)	Acc@1 100.00 (avg:  99.61)	Acc@5 100.00 (avg: 100.00)	Time  1.875 (avg:  1.906)	Data  0.000 (avg:  0.037)
Epoch: [13][100/282]	Loss 2.8960e-03 (avg: 1.3500e-02)	Acc@1 100.00 (avg:  99.64)	Acc@5 100.00 (avg: 100.00)	Time  1.871 (avg:  1.902)	Data  0.000 (avg:  0.033)
Epoch: [13][110/282]	Loss 4.0861e-02 (avg: 1.3043e-02)	Acc@1  98.44 (avg:  99.66)	Acc@5 100.00 (avg: 100.00)	Time  1.874 (avg:  1.899)	Data  0.000 (avg:  0.031)
Epoch: [13][120/282]	Loss 1.0729e-02 (avg: 1.3548e-02)	Acc@1 100.00 (avg:  99.66)	Acc@5 100.00 (avg: 100.00)	Time  1.866 (avg:  1.897)	Data  0.000 (avg:  0.028)
Epoch: [13][130/282]	Loss 3.8884e-03 (avg: 1.4379e-02)	Acc@1 100.00 (avg:  99.65)	Acc@5 100.00 (avg: 100.00)	Time  1.861 (avg:  1.895)	Data  0.000 (avg:  0.026)
Epoch: [13][140/282]	Loss 2.1020e-

KeyboardInterrupt: 