In [None]:
import os
from pathlib import Path
import copy
import numpy as np
from numpy import *
import pandas as pd
pd.options.display.max_columns = None
import pickle
import time
from datetime import datetime
import warnings
warnings.filterwarnings("ignore")
import tqdm.notebook as tq
import csv

#Plots
import matplotlib.pyplot as plt
import seaborn as sns

#Random
import random
from random import choice
from random import shuffle

#Pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.models as models 
from torch.utils.data import DataLoader,Dataset
from torch import optim
import torch.autograd as autograd
from PIL import Image

#Dim reduction
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.metrics import average_precision_score

#Livelossplot
!pip install livelossplot --quiet
from livelossplot import PlotLosses

# get the label of the image of Versace
def path_to_label(path):
    path = path.split('/')
    pre_path_num = 3
    #label = path[pre_path_num+3] + '-' + path[pre_path_num+4] +  '-' + path[pre_path_num+5] +  '-' + path[-3] + '-' + path[-2]
    label = path[-6] +  '-' + path[-5] +  '-' + path[-4] + '-'+ path[-3] + '-' + path[-2]
    return label
print(torch.__version__)
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1.8.1+cu101
True


In [None]:
# connect the google drive 
from google.colab import drive
drive.mount("/content/drive/", force_remount=True)

Mounted at /content/drive/


In [None]:
# !unzip -u "/content/drive/MyDrive/PoleS8/few_shot_learning_brands.zip" -d "/content/drive/My Drive/PoleS8/Brands"

In [None]:
# read data 
train_df = pd.read_csv('/content/drive/MyDrive/PoleS8/train_100_categories.csv')
support_df = pd.read_csv('/content/drive/MyDrive/PoleS8/support_50_categories.csv')
query_df = pd.read_csv('/content/drive/MyDrive/PoleS8/query_50_categories.csv')


def set_path_replace(set_df):
  set_df.path = set_df.path.apply(lambda x:x.replace('/content/drive/MyDrive/','/content/drive/MyDrive/PoleS8/Brands/'))

set_path_replace(train_df)
set_path_replace(support_df)
set_path_replace(query_df)

