In [1]:
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score
from torch.optim import Adam, lr_scheduler
from torchvision import transforms as T
import torch.nn.functional as F
import torchvision.models as models
import matplotlib.pyplot as plt
import os
from os.path import join
from glob import glob
from PIL import Image
from torch import nn
import pandas as pd
import numpy as np
import torchvision
import torch
import joblib
from sklearn import svm
import random
import tarfile
import io
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
from sklearn.manifold import TSNE
import encoding

In [2]:
device = 'cuda:1'

In [3]:
train_df = pd.read_csv('/data/wikiart/wikiart_Painting100k/MultitaskPainting100k_Dataset_groundtruth/groundtruth_multiloss_train_header.csv')
test_df = pd.read_csv('/data/wikiart/wikiart_Painting100k/MultitaskPainting100k_Dataset_groundtruth/groundtruth_multiloss_test_header.csv')

In [4]:
train_df['img_path'] = train_df.apply(lambda x: join('/data/wikiart/wikiart_Painting100k/images_256minside',x.filename),1)
test_df['img_path'] = test_df.apply(lambda x: join('/data/wikiart/wikiart_Painting100k/images_256minside',x.filename),1)

In [5]:
train_df = train_df.replace('wildlife painting', 'animal painting')
train_df = train_df.replace('self-portrait', 'portrait')
train_df = train_df.replace('poster', 'design')
train_df = train_df.replace('advertisement', 'illustration')
train_df = train_df.replace('cloudscape', 'landscape')
train_df = train_df.replace('literary painting', 'mythological painting')
train_df = train_df.replace('battle painting', 'history painting')
train_df = train_df.replace('bird-and-flower painting', 'animal painting')
train_df = train_df[train_df.genre.isin(['shan shui','panorama','miniature','pastorale','quadratura','vanitas','bijinga',
                                        'calligraphy','yakusha-e'])==False]
#train_df = train_df[train_df.genre.isin(['history painting','allegorical painting','interior','capriccio','veduta','caricature','tessellation'])==False]

In [6]:
test_df = test_df.replace('wildlife painting', 'animal painting')
test_df = test_df.replace('self-portrait', 'portrait')
test_df = test_df.replace('poster', 'design')
test_df = test_df.replace('advertisement', 'illustration')
test_df = test_df.replace('cloudscape', 'landscape')
test_df = test_df.replace('literary painting', 'mythological painting')
test_df = test_df.replace('battle painting', 'history painting')
test_df = test_df.replace('bird-and-flower painting', 'animal painting')
test_df = test_df[test_df.genre.isin(['shan shui','panorama','miniature','pastorale','quadratura','vanitas','bijinga',
                                        'calligraphy','yakusha-e'])==False]
#valid_df = valid_df[valid_df.genre.isin(['history painting','allegorical painting','interior','capriccio','veduta','caricature','tessellation'])==False]

In [7]:
#dict_genre = {v: k for k, v in class_dict_genre.items()}

In [8]:
#train_df_medium=pd.read_pickle('./train_medium.pkl')
#valid_df_medium=pd.read_pickle('./valid_medium.pkl')

from collections import Counter
dict=Counter(train_df_original.genre)
sorted_dict = sorted(dict.items(), key=lambda x: x[1])
first_dict = sorted_dict[-10:]
list_genre=[]
dict_iter = iter(first_dict)
for (name, count) in dict_iter:
    list_genre.append(name)

train_df = train_df_original.loc[train_df_original['genre'].isin(list_genre)] #55545 immagini
train_df = train_df.reset_index(drop=True)

## Dataset

In [9]:
class_dict_artist = {}
for i, artist in enumerate(np.sort(train_df.artist.unique())):
    train_df.loc[train_df.artist==artist, 'class_artist'] = i
    test_df.loc[test_df.artist==artist, 'class_artist'] = i
    class_dict_artist.update({i:artist})

class_dict_style = {}
for i, style in enumerate(np.sort(train_df['style'].unique())):
    train_df.loc[train_df['style']==style, 'class_style'] = i
    test_df.loc[test_df['style']==style, 'class_style'] = i
    class_dict_style.update({i:style})
    
class_dict_genre = {}
for i, genre in enumerate(np.sort(train_df.genre.unique())):
    train_df.loc[train_df.genre==genre, 'class_genre'] = i
    test_df.loc[test_df.genre==genre, 'class_genre'] = i
    class_dict_genre.update({i:genre})   

In [10]:
#class_dict_medium = {}
#for i, medium in enumerate(np.sort(train_df_medium.medium.unique())):
#    train_df_medium.loc[train_df_medium.medium==medium, 'class_medium'] = i
#    test_df_medium.loc[test_df_medium.medium==medium, 'class_medium'] = i
#    class_dict_medium.update({i:medium})

