In [1]:
#training methods:
#-> normal classification
#-> arc classification
#-> encoding
#-> trained embedding
#-> trained retrieval

In [2]:
!ls /kaggle/input/embedding-data/embeddings

architecture-dataset  fruits360        imagenet1000  products-10k  shopee
food101		      imagenet-sketch  places	     rp2k


In [3]:
import os
import cv2
import sys
import math
import copy
import torch
import random
import numpy as np
import pandas as pd
from torch import nn
from PIL import Image
from tqdm import tqdm
from zipfile import ZipFile
from torch.nn import LayerNorm
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import transforms

    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, 
                 m=0.5, easy_margin=False, ls_eps=0.0):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

In [5]:
class TopTrainer(nn.Module):
    def __init__(self, fw):
        super().__init__()
        self.fw = fw
        self.lid = None
        self.pool = nn.AdaptiveAvgPool1d(64)
        self.arc = ArcMarginProduct(1000,1000)
    def arc_sim(self, x, label):
        x = self.fw(x)
        return self.arc( x, label)
    def with_arc(self, x, label, use_lid=True):
        x = self.fw(x)
        if use_lid:
            x = self.lid(x)
        x = self.arc( x, label)
        return x
        
    def forward(self, x, use_lid=True):
        x = self.fw(x) 
        if use_lid:
            x = self.lid(x)
        return x

model = TopTrainer(
    nn.Sequential(
            nn.Linear( 768, 64)
        )
).to(device)
model

TopTrainer(
  (fw): Sequential(
    (0): Linear(in_features=768, out_features=64, bias=True)
  )
  (pool): AdaptiveAvgPool1d(output_size=64)
  (arc): ArcMarginProduct()
)

In [6]:
class class_ds_config:
    def __init__(
            self,
            path,
            lr,
            epochs,
            train, 
            max_class
    ):
        self.path=path
        self.lr=lr
        self.epochs=epochs
        self.train=train
        self.max_class = max_class
class embed_ds_config:
    def __init__(
        self,
        path,
        lr,
        epochs,
        train,
        criterion
    ):
        self.path=path
        self.lr=lr
        self.epochs=epochs
        self.train=train,
        self.criterion=criterion
class config:
    
    class_data_ls = [     
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/architecture-dataset',
            0.0001, 2, True, 25
        ),
        class_ds_config(
            '/kaggle/input/notebook-data/classification/GPR',
            0.00001,2,True,1200
        ),
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/food101',
            0.00008, 2, True, 101
        ),
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/fruits360',
            0.0001, 1, False, 131
        ),
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/rp2k',
            0.0001, 15, True, 2384
        ),
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/products-10k',
            0.0001, 5, True, 9691
        ),
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/shopee',
            0.0001, 10, True, 11014
        ),
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/imagenet1000',
            0.0001, 10, True, 1000
        ),
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/places',
            0.0001, 2, True, 1000
        ),
        
        class_ds_config(
            '/kaggle/input/embedding-data/embeddings/imagenet-sketch',
            0.0001, 2, True, 1000
        ),
        class_ds_config(
            '/kaggle/input/notebook-data/classification/imagenet1000',
            0.0001,3,False,1000
        )
    ]
    embed_data_ls = [
        embed_ds_config(
            '/kaggle/input/notebook-data/embedding/imagenet1000',
            0.0005,
            3,
            False,
            nn.MSELoss()
        ),
        embed_ds_config(
            '/kaggle/input/notebook-data/embedding/google-landmarks-2021-V1',
            0.0005,
            2,
            False,
            nn.MSELoss()
        ),
        embed_ds_config(
            '/kaggle/input/notebook-data/embedding/fashion',
            0.00005,
            2,
            False,
            nn.MSELoss()
        ),
        embed_ds_config(
            '/kaggle/input/notebook-data/embedding/GPR12000',
            0.0005,
            2,
            False,
            nn.MSELoss()
        )
    ]

