In [1]:
from __future__ import print_function
import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from PIL import Image
import os
import os.path
import numpy as np
import sys

from torchvision.datasets import VisionDataset, CelebA
from torchvision.datasets.utils import check_integrity, download_and_extract_archive, download_file_from_google_drive, verify_str_arg

from functools import partial
import torch
import os
import PIL
import pandas
import csv
from sklearn.preprocessing import LabelEncoder
import time
import copy

if torch.cuda.is_available():
    device = torch.device('cuda:0')
else:
    device = torch.device('cpu')

In [2]:
data_transforms = {
    'train': transforms.Compose([
        # transforms.RandomResizedCrop((80,60)),
        transforms.Resize((80,60)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((80,60)),
        # transforms.CenterCrop(224),
        transforms.ToTensor(),
        # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

In [3]:
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

In [4]:
class FashionProductImages(VisionDataset):
    """TODO
    """
    base_folder = 'fashion-product-images-small' # 'fashion-dataset'
    # base_folder_small_dataset = 'fashion-product-images-small'
    url = None
    filename = "fashion-product-images-dataset.zip" # 'fashion-product-images-small.zip'
    tgz_md5 = None

    file_list = [
        # File ID                         MD5 Hash                            Filename
        ("xxx", "xxx", "fashion-product-images-dataset.zip"),
        ("xxx", "xxx", "fashion-product-images-small.zip"),
        ("xxx", "xxx", "fashion-dataset"),
        ("xxx", "xxx", "fashion-product-images-small"),
        ("xxx", "xxx", "styles.csv"),
        ("xxx", "xxx", "images.csv"),
    ]
    
    top20_classes = [
        "Jeans", "Perfume and Body Mist", "Formal Shoes",
        "Socks", "Backpacks", "Belts", "Briefs",
        "Sandals", "Flip Flops", "Wallets", "Sunglasses",
        "Heels", "Handbags", "Tops", "Kurtas",
        "Sports Shoes", "Watches", "Casual Shoes", "Shirts",
        "Tshirts"]  # actually 21?

    def __init__(self, root, split="train", target_type="articleType", transform=None,
                 target_transform=None, download=False, small_dataset=False):
        super(FashionProductImages, self).__init__(
            root, transform=transform, target_transform=target_transform)

        self.split = split
        self.target_type = target_type

        # TODO.extension: should a list of target types be allowed?
        self.target_type = target_type
        
        # TODO.extension: allow for usage of small dataset
        #if small_dataset:
        #    base_folder = self.base_folder_small_dataset
        #else:
        #    base_folder = self.base_folder

        if download:
            self.download()

        # TODO.not_implemented
        if not self._check_integrity():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')
        
        fn = partial(os.path.join, self.root, self.base_folder)

        # TODO.refactor: refer to class attributes instead of "styles.csv"
        # TODO.edge_case: see whether the styles.csv are identical/have the same corrupt rows for small and full dataset
        with open(fn("styles.csv")) as file:
            csv_reader = csv.reader(file)
            column_names = next(csv_reader)
            
        column_names.append(column_names[-1] + '2')
        
        # TODO.refactor: clean up column names, potentially merge last to columns
        # TODO.refactor: column year is not parsed as integer - why?
        self.df_meta = pandas.read_csv(fn("styles.csv"), names=column_names, skiprows=1)
        
        # check if the images from 'styles.csv' are actually present in the 'images' folder
        images = os.listdir(fn("images"))
        self.samples = self.df_meta.loc[
            (self.df_meta[self.target_type].isin(self.top20_classes))
            & (self.df_meta["id"].apply(lambda x: str(x) +".jpg").isin(images))
        ]
                
        self.targets = self.samples[self.target_type]
        self.target_codec = LabelEncoder()
        self.target_codec.fit(self.targets)
        
        # TODO.decision: are additional codecs necessary? 
        # self.article_codec = LabelEncoder()
        # self.gender_codec = LabelEncoder()
        # self.master_category_codec = LabelEncoder()
        # self.season_codec = LabelEncoder()
        # self.article_codec.fit(self.metadata.loc[:, "articleType"])
        
        # TODO.decision: prepare indices here or when __getitem__ is called?
        self.target_indices = self.target_codec.transform(self.targets)
        
        self.n_classes = len(self.target_codec.classes_)
        
        # TODO.goal: test and training set split
        # TODO.goal: fine-tune vs transfer classes
        
    def __getitem__(self, index):
        # TODO.check: which images are not RGB?
        # TODO.check: which images are not 80x60?
        sample = str(self.samples["id"].iloc[index]) + ".jpg"
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "images", sample)).convert("RGB")
        target = self.target_indices[index]
        
        # TODO.extension: allow returning one-hot representation of target
        
        if self.transform is not None:
            X = self.transform(X)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return X, target
    
    def __len__(self):
        return len(self.samples)

    def download(self):
        # TODO.recycle: check this and compare to e.g. MNIST/CIFAR
        # TODO.tailored: how to download from Kaggle
        #if self._check_integrity():
        #    print('Files already downloaded and verified')
        #    return

        #for (file_id, md5, filename) in self.file_list:
        #    download_file_from_google_drive(file_id, os.path.join(self.root, self.base_folder), filename, md5)

        #with zipfile.ZipFile(os.path.join(self.root, self.base_folder, "img_align_celeba.zip"), "r") as f:
        #    f.extractall(os.path.join(self.root, self.base_folder))

        raise NotImplementedError

    def _check_integrity(self):

        # TODO.recycle: check this and compare to e.g. MNIST/CIFAR

        #for (_, md5, filename) in self.file_list:
        #    fpath = os.path.join(self.root, self.base_folder, filename)
        #    _, ext = os.path.splitext(filename)
        #    # Allow original archive to be deleted (zip and 7z)
        #    # Only need the extracted images
        #    if ext not in [".zip", ".7z"] and not check_integrity(fpath, md5):
        #        return False

        # Should check a hash of the images
        #return os.path.isdir(os.path.join(self.root, self.base_folder, "img_align_celeba"))

        return True
    
    def extra_repr(self):
        # lines = ["Target type: {target_type}", "Split: {split}"]
        # return '\n'.join(lines).format(**self.__dict__)
        pass


In [5]:
fashion = FashionProductImages("~/data", transform=data_transforms["train"])

In [6]:
X, y = fashion[10]
# X, y = fashion[10:15] # fails

In [7]:
len(fashion)

33149

In [8]:
train_size = int(len(fashion) * 0.75)
trainset, valset = random_split(fashion, [train_size, len(fashion) - train_size])

train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(valset, batch_size=64, shuffle=True, num_workers=4)

dataloaders = {"train": train_loader, "val": val_loader}
dataset_sizes = {"train": len(trainset), "val": len(valset)}

# data_loader = DataLoader(fashion, batch_size=128, shuffle=True, num_workers=4)

In [9]:
for i in range(len(fashion)):
    X, y = fashion[i]
    # print(y.shape)
    # if not isinstance(y, int):
    #    import pdb; pdb.set_trace()
    if not (X.shape[1]==80 and X.shape[2]==60):
        import pdb; pdb.set_trace()

In [11]:
counter = 0
for (i, batch) in enumerate(train_loader):
    counter += 64
    print(counter)
    # import pdb; pdb.set_trace()

64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024
1088
1152
1216
1280
1344
1408
1472
1536
1600
1664
1728
1792
1856
1920
1984
2048
2112
2176
2240
2304
2368
2432
2496
2560
2624
2688
2752
2816
2880
2944
3008
3072
3136
3200
3264
3328
3392
3456
3520
3584
3648
3712
3776
3840
3904
3968
4032
4096
4160
4224
4288
4352
4416
4480
4544
4608
4672
4736
4800
4864
4928
4992
5056
5120
5184
5248
5312
5376
5440
5504
5568
5632
5696
5760
5824
5888
5952
6016
6080
6144
6208
6272
6336
6400
6464
6528
6592
6656
6720
6784
6848
6912
6976
7040
7104
7168
7232
7296
7360
7424
7488
7552
7616
7680
7744
7808
7872
7936
8000
8064
8128
8192
8256
8320
8384
8448
8512
8576
8640
8704
8768
8832
8896
8960
9024
9088
9152
9216
9280
9344
9408
9472
9536
9600
9664
9728
9792
9856
9920
9984
10048
10112
10176
10240
10304
10368
10432
10496
10560
10624
10688
10752
10816
10880
10944
11008
11072
11136
11200
11264
11328
11392
11456
11520
11584
11648
11712
11776
11840
11904
11968
12032
12096
12160
12224
12288
12352
12416
12480
12

In [12]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [13]:
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format(class_names[preds[j]]))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

In [14]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs, fashion.n_classes)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.001, betas=(0.9, 0.999))

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [15]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=25)