In [None]:
def enforce_all_seeds(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    rgen = np.random.default_rng(seed)
    return rgen

rgen = enforce_all_seeds(42)

In [None]:
# ResNet50 network
class ResNet50(nn.Module):

    def __init__(self):
        super(ResNet50, self).__init__()
        
        # importing ResNet50 and freezing all weights
        model = models.resnet50(pretrained=True)
        for param in model.parameters():
            param.requires_grad = False
            
        # remove the last fully connected layers of ResNet50
        model.fc = nn.Sequential()
        self.conv = model
        
        # redifine the fully connected network
        self.fc1 = nn.Linear(2048,1024)
        self.fc2 = nn.Linear(1024,512)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.sigmoid(x)
        return x


In [None]:
# Prototypical network
'''
- df_paths_labels: the train dataframe of Versace
- support_df: the support dataframe of Versace
- query_df: the query dataframe of Versace
- Nc: the number of chosen categories for one episode
- Ns: the number of categories of support set
- Nq: the number of categories of query set
- step: save the model every step 
- learning_rate: the learning rate
- chosen_labels: the unique labels of Versace
- transfor: perform some operation on image
- trainval: load the model if True
'''
class Prototypical(Dataset):
    
    def __init__(self, df_paths_labels, support_df, query_df, length, Nc, Ns, Nq, step, learning_rate, chosen_labels=None, transform=None, trainval=False):
        self.df = df_paths_labels
        self.len_df = len(df_paths_labels)
        self.support_df = support_df
        self.query_df = query_df
        if chosen_labels is not None:
            self.chosen_labels = chosen_labels
        else:
            self.chosen_labels = df_paths_labels.label.unique()
        self.class_number = len(self.chosen_labels)
        self.length = length
        self.transform = transform
        self.Ns = Ns
        self.Nq = Nq
        self.Nc = Nc
        self.learning_rate = learning_rate
        
        if trainval == False:
            # initiate the Resnet50 network and the center of prototypical network 
            self.center = {}
            self.model = ResNet50().to(device)
        else:
            # load the the Resnet50 network and the center of prototypical network 
            self.center = {}
            # modify and save the file name of the model
            self.model = torch.load('/content/drive/MyDrive/PoleS8/log/model_net_'+str(step)+'.pkl')
            # modify the file name of the storage center	
            self.load_center('/content/drive/MyDrive/PoleS8/log/model_center_'+str(step)+'.csv')	

    # get the support set and query set randomly
    def randomSample(self):
        choose_class_labels = self.chosen_labels[:self.Nc]
        sup_set = {}
        que_set = {}
        for label in choose_class_labels:
            l = []
            l.append(label)
            label_path = random.choice(self.df.loc[self.df.label.isin(l)].path.values)
            label_path_modified = path_to_label(label_path)
            label_paths = self.df.loc[(self.df.label == label_path_modified)].path.values
            random.shuffle(label_paths)
            sup_set[label_path_modified] = label_paths[:self.Ns]
            que_set[label_path_modified] =  label_paths[self.Ns:(self.Ns+self.Nq)]
        return sup_set,que_set
    
    # get the feature of image using Resnet50 network
    def getFeature(self,img_path):
        img = Image.open(img_path).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
            img = img.to(device)
            
        feature = self.model(img.unsqueeze(0).to(device))
        return feature
    
    # compute the center of support set for one category
    def computeCenter(self,sup_set):
        for label, img_paths in sup_set.items():
            feature = torch.FloatTensor(np.array([np.zeros(512)])).to(device)
            for img_path in img_paths:
                img_path = self.df.loc[(self.df.path == img_path)].path.values[0]
                feature += self.getFeature(img_path)
            self.center[label] = feature/self.Ns

    # compute the distance of two cneters
    def eucli_tensor(self,x,y):
        return torch.sqrt(torch.sum((x-y)*(x-y))).to(device)
    
    # compute the loss
    def loss(self,que_set):
        loss_train = autograd.Variable(torch.FloatTensor([0])).to(device)
        for label, img_paths in que_set.items():
            for img_path in img_paths:
                img_path = self.df.loc[(self.df.path == img_path)].path.values[0]
                feature = self.getFeature(img_path)
                sum = torch.FloatTensor([0]).to(device)
                for label_center, feature_center in self.center.items():
                    if(label != label_center):
                        sum += torch.exp(-1*self.eucli_tensor(feature, feature_center))
                loss_train += (self.eucli_tensor(feature, self.center[label]) + torch.log(sum))/(self.Nc * self.Nq)        
        return loss_train
    
    # save centers
    def save_center(self,path):
        datas = []
        for label in self.center.keys():
            datas.append([label] + list(self.center[label].detach().cpu().numpy()))
        with open(path,"w", newline="") as datacsv:
            csvwriter = csv.writer(datacsv,dialect = ("excel"))
            csvwriter.writerows(datas)

    # load centers
    def load_center(self,path):
        csvReader = csv.reader(open(path))
        for line in csvReader:
            label = int(line[0])
            center = [ float(line[i]) for i in range(1,len(line))]
            center = np.array(center)
            center = Variable(torch.from_numpy(center))
            self.center[label] = center

    # train the Prototypical network           
    def train(self):
        sup_set, que_set = self.randomSample()
        self.computeCenter(sup_set)
        optimizer = torch.optim.Adam(self.model.parameters(),lr=self.learning_rate) 
        optimizer.zero_grad() 
        loss_train = self.loss(que_set)
        loss_train = loss_train.requires_grad_()
        loss_train.backward()
        optimizer.step()
    
    # compute the mean avarage precision
    def mean_average_precision(self):
    
        def preprocess(path):
            transformer = transforms.Compose([transforms.Resize((256, 256)),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
            return transformer(Image.open(path).convert("RGB"))

        def forward_pass(path):
            img = preprocess(path)
            y = self.model(img.unsqueeze(0).to(device))
            y = y.detach().cpu().numpy()
            return y

        self.support_df['embedded_images'] = self.support_df.path.apply(lambda x:forward_pass(x))
        self.query_df['embedded_images'] = self.query_df.path.apply(lambda x:forward_pass(x))

        def calculate_AP(label):
            # calculates AP for the given label
            y_ground = self.support_df.label.apply(lambda x: 1 if x==label else 0).values
            img_embedded = self.query_df.embedded_images.loc[self.query_df.label == label].values[0]
            def distance_to_query(x):
                return -np.linalg.norm(x-img_embedded)
            y_distances = self.support_df.embedded_images.apply(lambda x: distance_to_query(x))
            return average_precision_score(y_ground, y_distances)

        # compute the mAP of query set
        self.query_df['AP'] = self.query_df.label.apply(lambda x:calculate_AP(x))
        mAP = np.mean(self.query_df['AP'].values)

        return mAP
        

In [None]:
# define the transformer
transformer = transforms.Compose([transforms.Resize((256, 256)),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.RandomRotation(10),
                                  transforms.RandomCrop(256),
                                  transforms.ToTensor(),
                                  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# initiate the Prototypical network
protonets = Prototypical(df_paths_labels=train_df, support_df=support_df, query_df=query_df, length=5000, Nc=60, Ns=5, Nq=5, step=60, learning_rate=0.0001, chosen_labels=None, transform=transformer, trainval=False)

# train the Prototypical network
for n in range(1000):
    protonets.train()
    # save the model and centers every 50 times
    if n%50 ==0 and n!=0:
         torch.save(protonets.model, '/content/drive/MyDrive/PoleS8/log/model_net_'+str(n)+'.pkl')
         protonets.save_center('/content/drive/MyDrive/PoleS8/log/model_center_'+str(n)+'.csv')
    map = protonets.mean_average_precision()
    print(map)


0.24738492518125682
0.24947924938227475
0.24972491737592115
0.25328205995481523
0.24306394237599335
0.2434391340196522
0.23315662065614606
0.22688728579675907