In [7]:
def retrieval_evaluate( model, ds, ds_config):
    def get_embeds( model, ds):
        embeds = []
        labels = []
        with torch.no_grad():
            for vec, labels_ in ds:
                out = model( vec.to(device), use_lid=False)
                embeds.append(out)
                labels.append(labels_)
        embeds = torch.cat(embeds)
        labels = torch.cat(labels)
        return (embeds, labels)
    def normalize(a, eps=1e-8):
        a_n = a.norm(dim=1)[:, None]
        a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
        return a_norm
    def k_nearest_neighbors(embeds, k=5):
        #print(embeds.shape)
        normalized = normalize(embeds)
        preds = normalized @ normalized.T
        #print(preds.shape)
        #vals, indices = preds.sort(dim=1, descending=True)
        vals,indices = torch.topk(preds,6)
        k += 1
        return indices[:, 1:k].long()
    embeds, labels = get_embeds( model, ds)
    preds = k_nearest_neighbors(embeds, k=5)
    accs = (labels[preds] == labels.view(-1, 1)).float().mean(dim=1)
    return accs.mean()

last_retrieval = [0 for i in config.class_data_ls]
def get_retrieval_scores():
    idx=0
    for ds_config in config.class_data_ls:
        
        ds = classification_dataset(ds_config)
        if ds_config.path.split('/')[-1] == 'products-10k':
            #print('splicing ds...',end='    ')
            ds = ds.splice(0,512)
        #ds = ds.splice(0,512)
        #val_score = retrieval_evaluate( model, ds.val_split(), ds_config)
        #train_score = retrieval_evaluate( model, ds.train_split(), ds_config)
        score = retrieval_evaluate( model, ds, ds_config)
        diff = score.item() - last_retrieval[idx]
        delta = '+' if diff >= 0 else ' ' 
        pad = 20 - len(ds_config.path.split('/')[-1])
        print( ds_config.path.split('/')[-1], " "*pad,": ", score.item(),"  {}{}".format(delta, diff))
        last_retrieval[idx] = score.item()
        #print( ds_config.path.split('/')[-1]," train: ", train_score.item())
        #print( ds_config.path.split('/')[-1],"   val: ", val_score.item())
        idx+=1
def _eval( model):
    for ds_config in config.class_data_ls[:1]:
        ds = classification_dataset(ds_config)
        ds = ds.val_split()
        score = retrieval_evaluate( model, ds, ds_config)
        print( ds_config.path.split('/')[-1],": ", score.item())

In [8]:
class classification_dataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.config = config
        self.vecs = ['vec/'+_dir for _dir in os.listdir(os.path.join(config.path, 'vec'))]
        self.vecs.sort()
        self.labels = ['label/'+_dir for _dir in os.listdir(os.path.join(config.path, 'label'))]
        self.labels.sort()
    def __len__(self):
        return len(self.vecs)
    def train_split(self):
        _len = len(self)
        split = int(_len * 0.8)
        self.vecs = self.vecs[:split]
        self.labels = self.labels[:split]
        return self
    def val_split(self):
        _len = len(self)
        split = int(_len * 0.8)
        self.vecs = self.vecs[split:]
        self.labels = self.labels[split:]
        return self
        
    def splice(self, start, end):
        self.vecs = self.vecs[start:end]
        self.labels = self.labels[start:end]
        return self
    def __getitem__(self, idx):
        if self.vecs[idx].split('/')[1] != self.labels[idx].split('/')[1]:
            print('error')
        vec = torch.load(os.path.join(self.config.path,self.vecs[idx]),map_location=device)
        label = torch.load(os.path.join(self.config.path,self.labels[idx]),map_location=device)
        return vec, label

In [9]:
def train_classification_ds( model, ds, ds_config):
    optim = torch.optim.Adam(model.parameters(), lr=ds_config.lr)
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(ds_config.epochs):
        print('training epoch ',epoch+1,'...')
        for i, batch in enumerate(ds):
            _input, label = batch
            
            _input = _input.to(device)
            optim.zero_grad()
            output = model(_input)
            
            loss = criterion( output, label.to(device))
            loss.backward()
            optim.step()
    return model

