In [None]:
import shutil
import os
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!cp -u /content/drive/MyDrive/merged_Aug9_17_20.zip ./
!unzip -n merged_Aug9_17_20.zip

In [None]:
!cp -u /content/drive/MyDrive/img_dict.csv ./
!cp -u /content/drive/MyDrive/cat_names_dict.csv ./

import csv

img_cat = dict()
with open('img_dict.csv', 'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
        img_cat[row[0]] = row[1]

cat_names = dict()
with open('cat_names_dict.csv', 'r') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:
        cat_names[row[0]] = row[1]


In [None]:
def copy_shoe_imgs(img_cat, cat_names):
    """Create folder 'shoes' with shoe images only.
       
       Parameters are two dictionaries
       img_cat: key -- img filename, value -- img category
       cat_names: key -- category number, value -- category name in Russian"""

    os.mkdir('shoes')
    os.mkdir('shoes/M')
    os.mkdir('shoes/O')

    # categories with shoe images
    chosen_cats = [5232, 1542, 5233, 4106, 1532, 579, 2351, 2367, 2133, 2132,
                   689, 3338, 692, 954, 3710, 3704, 4418, 2550, 4419, 2129,
                   2130, 5232, 2127, 5234]
    for cat in chosen_cats:
        if str(cat) in cat_names:
            print(cat, cat_names[str(cat)])

    m_imgs = os.listdir('merged/M')
    for img in m_imgs:
        if int(img_cat[img]) in chosen_cats:
            shutil.copyfile('merged/M/'+img, 'shoes/M/'+img)

    o_imgs = os.listdir('merged/O')
    for img in o_imgs:
        if int(img_cat[img]) in chosen_cats:
            shutil.copyfile('merged/O/'+img, 'shoes/O/'+img)

    print('shoes Main images', len(os.listdir('shoes/M')),
          '\nshoes Other images', len(os.listdir('shoes/O')))

In [None]:
!pip install torch
!pip install torchvision
!pip install scikit-learn
!pip install pytorch-lightning
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import torch

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.metrics.functional import accuracy
from torch import nn

from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
import torch
from torch.nn import functional as F
from pytorch_lightning.metrics.functional import accuracy
import os

import ntpath

In [None]:
class MyDataModule(LightningDataModule): # rename
    def __init__(self, setupdir, tr, va, te, img_cat, seed=0, batch_size=64):
        super().__init__()
        
        self.batch_size = batch_size

        self.setupdir = setupdir

        self.transform = transforms.Compose([
              transforms.Resize(size=256),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])

        self.te = te
        self.tr = tr
        self.va = va

        self.img_cat = img_cat

        self.img_filenames = []

        self.seed = seed

    #def setup(self):
    #    dataset = datasets.ImageFolder(self.setupdir)
    #    self.num_classes = len(dataset.classes)
    #
    #    train_len, val_len = int(0.9 * len(dataset)), int(0.05  * len(dataset))
    #    test_len = len(dataset) - train_len - val_len
    #    
    #    self.train, self.val, self.test = random_split(dataset, [train_len, val_len, test_len])  
    #
    #    self.train.dataset.transform = self.transform
    #    self.val.dataset.transform = self.transform
    #    self.test.dataset.transform = self.transform

    #def loader(path):
        
    #    self.img_filenames.append(path)
     #   return datasets.folder.default_loader(path)

    def setup(self):
        torch.manual_seed(self.seed)

        dataset = datasets.ImageFolder(self.setupdir)
        self.num_classes = len(dataset.classes)

        names_dataset = datasets.ImageFolder(self.setupdir,
                                             loader=lambda path: path)

        self.img_filenames = [ntpath.basename(x[0]) for x in names_dataset]

        indices_dict = dict()

        for i in range(len(self.img_filenames)):
            cat = img_cat[self.img_filenames[i]]
            if not cat in indices_dict:
                indices_dict[cat] = {'main': [], 'other': []}
            is_main = not bool(dataset[i][1])

            if is_main:
                indices_dict[cat]['main'].append(i)
            else:
                indices_dict[cat]['other'].append(i)


        train_indcs, val_indcs, test_indcs = [], [], []

        for cat, indcs in indices_dict.items():
            # main imgs
            imgs_amount = len(indcs['main'])
            if imgs_amount != self.tr + self.va + self.te:
                print('amount of main images in any category must be equal\
                       to tr+va+te')
                return -1

            shuffled = torch.tensor(indcs['main'])[torch.randperm(imgs_amount)]
            shuffled = shuffled.tolist()

            train_indcs += shuffled[:self.tr]
            val_indcs += shuffled[self.tr:self.tr + self.va]
            test_indcs += shuffled[-self.te:]

            # other imgs
            imgs_amount = len(indcs['other'])
            if imgs_amount != self.tr + self.va + self.te:
                print('amount of other images in any category must be equal\
                       to tr+va+te')
                return -1

            shuffled = torch.tensor(indcs['other'])[torch.randperm(imgs_amount)]
            shuffled = shuffled.tolist()

            train_indcs += shuffled[:self.tr]
            val_indcs += shuffled[self.tr:self.tr+self.va]
            test_indcs += shuffled[-self.te:]

        print(train_indcs)

        self.train = torch.utils.data.Subset(dataset, train_indcs)
        self.val = torch.utils.data.Subset(dataset, val_indcs)
        self.test = torch.utils.data.Subset(dataset, test_indcs)

        self.train.dataset.transform = self.transform
        self.val.dataset.transform = self.transform
        self.test.dataset.transform = self.transform        

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [None]:
class MyModel(LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=1e-4, batch_size=64):
        super().__init__()

        self.batch_size = batch_size

        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes

        self.feature_extractor = models.resnet34(pretrained=True)
        self.feature_extractor.eval() # not necessary

        #for param in self.feature_extractor.parameters():
            #param.requires_grad = False

        n_sizes = self._get_conv_output(input_shape)

        # filled after running test_step
        # in accordance with the order of the test dataset from datamodule:
        # i-th prediction corresponds to the NN result on the i-th element of 
        # the test dataset
        self.predictions = []

        self.classifier = torch.nn.Linear(n_sizes, num_classes)

    def _get_conv_output(self, shape):
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))

        output_feat = self._forward_features(input) 
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
        
    def _forward_features(self, x):
        x = self.feature_extractor(x)
        return x
    
    def forward(self, x):
       x = self._forward_features(x) # выдавал бы логиты softmax частью loss'а
       x = x.view(x.size(0), -1)
       x = F.log_softmax(self.classifier(x), dim=1)

       return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x) # уже не логиты
        loss = F.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y) # в случае дисбаланса классов мало о чем говорит
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)        

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        
        for i in range(len(y)):
            self.predictions.append(preds[i])

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
#copy_shoe_imgs(img_cat, cat_names)