In [11]:
train_transforms = T.Compose([
        T.Resize(256), 
        T.RandomResizedCrop(size=224, scale=(0.3,1), ratio=(1, 1)), #size 384. scale specifies the lower and upper bounds for the random area of the crop
        T.RandomHorizontalFlip(p=0.5), #p probability of the image being flipped
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], 
                    std=[0.229, 0.224, 0.225])])
test_transforms = T.Compose([
        T.Resize(224), #T.Resize((331,331)), T.CenterCrop(300) per efficientnet
        T.CenterCrop(224), 
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], 
                    std=[0.229, 0.224, 0.225])])

In [12]:
class TripletCSNDataset(Dataset): #genre, artist, style
    
    def __init__(self, df, transform, sample=None):
        if sample: 
            self.df = df.groupby(by='genre').sample(sample)
        else:
            self.df = df
            
        self.transform = transform
            
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        cond_genre = 0
        cond_artist = 1
        cond_style = 2
        anc = self.df.iloc[idx] 
        ancImg = Image.open(anc.img_path).convert('RGB')
        
        pos_genre = self.df[self.df.genre == anc.genre].sample(1).squeeze()
        neg_genre = self.df[self.df.genre != anc.genre].sample(1).squeeze()
        
        pos_artist = self.df[self.df.artist == anc.artist].sample(1).squeeze()
        neg_artist = self.df[self.df.artist != anc.artist].sample(1).squeeze()
        
        pos_style = self.df[self.df['style'] == anc['style']].sample(1).squeeze()
        neg_style = self.df[self.df['style'] != anc['style']].sample(1).squeeze()
        
        posImg_genre = Image.open(pos_genre.img_path).convert('RGB')
        negImg_genre = Image.open(neg_genre.img_path).convert('RGB')
        
        posImg_artist = Image.open(pos_artist.img_path).convert('RGB')
        negImg_artist = Image.open(neg_artist.img_path).convert('RGB')
        
        posImg_style = Image.open(pos_style.img_path).convert('RGB')
        negImg_style = Image.open(neg_style.img_path).convert('RGB')
          
        anc_genre = [int(anc.class_genre)]
        anc_artist = [int(anc.class_artist)]
        anc_style = [int(anc.class_style)]
        
        #triplet (anchor, far, close)
        return self.transform(ancImg), self.transform(negImg_genre), self.transform(posImg_genre), cond_genre, self.transform(negImg_artist), self.transform(posImg_artist), cond_artist, self.transform(negImg_style), self.transform(posImg_style), cond_style

class TripletCSNDataset(Dataset): #genre, artist, style
    
    def __init__(self, df, df2, transform):
        self.df  = df
        self.df2  = df2
        self.transform  = transform
            
    def __len__(self):
        return len(self.df)+len(self.df2)
    
    def __getitem__(self, idx):
        cond_genre = 0
        cond_artist = 1
        cond_style = 2
        cond_medium = 3
        row = self.df.loc[idx]
        row_medium = self.df2.loc[idx]
        anchor = Image.open(self.df.img_path.loc[idx]).convert('RGB')
        anchor_medium = Image.open(self.df2.img_path.loc[idx]).convert('RGB')
        anchor = self.transform(anchor)
        anchor_medium = self.transform(anchor_medium)
        
        g_pos = Image.open(self.df[self.df.class_genre==row.class_genre].drop(row.name).sample(1).iloc[0].img_path).convert('RGB')
        g_pos = self.transform(g_positive)
        g_neg = Image.open(self.df[self.df.class_genre!=row.class_genre].sample(1).iloc[0].img_path).convert('RGB')
        g_neg = self.transform(g_negative)
        anchor_genre = [int(row.class_genre)]
        
        a_pos = Image.open(self.df[self.df.class_artist==row.class_artist].drop(row.name).sample(1).iloc[0].img_path).convert('RGB')
        a_pos = self.transform(a_positive)
        a_neg = Image.open(self.df[self.df.class_artist!=row.class_artist].sample(1).iloc[0].img_path).convert('RGB')
        a_neg = self.transform(a_negative)    
        anchor_artist = [int(row.class_artist)]
        
        s_pos = Image.open(self.df[self.df.class_style==row.class_style].drop(row.name).sample(1).iloc[0].img_path).convert('RGB')
        s_pos = self.transform(s_positive)
        s_neg = Image.open(self.df[self.df.class_style!=row.class_style].sample(1).iloc[0].img_path).convert('RGB')
        s_neg = self.transform(s_negative)    
        anchor_style = [int(row.class_style)]
        
        m_pos = Image.open(self.df2[self.df2.class_medium==row.class_medium].drop(row.name).sample(1).iloc[0].img_path).convert('RGB')
        m_pos = self.transform(m_pos)
        m_neg = Image.open(self.df2[self.df2.class_medium!=row.class_medium].sample(1).iloc[0].img_path).convert('RGB')
        m_neg = self.transform(m_neg)
        anchor_medium = [int(row_medium.class_medium)]
        #triplet (anchor, far, close)
        return anchor, g_neg, g_pos, cond_genre, a_neg, a_pos, cond_artist, s_neg, s_pos, cond_style, row, anchor_medium, m_neg, m_pos, row_medium