def eval_classification_ds( model, ds, ds_config):
    criterion = nn.CrossEntropyLoss()
    model.eval()
    total_loss=0
    for i, batch in enumerate(ds):
        _input, label = batch
        _input = _input.to(device)
        output = model(_input)
        loss = criterion( output, label.to(device)).detach().item()
        total_loss += loss
    return total_loss / len(ds)

In [10]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    def step(self, closure=None):
        raise NotImplementedError("SAM doesn't work like the other optimizers, you should first call `first_step` and the `second_step`; see the documentation for more info.")

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

In [11]:
def train_arc_ds( model, ds, ds_config):    
    
    #optim = torch.optim.Adam(model.parameters(), lr=ds_config.lr)
    base_optimizer = torch.optim.Adam
    optimizer = SAM(model.parameters(), base_optimizer, lr=0.001)  
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(ds_config.epochs):
        print('training epoch ',epoch+1,'...',end='')
        total_loss = 0
        for i, batch in enumerate(ds):
            _input, label = batch
            
            image_preds = model.with_arc(_input.to(device),label.to(device))   #output = model(input)
            #print(image_preds.shape, exam_pred.shape)

            loss = criterion(image_preds, label.to(device)) 
            loss.backward()
            optimizer.first_step(zero_grad=True)

            # second forward-backward pass
            criterion(model.with_arc(_input.to(device), label.to(device)), label.to(device)).backward()
            optimizer.second_step(zero_grad=True)
            total_loss += loss.detach().item()

            #_input = _input.to(device)
            #optim.zero_grad()
            #output = model.with_arc( _input, label)

            #loss = criterion( output, label.to(device))
            #loss.backward()
            #optim.step()
        print('      {}'.format(total_loss))
        total_loss = 0
    return model
if True:
    get_retrieval_scores()
    for ds_config in config.class_data_ls:
        if ds_config.train:
            print('='*50)
            ds = classification_dataset(ds_config)
            ds = ds.train_split()
            model.lid = nn.Linear(64, ds_config.max_class).to(device)
            model.arc = ArcMarginProduct( ds_config.max_class,ds_config.max_class).to(device)
            print('pre training loss: ',eval_classification_ds( model, ds, ds_config))

            print('training dataset: {}'.format(ds_config.path.split('/')[-1]))
            model = train_arc_ds( model, ds, ds_config)
            print('post training loss: ',eval_classification_ds( model, ds, ds_config))
            #get_retrieval_scores()
            print('finished.')
            print('='*50)
            print('')
    get_retrieval_scores()

architecture-dataset  :  0.7294535040855408   +0.7294535040855408
GPR                   :  0.676633358001709   +0.676633358001709
food101               :  0.7326732277870178   +0.7326732277870178
fruits360             :  0.99440598487854   +0.99440598487854
rp2k                  :  0.429698646068573   +0.429698646068573
products-10k          :  0.17433473467826843   +0.17433473467826843
shopee                :  0.34236496686935425   +0.34236496686935425
imagenet1000          :  0.4473622143268585   +0.4473622143268585
places                :  0.40837037563323975   +0.40837037563323975
imagenet-sketch       :  0.5465935468673706   +0.5465935468673706
imagenet1000          :  0.44729891419410706   +0.44729891419410706
pre training loss:  3.253124757607778
training dataset: architecture-dataset
training epoch  1 ...      884.9143877029419
training epoch  2 ...      461.63996028900146
post training loss:  3.2141392389933268
finished.

pre training loss:  7.1238253180185955
training dataset

