In [14]:
import os
import torch
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
from torch import optim
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

%matplotlib inline

# Explore and clean the data
In this section, we want to explore the data and prepare for image classification. 

In [15]:
# create a df from styles.csv neglecting lines with error
DATASET_PATH = '/mnt/fashion-dataset-1/fashion-dataset/'
styles = pd.read_csv(os.path.join(DATASET_PATH, "styles.csv"), error_bad_lines=False) 

b'Skipping line 6044: expected 10 fields, saw 11\nSkipping line 6569: expected 10 fields, saw 11\nSkipping line 7399: expected 10 fields, saw 11\nSkipping line 7939: expected 10 fields, saw 11\nSkipping line 9026: expected 10 fields, saw 11\nSkipping line 10264: expected 10 fields, saw 11\nSkipping line 10427: expected 10 fields, saw 11\nSkipping line 10905: expected 10 fields, saw 11\nSkipping line 11373: expected 10 fields, saw 11\nSkipping line 11945: expected 10 fields, saw 11\nSkipping line 14112: expected 10 fields, saw 11\nSkipping line 14532: expected 10 fields, saw 11\nSkipping line 15076: expected 10 fields, saw 12\nSkipping line 29906: expected 10 fields, saw 11\nSkipping line 31625: expected 10 fields, saw 11\nSkipping line 33020: expected 10 fields, saw 11\nSkipping line 35748: expected 10 fields, saw 11\nSkipping line 35962: expected 10 fields, saw 11\nSkipping line 37770: expected 10 fields, saw 11\nSkipping line 38105: expected 10 fields, saw 11\nSkipping line 38275: ex

In [16]:
print(styles.head())

      id gender masterCategory subCategory  articleType baseColour  season  \
0  15970    Men        Apparel     Topwear       Shirts  Navy Blue    Fall   
1  39386    Men        Apparel  Bottomwear        Jeans       Blue  Summer   
2  59263  Women    Accessories     Watches      Watches     Silver  Winter   
3  21379    Men        Apparel  Bottomwear  Track Pants      Black    Fall   
4  53759    Men        Apparel     Topwear      Tshirts       Grey  Summer   

     year   usage                             productDisplayName  
0  2011.0  Casual               Turtle Check Men Navy Blue Shirt  
1  2012.0  Casual             Peter England Men Party Blue Jeans  
2  2016.0  Casual                       Titan Women Silver Watch  
3  2011.0  Casual  Manchester United Men Solid Black Track Pants  
4  2012.0  Casual                          Puma Men Grey T-shirt  


In [17]:
print(len(styles))

44424


In [18]:
# Get the list of names of images that are available
imgs_available = os.listdir(DATASET_PATH + '/images')
print(len(imgs_available))

44442


We want to check if each entry in styles.csv has a corresponding images listing. If not, we remove it from the dataframe.

In [19]:
# Check if each entry in styles.csv has a corresponding images listing. 
# If not, we remove it from the dataframe.
missing_img = []
for idx, line in styles.iterrows():
    if not os.path.exists(os.path.join(DATASET_PATH, 'images', str(line.id)+'.jpg')):
        print(os.path.join(DATASET_PATH, 'images', str(line.id)+'.jpg'))
        missing_img.append(idx)
        
styles.drop(styles.index[missing_img], inplace=True)

/mnt/fashion-dataset-1/fashion-dataset/images/39403.jpg
/mnt/fashion-dataset-1/fashion-dataset/images/39410.jpg
/mnt/fashion-dataset-1/fashion-dataset/images/39401.jpg
/mnt/fashion-dataset-1/fashion-dataset/images/39425.jpg
/mnt/fashion-dataset-1/fashion-dataset/images/12347.jpg


In [20]:
print(len(styles))

44419


## Getting the top articleTypes 

In [21]:
# Check 
print(len(styles.groupby(['articleType']).size()))

142


There are 142 distinct articleTypes. In the next cell we are going to see what are the top 20 of them.

In [22]:
top_classes = styles.groupby(['articleType']).size().nlargest(20).sort_values()
print(top_classes)

articleType
Jeans                     608
Perfume and Body Mist     613
Formal Shoes              637
Socks                     686
Backpacks                 724
Belts                     813
Briefs                    849
Sandals                   897
Flip Flops                914
Wallets                   936
Sunglasses               1073
Heels                    1323
Handbags                 1759
Tops                     1762
Kurtas                   1844
Sports Shoes             2036
Watches                  2542
Casual Shoes             2845
Shirts                   3215
Tshirts                  7066
dtype: int64


# Transfer Learning / Fine tuning

### As first, we create master train and test splits of the valid image data, with everything in even years used for the training set, and everything in an odd year used for the test split.
Before doing that, we should check if there are any entries with not valid articleType or year and remove them

In [23]:
styles.dropna(inplace=True, subset=['year','articleType'])
len(styles)

44418

In [24]:
training_data = styles[styles['year'].astype('int') % 2 == 0]
test = styles[styles['year'].astype('int') % 2 == 1]

### Next, let us create sub-splits of the training data for pre-training and fine tuning 
We will do it as follows:
* the top 20 classes (see above) - about 3/4 of the data; and
* all other classes - about 1/4 of the data

In [25]:
top_classes_names = list(top_classes.index)

In [27]:
pre_training = training_data[training_data.articleType.isin(top_classes_names)]
fine_tuning = training_data[-training_data.articleType.isin(top_classes_names)]