In [13]:
train_dataset = TripletCSNDataset(train_df.reset_index(drop=True), train_transforms)    
test_dataset = TripletCSNDataset(test_df.reset_index(drop=True), test_transforms, sample=50)

train_dataset = TripletCSNDataset(train_df.reset_index(drop=True), train_df_medium.reset_index(drop=True), train_transforms)    
test_dataset = TripletCSNDataset(test_df.reset_index(drop=True), test_df_medium.reset_index(drop=True), test_transforms)

In [14]:
batch_size = 2 #16, 256
train_loader = DataLoader(train_dataset, batch_size, shuffle=True, pin_memory=False, num_workers=6)
test_loader  = DataLoader(test_dataset, batch_size, shuffle=False, pin_memory=False, num_workers=6) 

## Functions CSN

In [15]:
class CS_Tripletnet(nn.Module):
    def __init__(self, embeddingnet):
        super(CS_Tripletnet, self).__init__()
        self.embeddingnet = embeddingnet

    def forward(self, x, y, z, c):
        """ x: Anchor image,
            y: Distant (negative) image,
            z: Close (positive) image,
            c: Integer indicating according to which notion of similarity images are compared"""
        embedded_x, masknorm_norm_x, embed_norm_x, tot_embed_norm_x = self.embeddingnet(x, c)
        embedded_y, masknorm_norm_y, embed_norm_y, tot_embed_norm_y = self.embeddingnet(y, c)
        embedded_z, masknorm_norm_z, embed_norm_z, tot_embed_norm_z = self.embeddingnet(z, c)
        mask_norm = (masknorm_norm_x + masknorm_norm_y + masknorm_norm_z) / 3
        embed_norm = (embed_norm_x + embed_norm_y + embed_norm_z) / 3
        mask_embed_norm = (tot_embed_norm_x + tot_embed_norm_y + tot_embed_norm_z) / 3
        dist_a = F.pairwise_distance(embedded_x, embedded_y, 2) #distanza anchor-negative (norma2)
        dist_b = F.pairwise_distance(embedded_x, embedded_z, 2) #distanza anchor-positive
        return dist_a, dist_b, mask_norm, embed_norm, mask_embed_norm, embedded_x

In [16]:
class ConditionalSimNet(nn.Module):
    def __init__(self, embeddingnet, n_conditions, embedding_size, learnedmask=True, prein=False): #embeddingnet=resnet18
        """ embeddingnet: The network that projects the inputs into an embedding of embedding_size
            n_conditions: Integer defining number of different similarity notions
            embedding_size: Number of dimensions of the embedding output from the embeddingnet
            learnedmask: Boolean indicating whether masks are learned or fixed
            prein: Boolean indicating whether masks are initialized in equally sized disjoint 
                sections or random otherwise"""
        super(ConditionalSimNet, self).__init__()
        self.learnedmask = learnedmask
        self.embeddingnet = embeddingnet
        # create the mask
        if learnedmask:
            if prein:
                # define masks 
                self.masks = torch.nn.Embedding(n_conditions, embedding_size)
                # initialize masks
                mask_array = np.zeros([n_conditions, embedding_size])
                mask_array.fill(0.1)
                mask_len = int(embedding_size / n_conditions)
                for i in range(n_conditions):
                    mask_array[i, i*mask_len:(i+1)*mask_len] = 1
                # no gradients for the masks
                self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=True)
            else:
                # define masks with gradients
                self.masks = torch.nn.Embedding(n_conditions, embedding_size)
                # initialize weights
                self.masks.weight.data.normal_(0.9, 0.7) # 0.1, 0.005
        else:
            # define masks 
            self.masks = torch.nn.Embedding(n_conditions, embedding_size)
            # initialize masks
            mask_array = np.zeros([n_conditions, embedding_size])
            mask_len = int(embedding_size / n_conditions)
            for i in range(n_conditions):
                mask_array[i, i*mask_len:(i+1)*mask_len] = 1
            # no gradients for the masks
            self.masks.weight = torch.nn.Parameter(torch.Tensor(mask_array), requires_grad=False)
    def forward(self, x, c):
        embedded_x = self.embeddingnet(x)
        self.mask = self.masks(c)
        if self.learnedmask:
            self.mask = torch.nn.functional.relu(self.mask)
        masked_embedding = embedded_x * self.mask
        return masked_embedding, self.mask.norm(1), embedded_x.norm(2), masked_embedding.norm(2)