In [12]:
class ContrastiveLoss(nn.Module):
    def __init__(self, m=2.0):
        super(ContrastiveLoss, self).__init__()  # pre 3.3 syntax
        self.m = m  # margin or radius

    def forward(self, y1, y2, d=0):
        # d = 0 means y1 and y2 are supposed to be same
        # d = 1 means y1 and y2 are supposed to be different

        euc_dist = T.nn.functional.pairwise_distance(y1, y2)

        if d == 0:
            return T.mean(T.pow(euc_dist, 2))  # distance squared
        else:  # d == 1
            delta = self.m - euc_dist  # sort of reverse distance
            delta = T.clamp(delta, min=0.0, max=None)
            return T.mean(T.pow(delta, 2))  # mean over all rows

In [13]:
def train_encoding_class_ds(model, ds, ds_config):
    optim = torch.optim.SGD(model.fw.parameters(), lr=ds_config.lr)
    criterion = nn.MSELoss()#ds_config.criterion
    model.train()
    total_loss=0
    for i, batch in enumerate(ds):
        _input, label = batch
        optim.zero_grad()
        model.zero_grad()
        output = model(_input.to(device))
        loss = criterion( output, _input.to(device)) 
        loss.backward()
        optim.step()
        total_loss += loss.detach().item()
    print('     {}'.format(total_loss / len(ds)))
    return model

if False:
    get_retrieval_scores()
    model.lid = nn.Linear(64,768).to(device)
    for epoch in range(10):
        for ds_config in config.class_data_ls:
            if ds_config.train:
                ds = classification_dataset(ds_config)
                print('training ',ds_config.path)
                print('epoch {} training {}...'.format( epoch+1, ds_config.path.split('/')[-1]),end='')
                model = train_encoding_class_ds( model, ds, ds_config)

                print('')
    get_retrieval_scores()

In [14]:
def evaluate_margin(model, ds):
    pass
    #this should evaluate the embedding difference between model outputs of similar and different classes

In [15]:
#idea, have a special arcface layer to go on top of the flattened covariance matrix,
#it will just have class 0 and 1
#input: 64^2, output: 64^2

def train_similarities( model, ds, ds_config):
    optim = torch.optim.Adam(model.parameters(), lr=0.00005)#ds_config.lr)
    criterion = nn.CrossEntropyLoss()
    model.train()
    
    for epoch in range(ds_config.epochs):
        print('training epoch ',epoch+1,'...',end='')
        total_loss = 0
        for i, batch in enumerate(ds):
            _input, label = batch
            
            _input = _input.to(device)
            optim.zero_grad()
            output = model( _input, use_lid=False)
            #output = model.arc_sim( _input, label.to(device))
            x, y = torch.meshgrid(label,label)
            mesh = (x==y).type(torch.uint8).type(torch.float)
            covariance = output @ output.T
            loss = criterion( covariance, mesh)
            total_loss += loss.detach().item()
            
            #loss = criterion( output, label.to(device))
            loss.backward()
            optim.step()
        print('   loss: {}'.format(total_loss / len(ds)))
    return model
if False:
    print('training similarities... ')
    #get_retrieval_scores()
    model.arc = ArcMarginProduct(64,64).to(device)
    for ds_config in config.class_data_ls:
        if ds_config.train:
            ds = classification_dataset(ds_config)
            print('training ',ds_config.path)
            
            model = train_similarities( model, ds, ds_config)
            print('finished.')
    get_retrieval_scores()

In [16]:
class embedding_dataset(torch.utils.data.Dataset):
    def __init__(self, config):
        self.config = config
        self.vecs = ['vec/'+_dir for _dir in os.listdir(os.path.join(config.path, 'vec'))]
        self.vecs.sort()
       
    def __len__(self):
        return len(self.vecs)

    def __getitem__(self, idx):
        vec = torch.load(os.path.join(self.config.path,self.vecs[idx]),map_location=device)
        return vec