batch_size = 64
dm = MyDataModule(setupdir='merged', tr=15, va=1, te=4,
                  img_cat=img_cat, seed=0, batch_size=batch_size)
dm.setup()

In [None]:
num_classes = dm.num_classes
print(num_classes)

In [None]:
model = MyModel((3,224,224), num_classes, batch_size=batch_size, learning_rate=2e-4)

In [None]:
checkpoint = ModelCheckpoint(dirpath='drive/MyDrive/checkpoints', monitor='val_loss', save_top_k=1)

trainer = Trainer(max_epochs=4,
                  progress_bar_refresh_rate=1,
                  gpus = [0],
                  callbacks = [checkpoint]) # float point 16 можно использовать
# метрика для чекпоинта val loss
# early stopping если на трейне уменьшается но на валидации не уменьшается останавливает обучение
# чекпоинты сразу на гугл драйв

# если время обучения не уменьшилось в два в три раза батч сайз и во столько же раз лернин рейт
# best model path аттрибут у чекпоинта через. трайнер тест проверить ее качество 
# посмотреть где ошибочные изображения

trainer.fit(model, dm)

trainer.test()

In [None]:
# checkpoint_callback = ModelCheckpoint(dirpath='drive/MyDrive/checkpoints')
# trainer = Trainer(callbacks=[checkpoint_callback])
# trainer.fit(model, dm)
# checkpoint_callback.best_model_path

In [None]:
class ImitateModel():
    def __init__(self):
        self.predictions = []
        with open('/content/drive/MyDrive/raw_predictions.csv', 'r') as csvinput:
            reader = csv.reader(csvinput)
            for row in reader:
                self.predictions.append(torch.tensor(int(row[1])))
model = ImitateModel()

In [None]:
print(len(model.predictions))
dm.test.dataset[-20][1]

In [None]:
import ntpath

# key -- cat,
# value -- array:
# [right predicted main images, total main images, m img accuracy,
# right predicted other images, total other images, o img accuracy,
# right predictions in total, prediction in total, total accuracy]
prdctns_by_cat = dict()