In [17]:
class Triplet_model(nn.Module):
    def __init__(self, base_model):
        super(Triplet_model, self).__init__()
        self.base_model = base_model
        self.fc = nn.Sequential( 
            nn.Linear(2048, 256),
            nn.PReLU(),
            nn.Dropout(0.6), #try 0.6
            nn.Linear(256, 128))
        
    def forward(self, x):
        x = self.base_model(x)
        out = self.fc(x)
        return out

In [18]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [19]:
def save_loss_fig(loss_train, loss_valid, accuracy_train, accuracy_valid, epoch):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15,4))
    ax1.plot([i for i in range(len(loss_train))], loss_train, label='train_loss')
    ax1.plot([i for i in range(len(loss_valid))],  loss_valid,  label='valid_loss')
    ax1.set(xlabel='epoch', ylabel='loss')
    ax1.legend()

    ax2.plot([i for i in range(len(accuracy_train))], accuracy_train, label='train_accuracy')
    ax2.plot([i for i in range(len(accuracy_valid))],  accuracy_valid,  label='valid_accuracy')
    ax2.set(xlabel='epoch', ylabel='accuracy')
    ax2.legend()
    fig.suptitle(f'EPOCH {epoch}', fontsize=16)
    plt.close(fig)
    fig.savefig(os.path.join('results_csn', "loss_plot_csn_SGD.jpg"), pad_inches=0)

In [20]:
def accuracy(dist_an, dist_ap):
    margin = 0
    pred = (dist_an - dist_ap - margin).cpu().data
    return (pred > 0).sum()*1.0/dist_an.size()[0], (pred > 0).sum()*1.0

def accuracy_id(dist_an, dist_ap, c, c_id):
    margin = 0
    pred = (dist_an - dist_ap - margin).cpu().data
    return ((pred > 0)*(c.cpu().data == c_id)).sum()*1.0/(c.cpu().data == c_id).sum()

In [21]:
def imshow(image):
    npimg = image.numpy().transpose(1, 2, 0)
    npimg = npimg/(npimg.max()-npimg.min())+0.5
    plt.imshow(npimg)
    plt.axis("off")
    plt.show()

## Model

In [22]:
base_model = models.resnet50(pretrained=True)
#base_model = encoding.models.get_model('ResNeSt50', pretrained=True)
base_model.fc = nn.Identity()
model = Triplet_model(base_model)
#model.float()
csn_model = ConditionalSimNet(model, n_conditions=3, embedding_size=128, learnedmask=True, prein=False) #mask_weight=normal(0.9, 0.7)
tnet = CS_Tripletnet(csn_model)
tnet = tnet.to(device)

In [23]:
criterion = torch.nn.MarginRankingLoss(margin = 0.2)
#optimizer = torch.optim.Adam([
                             # {'params':tnet.embeddingnet.embeddingnet.base_model.parameters(), 'lr':1.e-7},
                              #{'params':tnet.embeddingnet.embeddingnet.fc.parameters(),   'lr':1.e-4}
                              #], lr=1.e-5)
optimizer = torch.optim.SGD([{'params':tnet.embeddingnet.embeddingnet.base_model.parameters(), 'lr':1.e-7},{'params':tnet.embeddingnet.embeddingnet.fc.parameters(),   'lr':1.e-4}], lr=1.e-4, momentum=0.9, weight_decay=0.0001)
#adam alfa=5E-5, beta1=0.1, beta2=0.001
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.90)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.98)