def train_embedding_ds( model, ds, ds_config):
    optim = torch.optim.SGD(model.fw.parameters(), lr=ds_config.lr)
    criterion = ds_config.criterion
    model.train()
    for i, batch in enumerate(ds):
        optim.zero_grad()
        model.zero_grad()

        output = model(batch.to(device), use_lid=False)
        label_vec = torch.sum( output, 0) / batch.shape[0] 

        loss = criterion( output, label_vec[None,:].repeat(batch.shape[0], 1))# * 10 

        loss.backward()
        optim.step()
    return model


In [17]:
def train_encoding_ds(model, ds, ds_config):
    optim = torch.optim.SGD(model.fw.parameters(), lr=ds_config.lr)
    criterion = nn.MSELoss()#ds_config.criterion
    model.train()
    total_loss=0
    for i, batch in enumerate(ds):
        optim.zero_grad()
        model.zero_grad()
        output = model(batch.to(device))
        loss = criterion( output, batch) 
        loss.backward()
        optim.step()
        total_loss += loss.detach().item()
    print('     {}'.format(total_loss / len(ds)))
    return model
get_retrieval_scores()
for ds_config in config.embed_data_ls:
    if ds_config.train:
        model.lid = nn.Linear(64,768).to(device)
        ds = embedding_dataset(ds_config)
        print('training ',ds_config.path)
        
        for epoch in range(ds_config.epochs):
            print('{} / {} training {}...'.format( epoch+1, ds_config.epochs, ds_config.path),end='')
            model = train_encoding_ds( model, ds, ds_config)
        
        print('')
get_retrieval_scores()

architecture-dataset  :  0.7722986936569214   +0.0
GPR                   :  0.8054666519165039   +0.0
food101               :  0.8609066009521484   +0.0
fruits360             :  0.9955152273178101   +0.0
rp2k                  :  0.6035001873970032   +0.0
products-10k          :  0.3372558653354645   +0.0
shopee                :  0.41354745626449585   +0.0
imagenet1000          :  0.7062138319015503   +0.0
places                :  0.537222146987915   +0.0
imagenet-sketch       :  0.7114307284355164   +0.0
imagenet1000          :  0.7063923478126526   +0.0
training  /kaggle/input/notebook-data/embedding/imagenet1000
1 / 3 training /kaggle/input/notebook-data/embedding/imagenet1000...     6.631022381852586
2 / 3 training /kaggle/input/notebook-data/embedding/imagenet1000...     6.471067062904739
3 / 3 training /kaggle/input/notebook-data/embedding/imagenet1000...     6.388194638531542

training  /kaggle/input/notebook-data/embedding/google-landmarks-2021-V1
1 / 2 training /kaggle/input/no

In [18]:
def eval_embedding_ds( model, ds, ds_config):
    
    criterion = ds_config.criterion
    model.eval()
    total_loss=0
    for i, batch in enumerate(ds):
        with torch.no_grad():
            output = model(batch.to(device), use_lid=False)
        label_vec = torch.sum( output, 0) / batch.shape[0] 
        loss = criterion( output, label_vec[None,:].repeat(batch.shape[0], 1)).detach().item()# * 10 
        total_loss+=loss
    return total_loss / len(ds)
    

In [19]:
#for ds_config in config.embed_data_ls:
#    if ds_config.train[0]:
#        ds = embedding_dataset(ds_config)
#       print('training ',ds_config.path)
#       print('pre training loss: ',eval_embedding_ds( model, ds, ds_config))
#        get_retrieval_scores()
#       for epoch in range(ds_config.epochs):
#           print('{} / {} training {}...'.format( epoch+1, ds_config.epochs, ds_config.path))
#            model = train_embedding_ds( model, ds, ds_config)
#            get_retrieval_scores()
#        print('post training loss: ',eval_embedding_ds( model, ds, ds_config))
#        get_retrieval_scores()
#        print('')

