In [None]:
# !ls
# !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1gsajB2iQN6_5pAWSu3MubsdR1-Mus4PH' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1gsajB2iQN6_5pAWSu3MubsdR1-Mus4PH" -O fashion_full_upd.zip && rm -rf /tmp/cookies.txt
# !unzip -qq /content/fashion_full_upd.zip -d fashion_full
# !mkdir checkpoints

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import os
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import data_utils
import nn_utils
import importlib
from data_utils import load_data, train_test_split_fashion, get_dataloaders
from nn_utils import correct_top_k, get_class_weights, plot_losses, get_accuracy, train, compute_loss
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from torch.utils.data import DataLoader

importlib.reload(data_utils)
importlib.reload(nn_utils)

In [None]:
data_path = './fashion_full_upd/'
#data_path = '/content/fashion_full/fashion_full_upd/'

dataloader_params = {'batch_size': 64,
                     'shuffle': True,
                     'num_workers': 8}
resize_normalize_transform = Compose([Resize((224, 224)), ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
val_split = 0.2
n_epochs = 20
learning_rate = 0.0001
seed = 0

# set the seed
torch.manual_seed(seed)
np.random.seed(seed)

styles_df = load_data(data_path)
styles_df_train, styles_df_test = train_test_split_fashion(styles_df)
sorted_class_names = list(styles_df.groupby(['articleType']).size().sort_values(ascending=False).index)

dataloaders = get_dataloaders(styles_df_train, 
                              styles_df_test,
                              sorted_class_names,
                              data_path, 
                              val_split, 
                              resize_normalize_transform, 
                              dataloader_params)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {}'.format(device))

class_weights = get_class_weights(styles_df_train, sorted_class_names)
class_weights = torch.tensor(class_weights,  dtype=torch.float)
class_weights = class_weights.to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

In [None]:
model = torchvision.models.resnet50(pretrained=True)

# reinitialize the fc layer
model.fc = nn.Linear(model.fc.in_features, len(class_weights))
model = model.to(device)
    
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
if False:
    checkpoint = torch.load('./checkpoints/checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
train_losses, val_losses, model = train(dataloader_train = dataloaders['train_top20'], 
                                        dataloader_val = dataloaders['val_top20'],
                                        n_epochs = n_epochs, 
                                        model = model, 
                                        criterion = criterion, 
                                        optimizer = optimizer, 
                                        device = device,
                                        load_checkpoint_path = None,
                                        checkpoint_save_path = './checkpoints/checkpoint.pt')

In [None]:
plot_losses(train_losses, val_losses)

In [None]:
class_correct_topk, class_counts = correct_top_k(dataloader = dataloaders['test'],
                                                 model = model,
                                                 k_list = [1, 5],
                                                 n_classes = len(class_weights), 
                                                 device=device)

In [None]:
get_accuracy(class_correct_topk, class_counts)

In [None]:
others_train_losses, others_val_losses, model = train(dataloader_train = dataloaders['train_others'], 
                                                      dataloader_val = dataloaders['val_others'], 
                                                      n_epochs = n_epochs, 
                                                      model = model, 
                                                      criterion = criterion, 
                                                      optimizer = optimizer, 
                                                      device = device,
                                                      load_checkpoint_path = None,
                                                      checkpoint_save_path='./checkpoints/checkpoint_finetuned.pt')

In [None]:
if True:
    checkpoint = torch.load('./checkpoints/checkpoint_finetuned.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [None]:
class_correct_topk_others, class_counts_others = correct_top_k(dataloader = dataloaders['test'],
                                                               model = model,
                                                               k_list = [1, 5],
                                                               n_classes = len(class_weights), 
                                                               device=device)

In [None]:
get_accuracy(class_correct_topk_others, class_counts_others)

In [None]:
plot_losses(others_train_losses, others_val_losses)