In [24]:
def train_func(train_loader, tnet, criterion, optimizer, epoch, device, embed_loss=5.e-3, mask_loss=5.e-4):
    losses_genre, losses_artist, losses_style = AverageMeter(), AverageMeter(), AverageMeter()
    accs_genre, accs_artist, accs_style = AverageMeter(), AverageMeter(), AverageMeter()
    correct_genre, correct_artist, correct_style = 0, 0, 0
    
    tnet.train()
    for batch_idx, (anchor, g_negative, g_positive, condition_genre, a_negative, a_positive, condition_artist, s_negative, s_positive, condition_style) in enumerate(train_loader):  
        
        anchor, g_negative, g_positive, condition_genre = anchor.to(device), g_negative.to(device), g_positive.to(device), condition_genre.to(device)
        a_negative, a_positive, condition_artist = a_negative.to(device), a_positive.to(device), condition_artist.to(device)
        s_negative, s_positive, condition_style = s_negative.to(device), s_positive.to(device), condition_style.to(device)
        
        dist_an_genre, dist_ap_genre, mask_norm_genre, embed_norm_genre, mask_embed_norm_genre, anchor_embedded = tnet(anchor, g_negative, g_positive, condition_genre)
        dist_an_artist, dist_ap_artist, mask_norm_artist, embed_norm_artist, mask_embed_norm_artist, _ = tnet(anchor, a_negative, a_positive, condition_artist)
        dist_an_style, dist_ap_style, mask_norm_style, embed_norm_style, mask_embed_norm_style, _ = tnet(anchor, s_negative, s_positive, condition_style)
        
        target_genre = torch.FloatTensor(dist_an_genre.size()).fill_(1) # 1 means, dist_an should be larger than dist_ap
        target_genre = target_genre.to(device)
        target_artist = torch.FloatTensor(dist_an_artist.size()).fill_(1)
        target_artist = target_artist.to(device)
        target_style = torch.FloatTensor(dist_an_style.size()).fill_(1)
        target_style = target_style.to(device)
        
        loss_triplet_genre = criterion(dist_an_genre, dist_ap_genre, target_genre)
        loss_embedd_genre = embed_norm_genre / np.sqrt(anchor.size(0))
        loss_mask_genre = mask_norm_genre / anchor.size(0)
        loss_genre = loss_triplet_genre + embed_loss * loss_embedd_genre + mask_loss * loss_mask_genre
        
        loss_triplet_artist = criterion(dist_an_artist, dist_ap_artist, target_artist)
        loss_embedd_artist = embed_norm_artist / np.sqrt(anchor.size(0))
        loss_mask_artist = mask_norm_artist / anchor.size(0)
        loss_artist = loss_triplet_artist + embed_loss * loss_embedd_artist + mask_loss * loss_mask_artist
        
        loss_triplet_style = criterion(dist_an_style, dist_ap_style, target_style)
        loss_embedd_style = embed_norm_style / np.sqrt(anchor.size(0))
        loss_mask_style = mask_norm_style / anchor.size(0)
        loss_style = loss_triplet_style + embed_loss * loss_embedd_style + mask_loss * loss_mask_style
                
        loss_tot = (loss_genre+loss_artist+loss_style)/3
        
        # measure accuracy and record loss
        acc_genre, label_acc_genre = accuracy(dist_an_genre, dist_ap_genre)
        acc_artist, label_acc_artist = accuracy(dist_an_artist, dist_ap_artist)
        acc_style, label_acc_style = accuracy(dist_an_style, dist_ap_style)
        
        correct_genre = correct_genre + label_acc_genre.item()
        correct_artist = correct_artist + label_acc_artist.item()
        correct_style = correct_style + label_acc_style.item()
        
        losses_genre.update(loss_triplet_genre.data.item(), anchor.size(0))
        losses_artist.update(loss_triplet_artist.data.item(), anchor.size(0))
        losses_style.update(loss_triplet_style.data.item(), anchor.size(0))
        
        accs_genre.update(acc_genre, anchor.size(0))
        accs_artist.update(acc_artist, anchor.size(0))
        accs_style.update(acc_style, anchor.size(0))
        
        # compute gradient and do optimizer step
        optimizer.zero_grad()
        loss_tot.backward()
        optimizer.step()    

        template = f'Iteration [{batch_idx}/{len(train_loader)}] Training | Train Loss: {round((losses_genre.avg + losses_artist.avg + losses_style.avg)/3,3)} | Accuracy: {round(((accs_genre.avg + accs_artist.avg + accs_style.avg)/3).item(),3)} | Correct Genre: {round(accs_genre.avg.item(),3)} | Correct Artist: {round(accs_artist.avg.item(),3)} | Correct Style: {round(accs_style.avg.item(),3)}\r'
        print(template, end='')
        
    mean_accs = (accs_genre.avg + accs_artist.avg + accs_style.avg)/3
    mean_losses = (losses_genre.avg + losses_artist.avg + losses_style.avg)/3
    
    return tnet, mean_accs, mean_losses, int(correct_genre), int(correct_artist), int(correct_style), accs_genre.avg, accs_artist.avg, accs_style.avg