In [20]:
def make_embedding_lib( model, ds, ds_config):    
    #1. find each class in the dataset
    all_labels = None
    for i, batch in enumerate(ds):
        _, label = batch
        if all_labels is None:
            all_labels = label
        else:
            all_labels = torch.concat((all_labels, label),0)
    unique = torch.unique(all_labels)
    
    #2. for each class, gather all of the vectors for that class
    vecs = None
    for label in unique:
        print('-',end='')
        class_vecs = None
        for i, batch in enumerate(ds):
            vec, labels = batch
            mask = (labels == label).nonzero()#torch.where(labels==label, label, -1)
            selected = vec[mask]
            if selected.shape[0] is not 0:
                
                if class_vecs is None:
                    class_vecs = selected
                else:
                    class_vecs = torch.concat((class_vecs, selected),0)    
                    
        #3. for each set of vectors run the model on each vector, and get the average vector
        class_vecs = class_vecs.squeeze()
        loader = torch.utils.data.DataLoader(
            class_vecs,
            batch_size=64
        )
        output_vecs = None
        for i, batch in enumerate(loader):
            output = model(batch, use_lid=False)
            if output_vecs is None:
                output_vecs = output
            else:
                output_vecs = torch.concat((output_vecs, output),0)
                
        #4. calculate an average output vector for each class
        mean_vec = torch.sum( output_vecs, 0) / class_vecs.shape[0]
        mean_vec = mean_vec[None, :]
        if vecs is None:
            vecs = mean_vec
        else:
            vecs = torch.concat(( vecs, mean_vec), 0)
    print(vecs.shape)
    return vecs
        
        
    
        
#for ds_config in config.class_data_ls[:1]:
#    if ds_config.train:
#        ds = classification_dataset(ds_config)
#        sample_vec_lib = make_embedding_lib( model, ds, ds_config)

In [21]:
class JSD(nn.Module):
    #thanks to:
    #https://discuss.pytorch.org/t/jensen-shannon-divergence/2626/12
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)).log_softmax(-1), q.view(-1, q.size(-1)).log_softmax(-1)
        m = (0.5 * (p + q))
        return 0.5 * (self.kl(m, p) + self.kl(m, q))

In [22]:
temp = model
def train_lookup_embed_ds( model, ds, ds_config, vec_lib):
    optim = torch.optim.Adam(model.parameters(), lr=ds_config.lr)
    criterion = JSD()
    model.train()
    
    for epoch in range(5):#ds_config.epochs):
        print('training epoch ',epoch+1,'...')
        for i, batch in enumerate(ds):
            _input, label = batch
            
            _input = _input.to(device)
            optim.zero_grad()
            model.zero_grad()
            output = model(_input,use_lid=False)
            
            label = vec_lib[label]
            loss = criterion( output, label.detach().to(device))
            loss.backward()
            optim.step()
    return model
#for ds_config in config.class_data_ls:
#    if ds_config.train:
#        ds = classification_dataset(ds_config)
#        print('training ',ds_config.path)
#        temp = train_lookup_embed_ds( temp, ds, ds_config, sample_vec_lib)
#        print('finished.')

In [23]:
class LookupModel(nn.Module):
    def __init__(self, _model):
        super().__init__()
        self._model = _model
        self.vec_lib = None
    def forward( self, x, use_lid=False):
        output = self._model(x,use_lid=False)[:,None,:]
        output = output.repeat(1,self.vec_lib.shape[0],1)
        vec_lib = self.vec_lib.repeat(output.shape[0],1,1)
        product = output @ vec_lib.transpose(1,2)
        product = product[:,:1,:]
        value, index = torch.topk(product.squeeze(), k=1,dim=1)
        return self.vec_lib[index.squeeze()]
#m = LookupModel(model)
#m.vec_lib = sample_vec_lib
#m(torch.zeros(3,768)).shape

In [24]:
#model(torch.zeros(3,768))

In [25]:
#temp(torch.zeros(3,768))

In [26]:
#_eval(model)
#_eval(m)
#trained_temp = LookupModel(temp)
#trained_temp.vec_lib = sample_vec_lib
#_eval(trained_temp)

In [27]:
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git