name_ds = dm.img_filenames

for i in tqdm(range(len(model.predictions))):
    complete_ds_indx = dm.test.indices[i]
    img_name = ntpath.basename(name_ds[complete_ds_indx])
    cat = img_cat[img_name]

    is_main = not bool(dm.test.dataset[complete_ds_indx][1])
    
    pred = model.predictions[i].item()

    pred_correct = (pred == dm.test.dataset[complete_ds_indx][1])
    if not cat in prdctns_by_cat:
        prdctns_by_cat[cat] = [0, 0, -1.,
                               0, 0, -1.,
                               0, 0, -1.,]

    if is_main:
        prdctns_by_cat[cat][1] += 1
    else:
        prdctns_by_cat[cat][4] += 1 

    if pred_correct:
        if is_main:
            prdctns_by_cat[cat][0] += 1
        else:
            prdctns_by_cat[cat][3] += 1 

for cat, pr_list in prdctns_by_cat.items():
    prdctns_by_cat[cat][6] = pr_list[0] + pr_list[3]
    prdctns_by_cat[cat][7] = pr_list[1] + pr_list[4]

    if pr_list[1] != 0:
        prdctns_by_cat[cat][2] = pr_list[0] / pr_list[1]
    if pr_list[4] != 0:
        prdctns_by_cat[cat][5] = pr_list[3] / pr_list[4]
    if pr_list[7] != 0:
        prdctns_by_cat[cat][8] = pr_list[6] / pr_list[7]

In [None]:
import csv

with open('drive/MyDrive/predictions.csv', 'w') as output:
    writer = csv.writer(output)

    writer.writerow(['cat name', 'base type',
                     '# right predictions of main images', '# main images',
                     'main images accuracy',
                     '# right predictions of other images', '# other images',
                     'other images accuracy',
                     '# right predictions in total', '# images in total',
                     'accuracy in total'])

    for cat, pr_list in tqdm(prdctns_by_cat.items()):
        row = [cat_names[cat], cat] + pr_list
        writer.writerow(row)

with open('drive/MyDrive/raw_predictions.csv', 'w') as raw_predictions_csv:
    writer = csv.writer(raw_predictions_csv)

    for i in tqdm(range(len(model.predictions))):
        complete_ds_indx = dm.test.indices[i]
        name_ds_element = name_ds[complete_ds_indx]
        

        img_name = ntpath.basename(name_ds_element)
        pred = model.predictions[i].item()

        writer.writerow([img_name, pred])
        

In [None]:
cats_worst_to_best = sorted(list(prdctns_by_cat.items()), key=lambda x: x[1][8])
cats_worst_to_best = [(c[0],c[1][8]) for c in cats_worst_to_best]

In [None]:
prdctns_sorted = dict()
for i in range(len(model.predictions)):
    complete_ds_indx = dm.test.indices[i]
    name_ds_element = name_ds[complete_ds_indx]

    img_name = ntpath.basename(name_ds_element)
    real_label = dm.test.dataset[complete_ds_indx][1]
    pred = model.predictions[i].item()

    cat = img_cat[img_name]
    if not cat in prdctns_sorted:
        prdctns_sorted[cat] = {(0,0): [],
                               (0,1): [],
                               (1,0): [],
                               (1,1): []}
    prdctns_sorted[cat][(real_label, pred)].append(img_name)

In [None]:
num_examples = 4
examples = dict()

for indx in dm.train.indices:
    name_ds_element = name_ds[indx]
    img_name = ntpath.basename(name_ds_element)

    img_label = dm.train.dataset[indx][1]

    cat = img_cat[img_name]
    if not cat in examples:
        examples[cat] = {0: [], 1: []}

    if len(examples[cat][img_label]) < num_examples:
        examples[cat][img_label].append(img_name)    

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image
import math
import numpy as np