In [25]:
def valid_func(valid_loader, tnet, criterion, epoch, device, embed_loss=1.e-4, mask_loss=1e-4):
    
    losses_genre, losses_artist, losses_style = AverageMeter(), AverageMeter(), AverageMeter()
    accs_genre, accs_artist, accs_style = AverageMeter(), AverageMeter(), AverageMeter()
    correct_genre, correct_artist, correct_style = 0, 0, 0
    anchors_embedded = []
    
    tnet.eval()  
    for batch_idx, (anchor, g_negative, g_positive, condition_genre, a_negative, a_positive, condition_artist, s_negative, s_positive, condition_style) in enumerate(valid_loader):
        
        anchor, g_negative, g_positive, condition_genre = anchor.to(device), g_negative.to(device), g_positive.to(device), condition_genre.to(device)
        a_negative, a_positive, condition_artist = a_negative.to(device), a_positive.to(device), condition_artist.to(device)
        s_negative, s_positive, condition_style = s_negative.to(device), s_positive.to(device), condition_style.to(device)
        
        dist_an_genre, dist_ap_genre, _, _, _, anchor_embedded = tnet(anchor, g_negative, g_positive, condition_genre)
        dist_an_artist, dist_ap_artist, _, _, _, _ = tnet(anchor, a_negative, a_positive, condition_artist)
        dist_an_style, dist_ap_style, _, _, _, _ = tnet(anchor, s_negative, s_positive, condition_style)
        
        anchors_embedded.append(anchor_embedded.detach().cpu().numpy())
        #anchors_row.append(anchor_row.detach().cpu().numpy())
        
        target_genre = torch.FloatTensor(dist_an_genre.size()).fill_(1)
        target_genre = target_genre.to(device)
        target_artist = torch.FloatTensor(dist_an_artist.size()).fill_(1)
        target_artist = target_artist.to(device)
        target_style = torch.FloatTensor(dist_an_style.size()).fill_(1)
        target_style = target_style.to(device)
        
        valid_loss_genre = criterion(dist_an_genre, dist_ap_genre, target_genre)
        valid_loss_artist = criterion(dist_an_artist, dist_ap_artist, target_artist)
        valid_loss_style = criterion(dist_an_style, dist_ap_style, target_style)
        
        acc_genre, label_acc_genre = accuracy(dist_an_genre, dist_ap_genre)
        acc_artist, label_acc_artist = accuracy(dist_an_artist, dist_ap_artist)
        acc_style, label_acc_style = accuracy(dist_an_style, dist_ap_style)
        
        correct_genre = correct_genre + label_acc_genre.item()
        correct_artist = correct_artist + label_acc_artist.item()
        correct_style = correct_style + label_acc_style.item()
        
        accs_genre.update(acc_genre, anchor.size(0))
        accs_artist.update(acc_artist, anchor.size(0))
        accs_style.update(acc_style, anchor.size(0))
        
        losses_genre.update(valid_loss_genre.data.item(), anchor.size(0))
        losses_artist.update(valid_loss_artist.data.item(), anchor.size(0))
        losses_style.update(valid_loss_style.data.item(), anchor.size(0))
        
        template = f'Iteration [{batch_idx}/{len(valid_loader)}] Validating | Valid Loss: {round(((losses_genre.avg + losses_artist.avg + losses_style.avg)/3),3)} | Accuracy: {round(((accs_genre.avg + accs_artist.avg + accs_style.avg)/3).item(),3)} | Correct Genre: {round(accs_genre.avg.item(),3)} | Correct Artist: {round(accs_artist.avg.item(),3)} | Correct Style: {round(accs_style.avg.item(),3)}\r'
        print(template, end='')

    mean_accs = (accs_genre.avg + accs_artist.avg + accs_style.avg)/3
    mean_losses = (losses_genre.avg + losses_artist.avg + losses_style.avg)/3
        
    return mean_accs, mean_losses, int(correct_genre), int(correct_artist), int(correct_style), accs_genre.avg, accs_artist.avg, accs_style.avg, anchors_embedded#, anchors_row

In [26]:
min_los = 10
loss_train, loss_valid = [], []
accuracy_train, accuracy_valid = [], []

In [None]:
for epoch in range(0, 2000):
    
    if epoch == 0: acc_valid, losses_valid, correct_genre_valid, correct_genre_valid, correct_style_valid, mean_accs_genre_valid, mean_accs_artist_valid, mean_accs_style_valid, anchors_embedded = valid_func(test_loader, tnet, criterion, epoch, device) 

    train_loader = DataLoader(TripletCSNDataset(train_df, train_transforms, sample=100),  
                          batch_size=batch_size, shuffle=True, 
                          num_workers = 8, pin_memory=False, 
                          drop_last=False)
    
    tnet, acc_train, losses_train, correct_genre_train, correct_artist_train, correct_style_train, mean_accs_genre_train, mean_accs_artist_train, mean_accs_style_train = train_func(train_loader, tnet, criterion, optimizer, epoch, device)
    acc_valid, losses_valid, correct_genre_valid, correct_genre_valid, correct_style_valid, mean_accs_genre_valid, mean_accs_artist_valid, mean_accs_style_valid, anchors_embedded = valid_func(test_loader, tnet, criterion, epoch, device)

    template = f'Epoch: {epoch}\t| Train Loss: {round(losses_train,3)} | Accuracy: {round(acc_train.item(),3)} | Correct Genre: {round(mean_accs_genre_train.item(),3)} | Correct Artist: {round(mean_accs_artist_train.item(),3)} | Correct Style: {round(mean_accs_style_train.item(),3)}\n\
        \t\t| Valid Loss: {round(losses_valid,3)} | Accuracy: {round(acc_valid.item(),3)} | Correct Genre: {round(mean_accs_genre_valid.item(),3)} | Correct Artist: {round(mean_accs_artist_valid.item(),3)} | Correct Style: {round(mean_accs_style_valid.item(),3)}\n'
    print(template)
    with open("results_csn/sample_SGD.txt", "a") as file_object:
            file_object.write(template)


    if losses_valid < min_los:
        min_los = losses_valid
        torch.save(tnet, 'csn_model_SGD.pt')
        with open("results_csn/sample_SGD.txt", "a") as file_object:
            file_object.write('------------------- Saving model -------------------\n')
        print('------------------- Saving model -------------------')

    loss_train.append(losses_train)
    loss_valid.append(losses_valid)
    accuracy_train.append(acc_train)
    accuracy_valid.append(acc_valid)
    save_loss_fig(loss_train, loss_valid, accuracy_train, accuracy_valid, epoch)

    scheduler.step(losses_valid)