import torch
import clip
from clip.clip import _download, _MODELS


model_path = _download(_MODELS['ViT-L/14@336px'], os.path.expanduser("~/.cache/clip"))
with open(model_path, 'rb') as opened_file:
    print('opening: ',model_path)
    clip_vit_l14_336 = torch.jit.load(opened_file, map_location=device).visual.eval()
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = clip_vit_l14_336
        
        self.fw = None
        self.pool = nn.AdaptiveAvgPool1d(64)
    def preprocess_image(self, x):
        x = transforms.functional.resize(x,size=[336, 336])
        x = x/255.0
        x = transforms.functional.normalize(x, 
                                            mean=[0.48145466, 0.4578275, 0.40821073],
                                            std=[0.26862954, 0.26130258, 0.27577711])
        return x
    
    def forward(self, x):
        x = self.preprocess_image(x)
        x = self.encoder(x.half())
        x = self.fw(x)
        #x = torch.nn.functional.normalize(x, p=2.0, dim=1, eps=1e-12)
        return x

sub = MyModel().to(device).eval()
sub.fw = model.fw.half()
#print(sub)
print(sub(torch.randn(1,3,336,336).to(device)).detach())

sub.eval()
saved_model = torch.jit.script(sub)
saved_model.save("saved_model.pt")
with ZipFile('submission.zip','w') as zip:           
    zip.write("saved_model.pt", arcname='saved_model.pt')
sub = torch.jit.load("saved_model.pt").to('cuda').eval()
input_batch = torch.rand(1, 3, 336, 336).to('cuda')
with torch.no_grad():
    embedding = sub(input_batch).cpu().data.numpy()
embedding

Collecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.1/53.1 kB[0m [31m670.5 kB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.1
[0mCollecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-_6h4rjl8
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-_6h4rjl8
  Resolved https://github.com/openai/CLIP.git to commit d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
  Preparing metadata (setup.py) ... [?25l- done
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l- \ | done
[?25h  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369409 sha256=980b176f29c705675b77647b70cdd6212d7afafb03baa4840a92d8c89fe0212b
  Stored in directory: /tmp/pip-ephem-wheel-cache-1wrcxmkw/wheels/fd/b9/c

100%|████████████████████████████████████████| 891M/891M [00:03<00:00, 245MiB/s]


opening:  /root/.cache/clip/ViT-L-14-336px.pt


To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  /usr/local/src/pytorch/aten/src/ATen/native/BinaryOps.cpp:601.)
  return forward_call(*input, **kwargs)


tensor([[ 1.9873, -7.2148, -2.6152,  7.3945, -1.9160,  3.2363, -2.2305, -4.0938,
         -1.5557,  1.3984,  1.6836,  3.2324, -5.2461,  0.1132, -4.1406, -4.7656,
         -0.2913,  2.2266,  3.8047, -0.4146, -2.6523, -2.4512, -2.6992,  1.2139,
         -2.3438,  1.5293, -2.9336, -0.7905, -3.1309, -0.7441, -2.2168,  0.8643,
         -2.3926,  1.1592, -0.1013, -4.0586,  1.6426,  2.7227,  2.6777,  0.4241,
          4.9297, -2.3809, -1.7236,  0.9590,  1.2266,  2.7754, -3.4023, -0.1344,
          0.6519, -1.4082,  3.6172,  0.6284, -1.3262,  3.9141,  2.5703,  6.6797,
         -2.7188,  4.3242, -0.4131, -2.0098, -6.7227,  1.3379,  0.9478,  0.0483]],
       device='cuda:0', dtype=torch.float16)


array([[ 5.473  , -4.97   , -2.479  ,  1.81   , -2.357  ,  3.297  ,
        -1.745  , -4.375  , -1.862  ,  2.08   , -0.1088 ,  1.363  ,
        -0.879  ,  3.889  , -0.2479 , -1.391  , -1.278  ,  2.8    ,
         4.906  , -2.104  , -3.885  ,  1.708  , -3.998  , -0.6377 ,
        -0.807  ,  1.625  , -0.741  , -3.172  , -0.678  ,  1.252  ,
         0.2866 ,  2.244  , -2.107  , -0.859  , -2.633  , -2.094  ,
         1.699  , -0.235  ,  0.4277 ,  0.06097,  2.965  , -1.166  ,
        -0.479  ,  3.441  ,  2.436  ,  2.23   , -4.184  , -0.3992 ,
         3.834  , -0.51   , -0.977  ,  0.8135 ,  0.5063 ,  2.121  ,
         3.283  ,  1.277  ,  0.606  ,  2.846  , -0.2996 , -1.479  ,
        -2.742  , -0.4424 , -1.549  , -0.528  ]], dtype=float16)

In [28]:
model.fw

Sequential(
  (0): Linear(in_features=768, out_features=64, bias=True)
)

In [29]:
def train_embed_ds(path_ds, model):
    
    loader = torch.utils.data.DataLoader(
        path_ds,
        batch_size=config.batch_size,
    )
    
    mse = torch.nn.MSELoss()
    
    #optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.000005, amsgrad=True)
    optim = torch.optim.SGD(model.fw.parameters(), lr=config.embed_lr)
    
    for i, batch in enumerate(loader):
        optim.zero_grad()
        model.zero_grad()
        model.train()
        output = model(batch.to(device))
        #print('ran model')
        
        
        label_vec = torch.sum( output, 0) / batch.shape[0] # calculate the average output vector
        # essentially, make the outputs of the model more similar to each other for each class
        # 'tighen' the vector output for the distribution of samples
        loss = mse( output, label_vec[None,:].repeat(batch.shape[0], 1)) * 10 
        print(loss.item())

        
        loss.backward()
        #print(model.fw[0].weight)
        #torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) # I have no idea what this does
        optim.step()
        #print(model.fw[0].weight)


if False:#config.train_class:
    
    top_fw = nn.Sequential(
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 768),
            nn.ReLU(),
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(64, 256),
    ).to(device)
    

In [30]:
if False:#config.train_embed:
    
    #train on imagenet classes
    #path = '/kaggle/input/imagenetmini-1000/imagenet-mini/train'
    #classes = os.listdir(path)
    #idx=0
    #for _class in classes[:config.imagenet_classes]:
    #    idx+=1
    #    print(' imagenetmini1000 | trainig class {} | {} / {} ...'.format( _class, idx, config.imagenet_classes))
    #    path_ds = load_imagenet_class(os.path.join(path, _class))
    #    train_class( path_ds, model)
    idx=0 
    for _class in get_imagenet1000_classes()[:config.imagenet_embed_classes]:
        idx+=1
        print(' imagenet1000 | training embed class {} | {} / {} ...'.format( _class, idx, 200))
        path_ds = load_imagenet1000_class(_class)
        train_embed_ds( path_ds, model)
    #train on caltech256 classes
    idx=0
    for _class in get_caltech256_classes()[:config.caltech256_embed_classes]:
        idx+=1
        print(' caltech256 | training embed class {} | {} / {} ...'.format( _class, idx, 20))
        path_ds = load_caltech256_class(_class)
        train_embed_ds( path_ds, model)
    idx=0
    for _class in get_fashion_classes()[:config.fashion_embed_classes]:
        idx+=1
        print(' caltech256 | training embed class {} | {} / {} ...'.format( _class, idx, config.fasion_embed_classes))
        path_ds = load_fashion_class(_class)
        train_embed_ds( path_ds, model)
        
        
    
    #train on google landmark recognition 2021 classes
    #idx=0
    #for _class in get_google_landmarks_2021_classes()[:500]:
    #    idx+=1
    #    print(' Google Landmarks 2021 | training class {} | {} / {}'.format( _class, idx, 500))
    #    path_ds = load_google_landmarks_2021_class(_class)
    #    train_class( path_ds, model)