In [None]:
def grid_alignment(imgs, grid_height=1, grid_width=None, size=20.):
    if not imgs:
        return

    blank_img = np.zeros([1000,1000,3])
    blank_img.fill(1.)

    #imgs = [blank_img, blank_img, imgs[0], blank_img]

    if not grid_width:
        grid_width = math.ceil(len(imgs) / grid_height)


    #if grid_width * grid_height < len(imgs):
    #    print("Make sure grid_width * grid_height is sufficient to\
    #           draw all the images")
    #    return

    #grid = ImageGrid(fig, 111,
    #                nrows_ncols=(grid_height, grid_width),
    #                axes_pad=0.1,
    #)


    #for ax, im in zip(grid, imgs):
    #    ax.imshow(im)
    #    ax.axis('off')

    _, axarr = plt.subplots(grid_height, grid_width, figsize=(15,15))

    for i, img in enumerate(imgs):
        if grid_height > 1:
            axarr[int(i/grid_width),i%grid_width].imshow(img)
            axarr[int(i/grid_width),i%grid_width].axis('off')
        elif grid_width > 1:
            axarr[i].imshow(img)
            axarr[i].axis('off')
        else:
            axarr.imshow(img)
            axarr.axis('off')

    plt.show()

In [None]:
len(cats_worst_to_best)

In [None]:
def costyl(img_name):

    if img_name[-5] == 'M':
        return 'merged/M/'+img_name
    else:
        return 'merged/O/'+img_name

for cat, accuracy in cats_worst_to_best[:20]:
    print(cat, cat_names[cat], accuracy)
    to_print = {(0,1): "Main images predicted as Other",
                (1,0): "Other images predicted as Main",
                (0,0): "Right predicted Main images",
                (1,1): "Right predicted Other images"}
    for k in [(0,1), (1,0), (0,0), (1,1)]:
        img_names = prdctns_sorted[cat][k]


        if img_names:

            print(to_print[k])

            imgs = [Image.open(costyl(img_name)) for img_name in img_names]
            grid_alignment(imgs)

            print('. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . '+
                '. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ')


    print('-------------------------------------------------------------------'+
          '-------------------------------------------------------------------')
    print('examples')
    for k,img_names in examples[cat].items():

        if k==0:
            print('Main')
        else:
            print('Other')

        imgs = [Image.open(costyl(img_name)) for img_name in img_names]
        grid_alignment(imgs)

        print('. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . '+
            '. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ')


    print()
    print('==================================================================='+
          '===================================================================')
    print('==================================================================='+
          '===================================================================')
    print()

In [None]:
for cat, accuracy in cats_worst_to_best[-20:]:
    print(cat, cat_names[cat], accuracy)
    to_print = {(0,1): "Main images predicted as Other",
                (1,0): "Other images predicted as Main",
                (0,0): "Right predicted Main images",
                (1,1): "Right predicted Other images"}
    for k in [(0,1), (1,0), (0,0), (1,1)]:
        img_names = prdctns_sorted[cat][k]


        if img_names:

            print(to_print[k])

            imgs = [Image.open(costyl(img_name)) for img_name in img_names]
            grid_alignment(imgs)

            print('. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . '+
                '. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ')


    print('-------------------------------------------------------------------'+
          '-------------------------------------------------------------------')
    print('examples')
    for k,img_names in examples[cat].items():

        if k==0:
            print('Main')
        else:
            print('Other')

        imgs = [Image.open(costyl(img_name)) for img_name in img_names]
        grid_alignment(imgs)

        print('. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . '+
            '. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ')


    print()
    print('==================================================================='+
          '===================================================================')
    print('==================================================================='+
          '===================================================================')
    print()

In [None]:
from random import randint
random_inds = []
while len(random_inds) < 20:
    next_num = randint(20 + 1, len(cats_worst_to_best) - 20 - 1)
    if not next_num in random_inds:
        random_inds.append(next_num)

random_cats = [cats_worst_to_best[ind] for ind in random_inds]

for cat, accuracy in random_cats:
    print(cat, cat_names[cat], accuracy)
    to_print = {(0,1): "Main images predicted as Other",
                (1,0): "Other images predicted as Main",
                (0,0): "Right predicted Main images",
                (1,1): "Right predicted Other images"}
    for k in [(0,1), (1,0), (0,0), (1,1)]:
        img_names = prdctns_sorted[cat][k]


        if img_names:

            print(to_print[k])

            imgs = [Image.open(costyl(img_name)) for img_name in img_names]
            grid_alignment(imgs)

            print('. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . '+
                '. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ')


    print('-------------------------------------------------------------------'+
          '-------------------------------------------------------------------')
    print('examples')
    for k,img_names in examples[cat].items():

        if k==0:
            print('Main')
        else:
            print('Other')

        imgs = [Image.open(costyl(img_name)) for img_name in img_names]
        grid_alignment(imgs)

        print('. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . '+
            '. . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . ')


    print()
    print('==================================================================='+
          '===================================================================')
    print('==================================================================='+
          '===================================================================')
    print()

In [None]:
!