Epoch: 0	| Train Loss: 0.21 | Accuracy: 0.541 | Correct Genre: 0.535 | Correct Artist: 0.572 | Correct Style: 0.517 | Correct Style: 0.604
        		| Valid Loss: 0.142 | Accuracy: 0.681 | Correct Genre: 0.592 | Correct Artist: 0.847 | Correct Style: 0.604

------------------- Saving model -------------------
Epoch: 1	| Train Loss: 0.188 | Accuracy: 0.565 | Correct Genre: 0.539 | Correct Artist: 0.619 | Correct Style: 0.537| Correct Style: 0.612
        		| Valid Loss: 0.141 | Accuracy: 0.687 | Correct Genre: 0.604 | Correct Artist: 0.846 | Correct Style: 0.612

------------------- Saving model -------------------
Epoch: 2	| Train Loss: 0.186 | Accuracy: 0.578 | Correct Genre: 0.545 | Correct Artist: 0.646 | Correct Style: 0.544rrect Style: 0.6477646
        		| Valid Loss: 0.14 | Accuracy: 0.7 | Correct Genre: 0.634 | Correct Artist: 0.82 | Correct Style: 0.647

------------------- Saving model -------------------
Epoch: 3	| Train Loss: 0.178 | Accuracy: 0.593 | Correct Genre: 0.565 |

## Testing

In [None]:
device='cpu'

In [None]:
model = torch.load('csn_model_gsa.pt').to(device)

In [None]:
model.eval() 
for i in range(10):
    anchor, g_negative, g_positive, condition_genre, a_negative, a_positive, condition_artist, s_negative, s_positive, condition_style=next(iter(test_loader))
    anchor, g_negative, g_positive, condition_genre = anchor.to(device), g_negative.to(device), g_positive.to(device), condition_genre.to(device)
    dist_an_genre, dist_ap_genre, _, _, _, _ = model(anchor, g_negative, g_positive, condition_genre)
    concatenated_an = torch.cat((anchor, g_negative), 0)
    concatenated_ap = torch.cat((anchor, g_positive), 0)
    #print('Genre anchor: '.format(anchor_row.genre))
    imshow(torchvision.utils.make_grid(concatenated_an))
    print('Dissimilarity a-n: ', dist_an_genre)
    imshow(torchvision.utils.make_grid(concatenated_ap))
    print('Dissimilarity a-p: ', dist_ap_genre)

In [None]:
model.eval()
anchors_embedded, anchors_genre, anchors_path = [],[],[]
for anchor, g_negative, g_positive, condition_genre, a_negative, a_positive, condition_artist, s_negative, s_positive, condition_style in next(iter(test_loader)):
    anchor, g_negative, g_positive, condition_genre = anchor.to(device), g_negative.to(device), g_positive.to(device), condition_genre.to(device)
    dist_an_genre, dist_ap_genre, _, _, _, anchor_embedded = model(anchor, g_negative, g_positive, condition_genre)
    anchors_embedded.append(anchor_embedded.detach().cpu().numpy())
    anchors_genre.append(row.genre.detach().cpu().numpy())
    anchors_path.append(row.img_path.detach().cpu().numpy())

