In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.insert(0, '../')

import fewshot.data
import fewshot.trainer
import fewshot.focal

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision import models

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Training and results on large dataset

This notebook contains the training and results for the fashion dataset, as well as all the different "tricks" to improve classification accuracy on the rare classes.

NB: I did a little bit of hyperparameter optimization "offline" - mostly around the learning rate, gamma for the focal loss, which number of epochs was appropriate, ...

In [3]:
fewshot.data.fix_quotes('../data/fashion-dataset/styles.csv', '../data/fashion-dataset/styles_quoted.csv')

In [4]:
train_transform = transforms.Compose([
        transforms.Resize(300),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

test_transform = transforms.Compose([
        transforms.Resize(300),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

In [8]:
data = fewshot.data.FashionData('../data/small/styles_quoted.csv',
                                '../data/small/images/',
                               train_transform=train_transform,
                               test_transform=test_transform,
                               top20=True)

In [9]:
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features

## replacing the last layer with the dense layer we need
model.fc = nn.Linear(num_ftrs, data.n_classes)

# Which loss performs better on top20?
### Plain Cross-Entropy:

In [11]:
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
fewshot.trainer.run_training(model, 20, data, optimizer, loss_func, gpu_id=0, root='./traces/')

------------- epoch:  0
Training loss: 0.732
Validation loss: 1.064
Validation accuracy: 0.737
------------- epoch:  1
Training loss: 0.450
Validation loss: 0.975
Validation accuracy: 0.801
------------- epoch:  2
Training loss: 0.386
Validation loss: 0.976
Validation accuracy: 0.795
------------- epoch:  3
Training loss: 0.349
Validation loss: 1.057
Validation accuracy: 0.816
------------- epoch:  4
Training loss: 0.320
Validation loss: 0.821
Validation accuracy: 0.817
------------- epoch:  5
Training loss: 0.281
Validation loss: 1.153
Validation accuracy: 0.792
------------- epoch:  6
Training loss: 0.284
Validation loss: 0.853
Validation accuracy: 0.840
------------- epoch:  7
Training loss: 0.259
Validation loss: 0.959
Validation accuracy: 0.837
------------- epoch:  8
Training loss: 0.241
Validation loss: 0.900
Validation accuracy: 0.842
------------- epoch:  9
Training loss: 0.243
Validation loss: 0.930
Validation accuracy: 0.832
------------- epoch:  10
Training loss: 0.225
Vali

In [15]:
best_checkpoint = torch.load('traces/gpu_0_Jun09_0014/model_epoch_18.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
c1, c5, occ = fewshot.trainer.evaluate(model.cuda(), data, 0, data.n_classes)
fewshot.trainer.pretty_print_eval(c1, c5, occ, data)

Category                    Top-1     Top-5
-------------------------------------------
Tshirts                      89.1      99.2
Shirts                       98.3     100.0
Casual Shoes                 86.6      98.3
Watches                     100.0     100.0
Sports Shoes                 73.9     100.0
Kurtas                       94.9      99.1
Tops                         66.7     100.0
Handbags                     95.0      99.2
Heels                        89.7      99.1
Sunglasses                  100.0     100.0
Wallets                      95.7      99.1
Flip Flops                   87.6      96.5
Sandals                      78.4     100.0
Briefs                       95.5     100.0
Belts                        97.3      99.1
Backpacks                    84.2      99.1
Socks                        77.3      93.8
Formal Shoes                 77.8      99.1
Perfume and Body Mist         0.0       0.0
Jeans                        99.1     100.0
Average                      84.

### Focal Loss 

In [16]:
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, data.n_classes)

loss_func = fewshot.focal.FocalLoss(gamma=2)
optimizer = optim.Adam(model.parameters(), lr=0.001)
fewshot.trainer.run_training(model, 20, data, optimizer, loss_func, gpu_id=0, root='./traces/')

------------- epoch:  0


  logpt = F.log_softmax(input)


Training loss: 0.463
Validation loss: 0.766
Validation accuracy: 0.690
------------- epoch:  1
Training loss: 0.274
Validation loss: 0.628
Validation accuracy: 0.772
------------- epoch:  2
Training loss: 0.209
Validation loss: 0.604
Validation accuracy: 0.779
------------- epoch:  3
Training loss: 0.173
Validation loss: 0.640
Validation accuracy: 0.788
------------- epoch:  4
Training loss: 0.159
Validation loss: 0.620
Validation accuracy: 0.774
------------- epoch:  5
Training loss: 0.155
Validation loss: 0.636
Validation accuracy: 0.791
------------- epoch:  6
Training loss: 0.133
Validation loss: 0.528
Validation accuracy: 0.831
------------- epoch:  7
Training loss: 0.133
Validation loss: 0.625
Validation accuracy: 0.810
------------- epoch:  8
Training loss: 0.117
Validation loss: 0.631
Validation accuracy: 0.814
------------- epoch:  9
Training loss: 0.127
Validation loss: 0.581
Validation accuracy: 0.833
------------- epoch:  10
Training loss: 0.106
Validation loss: 0.588
Valid

In [17]:
best_checkpoint = torch.load('traces/gpu_0_Jun09_0053/model_epoch_17.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
c1, c5, occ = fewshot.trainer.evaluate(model.cuda(), data, 0, data.n_classes)
fewshot.trainer.pretty_print_eval(c1, c5, occ, data)

Category                    Top-1     Top-5
-------------------------------------------
Tshirts                      95.0     100.0
Shirts                       98.3     100.0
Casual Shoes                 84.0      99.2
Watches                     100.0     100.0
Sports Shoes                 89.1     100.0
Kurtas                       96.6      99.1
Tops                         48.7      99.1
Handbags                     84.9      99.2
Heels                        93.2     100.0
Sunglasses                  100.0     100.0
Wallets                      96.6     100.0
Flip Flops                   83.2      98.2
Sandals                      69.0     100.0
Briefs                      100.0     100.0
Belts                        99.1     100.0
Backpacks                    97.4     100.0
Socks                        87.6     100.0
Formal Shoes                 74.1     100.0
Perfume and Body Mist         0.0       0.0
Jeans                        99.1     100.0
Average                      84.

### Weighed Cross Entropy

In [18]:
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, data.n_classes)

loss_func = nn.CrossEntropyLoss(weight=torch.tensor(data.weights).float().cuda())
optimizer = optim.Adam(model.parameters(), lr=0.001)
fewshot.trainer.run_training(model, 20, data, optimizer, loss_func, gpu_id=0, root='./traces/')

------------- epoch:  0
Training loss: 0.913
Validation loss: 0.793
Validation accuracy: 0.728
------------- epoch:  1
Training loss: 0.545
Validation loss: 0.656
Validation accuracy: 0.698
------------- epoch:  2
Training loss: 0.446
Validation loss: 0.570
Validation accuracy: 0.767
------------- epoch:  3
Training loss: 0.387
Validation loss: 0.699
Validation accuracy: 0.787
------------- epoch:  4
Training loss: 0.370
Validation loss: 0.493
Validation accuracy: 0.808
------------- epoch:  5
Training loss: 0.312
Validation loss: 0.489
Validation accuracy: 0.781
------------- epoch:  6
Training loss: 0.315
Validation loss: 0.514
Validation accuracy: 0.810
------------- epoch:  7
Training loss: 0.306
Validation loss: 0.506
Validation accuracy: 0.802
------------- epoch:  8
Training loss: 0.271
Validation loss: 0.474
Validation accuracy: 0.808
------------- epoch:  9
Training loss: 0.270
Validation loss: 0.524
Validation accuracy: 0.806
------------- epoch:  10
Training loss: 0.255
Vali

In [19]:
best_checkpoint = torch.load('traces/gpu_0_Jun09_0136/model_epoch_15.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
c1, c5, occ = fewshot.trainer.evaluate(model.cuda(), data, 0, data.n_classes)
fewshot.trainer.pretty_print_eval(c1, c5, occ, data)

Category                    Top-1     Top-5
-------------------------------------------
Tshirts                      94.1     100.0
Shirts                       97.5     100.0
Casual Shoes                 78.2      98.3
Watches                     100.0     100.0
Sports Shoes                 88.2     100.0
Kurtas                       93.2      99.1
Tops                         61.5      99.1
Handbags                     93.3      98.3
Heels                        81.2      99.1
Sunglasses                  100.0     100.0
Wallets                      94.0     100.0
Flip Flops                   95.6      97.3
Sandals                      69.0     100.0
Briefs                       93.3     100.0
Belts                        99.1      99.1
Backpacks                    87.7      98.2
Socks                        86.6      94.8
Formal Shoes                 83.3      99.1
Perfume and Body Mist         0.0       0.0
Jeans                        99.1     100.0
Average                      84.

### Weighed Focal Loss

In [20]:
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, data.n_classes)

normalized_weights = data.n_classes *data.weights /sum(data.weights)

loss_func = fewshot.focal.FocalLoss(gamma=.5, alpha=torch.tensor(normalized_weights).float().cuda())
optimizer = optim.Adam(model.parameters(), lr=0.001)
fewshot.trainer.run_training(model, 25, data, optimizer, loss_func, gpu_id=0, root='./traces/')

------------- epoch:  0


  logpt = F.log_softmax(input)


Training loss: 0.575
Validation loss: 0.498
Validation accuracy: 0.682
------------- epoch:  1
Training loss: 0.322
Validation loss: 0.403
Validation accuracy: 0.736
------------- epoch:  2
Training loss: 0.254
Validation loss: 0.321
Validation accuracy: 0.748
------------- epoch:  3
Training loss: 0.229
Validation loss: 0.272
Validation accuracy: 0.803
------------- epoch:  4
Training loss: 0.220
Validation loss: 0.242
Validation accuracy: 0.822
------------- epoch:  5
Training loss: 0.194
Validation loss: 0.296
Validation accuracy: 0.800
------------- epoch:  6
Training loss: 0.179
Validation loss: 0.255
Validation accuracy: 0.792
------------- epoch:  7
Training loss: 0.166
Validation loss: 0.239
Validation accuracy: 0.824
------------- epoch:  8
Training loss: 0.151
Validation loss: 0.276
Validation accuracy: 0.785
------------- epoch:  9
Training loss: 0.155
Validation loss: 0.417
Validation accuracy: 0.748
------------- epoch:  10
Training loss: 0.161
Validation loss: 0.258
Valid

In [26]:
best_checkpoint = torch.load('traces/gpu_0_Jun09_0231/model_epoch_19.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
c1, c5, occ = fewshot.trainer.evaluate(model.cuda(), data, 0, data.n_classes)
fewshot.trainer.pretty_print_eval(c1, c5, occ, data)

Category                    Top-1     Top-5
-------------------------------------------
Tshirts                      87.4     100.0
Shirts                       99.2     100.0
Casual Shoes                 81.5      99.2
Watches                     100.0     100.0
Sports Shoes                 82.4      98.3
Kurtas                       88.0     100.0
Tops                         76.1     100.0
Handbags                     84.9      99.2
Heels                        74.4      98.3
Sunglasses                  100.0     100.0
Wallets                      95.7     100.0
Flip Flops                   96.5      97.3
Sandals                      61.2     100.0
Briefs                       95.5     100.0
Belts                       100.0     100.0
Backpacks                    97.4     100.0
Socks                        90.7      99.0
Formal Shoes                 95.4      98.1
Perfume and Body Mist         0.0       0.0
Jeans                        98.1      99.1
Average                      85.

We find that the *Weighed Focal Loss* performs best here, although the difference between the models could still be due to noise, we're going to run with it and use this model to finetune on the rare classes.

## Finetuning on the rare classes

In [27]:
data = fewshot.data.FashionData('../data/fashion-dataset/styles_quoted.csv',
                                '../data/fashion-dataset/images/',
                               train_transform=train_transform,
                               test_transform=test_transform,
                               top20=False)

In [34]:
best_checkpoint = torch.load('traces/gpu_0_Jun09_0231/model_epoch_19.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
num_ftrs = model.fc.in_features
## replacing the last layer with the dense layer we need
model.fc = nn.Linear(num_ftrs, data.n_classes)

In [35]:
normalized_weights = data.n_classes *data.weights /sum(data.weights)

loss_func = fewshot.focal.FocalLoss(gamma=.5, alpha=torch.tensor(normalized_weights).float().cuda())
optimizer = optim.Adam(model.parameters(), lr=0.001)
fewshot.trainer.run_training(model, 35, data, optimizer, loss_func, gpu_id=0, root='./traces/')

------------- epoch:  0


  logpt = F.log_softmax(input)


Training loss: 0.773
Validation loss: 0.778
Validation accuracy: 0.184
------------- epoch:  1
Training loss: 0.499
Validation loss: 0.632
Validation accuracy: 0.235
------------- epoch:  2
Training loss: 0.367
Validation loss: 0.577
Validation accuracy: 0.305
------------- epoch:  3
Training loss: 0.304
Validation loss: 0.561
Validation accuracy: 0.312
------------- epoch:  4
Training loss: 0.249
Validation loss: 0.531
Validation accuracy: 0.308
------------- epoch:  5
Training loss: 0.210
Validation loss: 0.573
Validation accuracy: 0.249
------------- epoch:  6
Training loss: 0.203
Validation loss: 0.641
Validation accuracy: 0.241
------------- epoch:  7
Training loss: 0.197
Validation loss: 0.615
Validation accuracy: 0.301
------------- epoch:  8
Training loss: 0.180
Validation loss: 0.590
Validation accuracy: 0.322
------------- epoch:  9
Training loss: 0.135
Validation loss: 0.610
Validation accuracy: 0.396
------------- epoch:  10
Training loss: 0.121
Validation loss: 0.632
Valid

In [39]:
best_checkpoint = torch.load('traces/gpu_0_Jun09_1537/model_epoch_34.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
c1, c5, occ = fewshot.trainer.evaluate(model.cuda(), data, 0, data.n_classes)
fewshot.trainer.pretty_print_eval(c1, c5, occ, data)

Category                    Top-1     Top-5
-------------------------------------------
Jeans                        89.3     100.0
Trousers                     60.7     100.0
Flats                        85.7     100.0
Bra                          92.9      96.4
Dresses                      28.0      68.0
Earrings                    100.0     100.0
Track Pants                  77.8     100.0
Deodorant                     0.0       0.0
Nail Polish                   0.0       0.0
Sweatshirts                  67.9      92.9
Clutches                     75.0     100.0
Innerwear Vests               0.0     100.0
Lipstick                      0.0       0.0
Sweaters                     39.3      78.6
Jackets                      44.0      88.0
Ties                         95.2      95.2
Caps                         85.2     100.0
Kurtis                       40.0      92.0
Tunics                       27.8     100.0
Capris                       55.0      85.0
Pendant                      47.

## Compare this to not finetuning:

In [40]:
data = fewshot.data.FashionData('../data/fashion-dataset/styles_quoted.csv',
                                '../data/fashion-dataset/images/',
                               train_transform=train_transform,
                               test_transform=test_transform,
                               top20=False)

In [42]:
model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, data.n_classes)

In [43]:
normalized_weights = data.n_classes *data.weights /sum(data.weights)
loss_func = fewshot.focal.FocalLoss(gamma=.5, alpha=torch.tensor(normalized_weights).float().cuda())
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [44]:
fewshot.trainer.run_training(model, 35, data, optimizer, loss_func, gpu_id=0, root='./traces/')

------------- epoch:  0


  logpt = F.log_softmax(input)


Training loss: 0.900
Validation loss: 17.175
Validation accuracy: 0.010
------------- epoch:  1
Training loss: 0.912
Validation loss: 1.157
Validation accuracy: 0.003
------------- epoch:  2
Training loss: 0.826
Validation loss: 0.951
Validation accuracy: 0.050
------------- epoch:  3
Training loss: 0.788
Validation loss: 0.937
Validation accuracy: 0.005
------------- epoch:  4
Training loss: 0.756
Validation loss: 0.939
Validation accuracy: 0.005
------------- epoch:  5
Training loss: 0.743
Validation loss: 0.942
Validation accuracy: 0.045
------------- epoch:  6
Training loss: 0.727
Validation loss: 1.918
Validation accuracy: 0.036
------------- epoch:  7
Training loss: 0.692
Validation loss: 0.878
Validation accuracy: 0.045
------------- epoch:  8
Training loss: 0.628
Validation loss: 0.829
Validation accuracy: 0.109
------------- epoch:  9
Training loss: 0.615
Validation loss: 0.883
Validation accuracy: 0.087
------------- epoch:  10
Training loss: 0.577
Validation loss: 0.946
Vali

In [45]:
best_checkpoint = torch.load('traces/gpu_0_Jun10_1108/model_epoch_34.pth')
model.load_state_dict(best_checkpoint['model_state_dict'])
c1, c5, occ = fewshot.trainer.evaluate(model.cuda(), data, 0, data.n_classes)
fewshot.trainer.pretty_print_eval(c1, c5, occ, data)

Category                    Top-1     Top-5
-------------------------------------------
Jeans                        64.3      92.9
Trousers                     78.6      96.4
Flats                        53.6      96.4
Bra                          57.1      96.4
Dresses                      36.0      64.0
Earrings                     61.1     100.0
Track Pants                  59.3      88.9
Deodorant                     0.0       0.0
Nail Polish                   0.0       0.0
Sweatshirts                  25.0      89.3
Clutches                     75.0      83.3
Innerwear Vests               0.0       0.0
Lipstick                      0.0       0.0
Sweaters                     21.4      71.4
Jackets                      12.0      68.0
Ties                         90.5      95.2
Caps                         59.3      81.5
Kurtis                        4.0      80.0
Tunics                       33.3      88.9
Capris                       10.0      70.0
Pendant                      47.

In [98]:
data.label_data

Unnamed: 0,id,gender,masterCategory,subCategory,articleType,baseColour,season,year,usage,productDisplayName
1,39386,Men,Apparel,Bottomwear,Jeans,Blue,Summer,2012.0,Casual,Peter England Men Party Blue Jeans
18,51832,Women,Apparel,Innerwear,Bra,Beige,Summer,2016.0,Casual,Bwitch Beige Full-Coverage Bra BW335
28,56019,Women,Personal Care,Lips,Lipstick,Brown,Spring,2017.0,Casual,Colorbar Soft Touch Show Stopper Copper Lipsti...
34,59051,Women,Footwear,Shoes,Flats,Black,Winter,2012.0,Casual,Carlton London Women Black & Gold Toned Flats
41,2886,Women,Footwear,Shoes,Flats,Brown,Winter,2015.0,Casual,Catwalk Women Leather Brown Flats
44,8580,Men,Apparel,Topwear,Waistcoat,Grey,Fall,2011.0,Casual,Scullers Men Grey Waistcoat
60,59607,Women,Apparel,Saree,Sarees,Grey,Fall,2012.0,Ethnic,FNF Pink & Grey Wedding Collection Sari
74,23876,Men,Apparel,Topwear,Sweatshirts,Blue,Fall,2011.0,Casual,ADIDAS Men Blue Sweatshirt
83,17885,Men,Apparel,Innerwear,Innerwear Vests,Grey,Summer,2016.0,Casual,Levis Men Comfort Style Grey Innerwear Vest
85,48781,Women,Accessories,Jewellery,Pendant,Silver,Summer,2012.0,Casual,Lucera Women Silver Pendant


In [106]:
episode_classes = np.random.choice(data.train_ds.label_data.articleType.unique(), size=20, replace=False)


In [107]:
episode_classes

array(['Tunics', 'Bracelet', 'Mufflers', 'Rain Jacket', 'Night suits',
       'Ipad', 'Suspenders', 'Sweaters', 'Skirts', 'Caps', 'Key chain',
       'Stoles', 'Ring', 'Mobile Pouch', 'Clutches', 'Trolley Bag',
       'Pendant', 'Track Pants', 'Salwar and Dupatta', 'Duffel Bag'],
      dtype=object)

In [112]:
data.train_ds.label_data[data.train_ds.label_data.articleType.isin(episode_classes)].index.values

array([   5,    7,   10,   15,   23,   27,   28,   31,   42,   48,   49,
         55,   56,   58,   60,   78,   87,   89,   95,   96,  105,  112,
        113,  114,  122,  127,  130,  132,  135,  137,  138,  147,  154,
        155,  156,  162,  170,  173,  176,  182,  184,  189,  199,  201,
        204,  207,  208,  211,  223,  227,  228,  232,  233,  235,  240,
        248,  249,  252,  256,  258,  259,  261,  262,  264,  267,  268,
        269,  271,  275,  284,  287,  294,  298,  300,  301,  304,  308,
        313,  317,  319,  323,  326,  331,  332,  336,  337,  338,  341,
        342,  345,  353,  361,  365,  366,  368,  373,  386,  391,  393,
        398,  409,  413,  415,  424,  425,  426,  427,  431,  441,  444,
        452,  457,  470,  482,  487,  491,  492,  501,  508,  509,  512,
        513,  516,  522,  524,  530,  535,  537,  539,  540,  551,  556,
        557,  561,  565,  571,  573,  578,  586,  591,  598,  601,  603,
        617,  626,  627,  644,  645,  648,  649,  6

In [113]:
torch.zeros((2, 3))

tensor([[0., 0., 0.],
        [0., 0., 0.]])

In [117]:
data.train_ds[0][0].shape

torch.Size([3, 224, 224])

In [124]:
counts = data.train_ds.label_data.articleType.value_counts()

In [125]:
counts

Sarees                 292
Dresses                227
Jeans                  221
Flats                  201
Earrings               192
Innerwear Vests        186
Trousers               177
Clutches               139
Ties                   115
Trunk                  110
Dupatta                100
Track Pants             98
Tunics                  96
Bra                     92
Caps                    84
Necklace and Chains     81
Capris                  79
Leggings                76
Kurtis                  75
Stoles                  68
Scarves                 59
Jackets                 55
Pendant                 55
Free Gifts              51
Skirts                  50
Night suits             47
Boxers                  41
Bangle                  41
Ring                    41
Suspenders              40
                      ... 
Rain Jacket              8
Rompers                  8
Waistcoat                8
Shapewear                8
Tracksuits               7
Water Bottle             7
M