Epoch 0/24
----------
train Loss: 0.4583 Acc: 0.8455
val Loss: 0.3618 Acc: 0.8743

Epoch 1/24
----------
train Loss: 0.2945 Acc: 0.8990
val Loss: 0.3084 Acc: 0.8938

Epoch 2/24
----------
train Loss: 0.2383 Acc: 0.9148
val Loss: 0.2789 Acc: 0.9048

Epoch 3/24
----------
train Loss: 0.2158 Acc: 0.9236
val Loss: 0.2584 Acc: 0.9139

Epoch 4/24
----------
train Loss: 0.1872 Acc: 0.9337
val Loss: 0.2625 Acc: 0.9118

Epoch 5/24
----------
train Loss: 0.1727 Acc: 0.9376
val Loss: 0.2493 Acc: 0.9216

Epoch 6/24
----------
train Loss: 0.1543 Acc: 0.9456
val Loss: 0.2745 Acc: 0.9111

Epoch 7/24
----------
train Loss: 0.0944 Acc: 0.9653
val Loss: 0.1934 Acc: 0.9382

Epoch 8/24
----------
train Loss: 0.0764 Acc: 0.9736
val Loss: 0.1966 Acc: 0.9414

Epoch 9/24
----------
train Loss: 0.0645 Acc: 0.9779
val Loss: 0.2009 Acc: 0.9368

Epoch 10/24
----------
train Loss: 0.0519 Acc: 0.9817
val Loss: 0.2134 Acc: 0.9382

Epoch 11/24
----------
train Loss: 0.0453 Acc: 0.9854
val Loss: 0.2272 Acc: 0.9365

Ep