In [None]:
def tsne_image():
    tsne = TSNE(n_components=2, perplexity=50, random_state=1) #try different perplexity (between 5 and 50)
    X_embedded = tsne.fit_transform(anchors_embedded) #x=anchor_embedded concatenato
    tx, ty = tsne[:,0], tsne[:,1]
    tx = (tx-np.min(tx)) / (np.max(tx) - np.min(tx))
    ty = (ty-np.min(ty)) / (np.max(ty) - np.min(ty))

    width = 4000
    height = 3000
    max_dim = 100

    full_image = Image.new('RGB', (width, height))
    for img, x, y in zip(anchors_path, tx, ty): #change images
        tile = Image.open(img)
        rs = max(1, tile.width/max_dim, tile.height/max_dim)
        tile = tile.resize((int(tile.width/rs), int(tile.height/rs)), Image.ANTIALIAS) #antialias?
        full_image.paste(tile, (int((width-max_dim)*x), int((height-max_dim)*y)), mask=tile.convert('RGB'))

    matplotlib.pyplot.figure(figsize = (16,12))
    imshow(full_image)
    #https://github.com/sinanatra/image-tsne/blob/master/notebooks/image_tsne.ipynb

    full_image.savefig(os.path.join('.', "tsne_plot_csn.jpg"), pad_inches=0)

        #fig=plt.figure(figsize=(6, 5))
        #plt.title('t-SNE')
        #target_ids = range(len(list_genre))
        #palette = np.array(sns.color_palette("hls", 41))
        #colors=palette[colors.astype(np.int)]

        #for i, c, label in zip(target_ids, colors, list_genre):
         #   plt.scatter(X_embedded[y_train == i, 0], X_embedded[y_train == i, 1], c=c, label=label)
        #plt.legend()
        #plt.close(fig)
        #fig.savefig(os.path.join('.', "tsne_plot_csn.jpg"), pad_inches=0)

In [None]:
def save_tsne_fig(X_train, y_train):
    tsne = TSNE(n_components=2, perplexity=50) #try different perplexity (between 5 and 50)
    X_embedded = tsne.fit_transform(X_train)
    
    fig=plt.figure(figsize=(17, 10))
    plt.title('t-SNE')
        
    for label in np.unique(y_train):
        plt.scatter(X_embedded[y_train == label, 0], 
                    X_embedded[y_train == label, 1],
                    c=plt.cm.Set1(label / float(len(np.unique(y_train)))),
                    alpha=0.8,
                    label=class_dict_genre[label])
    plt.legend(loc='best')
    plt.close(fig)
    fig.savefig(os.path.join('.', "csn_plot.jpg"), pad_inches=0)

## Visualize Nearest Neighbors

truelabels = []
predictions = []
model.eval()

for images, target in test_loader:
    for label in target.data.numpy():
        truelabels.append(label)
    out = model(images.to(device))
    for prediction in out.data.argmax(1):
        predictions.append(prediction.detach().cpu().numpy().item())

In [None]:
from sklearn.neighbors import NearestNeighbors

In [None]:
#def get_image_as_np_array(filename: str):
 #   """Returns an image as an numpy array
  #  """
   # img = Image.open(filename)
    #return np.asarray(img)


def plot_knn_examples(embeddings, dataframe, n_neighbors=4, num_examples=6):
    """Plots multiple rows of random images with their nearest neighbors
    """
    nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(embeddings)
    distances, indices = nbrs.kneighbors(embeddings)

    # get 5 random samples
    samples_idx = np.random.choice(len(indices), size=num_examples, replace=False)

    # loop through our randomly picked samples
    for idx in samples_idx:
        fig = plt.figure()
        # loop through their nearest neighbors
        for plot_x_offset, neighbor_idx in enumerate(indices[idx]):
            # add the subplot
            ax = fig.add_subplot(1, len(indices[idx]), plot_x_offset + 1)
            # get the correponding filename for the current index
            #fname = dataframe.img_path[neighbor_idx]
            # plot the image
            imshow(Image.open(dataframe.img_path[neighbor_idx]).convert('RGB'))
            # set the title to the distance of the neighbor
            ax.set_title(f'd={distances[idx][plot_x_offset]:.3f}')
            # let's disable the axis
            plt.axis('off')

In [None]:
plot_knn_examples(anchors_embedded, test_df) #embeddings n samples x n features

row = test_df.sample(1).iloc[0]
print(f'Actual genre: {row.genre}')
im = Image.open(row.img_path).convert('RGB')
X = []
distout = model(valid_transforms(im).unsqueeze(0).to(device))
X.append(out.detach().cpu().numpy())
pred = clf.predict(np.concatenate(X))

T.Resize(250)(im)

def change_color(train_data):
    num, size = train_data.shape
    output = np.zeros([num, size, 3], dtype=np.uint8)
    color = np.zeros([num, 3], dtype=np.uint8)
    for i in range(num):
        print(i)
        r = np.random.randint(0, 200)
        g = np.random.randint(0, 200)
        b = np.random.randint(0, 200)
        color[i][0] = r
        color[i][1] = g
        color[i][2] = b
        for j in range(size):
            gray_scale = train_data[i][j]
            if gray_scale!=0:
                output[i][j][0] = r*gray_scale
                output[i][j][1] = g*gray_scale
                output[i][j][2] = b*gray_scale

    return output, color