In [67]:
import os
from functools import reduce
from operator import __add__
import numpy as np
import pytorch_lightning as pl
import torch
import umap
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F

from unsupervised_meta_learning.pl_dataloaders import (UnlabelledDataModule, get_episode_loader,
                                                       UnlabelledDataset)
from unsupervised_meta_learning.proto_utils import euclidean_distance, cosine_similarity, nt_xent_loss
from unsupervised_meta_learning.protoclr import ProtoCLR, get_train_images
from sklearnex import patch_sklearn
patch_sklearn()
from sklearn import cluster


Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [None]:
trainer = pl.Trainer(
    profiler='simple',
    max_epochs=2,
    limit_train_batches=100,
    fast_dev_run=False,
    limit_val_batches=15,
    limit_test_batches=600,
    num_sanity_val_steps=2, gpus=1,
)

In [None]:
lr_finder = trainer.tuner.lr_find(model, train_dataloader=dm)

In [None]:
kernel_sizes = (3, 3)
conv_padding = reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in kernel_sizes[::-1]])

class Encoder(nn.Module):
    def __init__(self, in_channels=1, hidden_size=64, out_channels=64):
        super().__init__()

        self.encoder = nn.Sequential(
            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(in_channels, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14 x 14

            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7

            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 3x3

            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(hidden_size, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 1x1
            # nn.Flatten()
        )

    def forward(self, inputs):
        return self.encoder(inputs)

class Decoder(nn.Module):
    def __init__(self, in_channels=1, hidden_size=64, out_channels=64):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.UpsamplingNearest2d(size=(4, 4)),
            nn.Conv2d(in_channels=out_channels,
                      out_channels=hidden_size, kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),

            nn.UpsamplingNearest2d(size=(7, 7)),
            nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size,
                      kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),

            nn.UpsamplingNearest2d(size=(14, 14)),
            nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size,
                      kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),

            nn.UpsamplingNearest2d(size=(28, 28)),
            nn.Conv2d(in_channels=hidden_size, out_channels=in_channels,
                      kernel_size=3, padding='same'),
            nn.BatchNorm2d(in_channels),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.decoder(inputs)
        
class AE(nn.Module):
    def __init__(self, in_channels=1, hidden_size=64, out_channels=64):
        super().__init__()

        self.encoder = Encoder(in_channels=in_channels, hidden_size=hidden_size, out_channels=out_channels)
        self.decoder = Decoder(in_channels=in_channels, hidden_size=hidden_size, out_channels=out_channels)

    def forward(self, inputs):
        print(inputs.shape)
        embeddings = self.encoder(inputs.view(-1, *inputs.shape[-3:]))
        print(embeddings.shape)
        recons = self.decoder(embeddings.unsqueeze(-1).unsqueeze(-1))
        return embeddings.view(*inputs.shape[:-3], -1), recons.view(*inputs.shape)

In [None]:
model = ProtoCLR(
    n_support=1, n_query=3, batch_size=50, distance='cosine', τ=.5,
    num_input_channels=1, decoder_class=Decoder, encoder_class=Encoder,
    lr_decay_step=25000, lr_decay_rate=.5, ae=True, gamma=1., log_images=True)
dataset_train = UnlabelledDataset(
    dataset='omniglot',
    datapath='./data/',
    split='train',
    n_support=1,
    n_query=0
)

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(2*16*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Decoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [None]:
class Autoencoder(pl.LightningModule):

    def __init__(self,
                 base_channel_size: int,
                 latent_dim: int,
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 num_input_channels: int = 1,
                 width: int = 28,
                 height: int = 28):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)

    def forward(self, x):
        """
        The forward function takes in an image and returns the reconstructed image
        """
        # print(x.shape)
        z = self.encoder(x.view(-1, *x.shape[-3:]))
        x_hat = self.decoder(z)
        # z = self.encoder(x)
        # x_hat = self.decoder(z)
        return x_hat.view(*x.shape)

    def _get_reconstruction_loss(self, batch, ways, n_supp, n_query):
        """
        Given a batch of images, this function returns the reconstruction loss (MSE in our case)
        """
        x, _ = batch # We do not need the labels
        x_hat = self.forward(x)

        x_supp = x[:,:ways * n_supp]
        r_supp = x_hat[:,:ways * n_supp]
        r_query = x_hat[:,ways * n_supp:]
        # print("####", r_supp.shape, r_query.shape)

        r_query = r_query.view(1, r_supp.shape[1], 3, 1, 28, 28)



        # loss = F.mse_loss(x.squeeze(0), x_hat, reduction='none').sum(dim=[1, 2, 3,]).mean(dim=[0])
        loss = F.mse_loss(
                r_query,
                torch.broadcast_to(x_supp.unsqueeze(2), r_query.shape),
                reduction='none').sum(dim=[1, 2, 3, 4, 5]).mean(dim=[0])
        # loss = F.mse_loss(
        #         r_supp, 
        #         x_supp,
        #         reduction='none').sum(dim=[1, 2, 3, 4,]).mean(dim=[0])

        # loss = F.mse_loss(x, x_hat, reduction="none")
        # loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=0.2,
                                                         patience=20,
                                                         min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}

    def training_step(self, batch, batch_idx):
        data = batch['data'].to(self.device) # [batch_size x ways x shots x image_dim]
        data = data.unsqueeze(0)
        # e.g. 50 images, 2 support, 2 query, miniImageNet: torch.Size([1, 50, 4, 3, 84, 84])
        batch_size = data.size(0)
        ways = data.size(1)
        x_support = data[:,:,:1]
        x_support = x_support.reshape((batch_size, ways * 1, *x_support.shape[-3:])) # e.g. [1,50*n_support,*(3,84,84)]
        x_query = data[:,:,1:]
        x_query = x_query.reshape((batch_size, ways * 3, *x_query.shape[-3:])) # e.g. [1,50*n_query,*(3,84,84)]
        # print(f'!!!! {x_support.shape}')
        # Create dummy query labels
        y_query = torch.arange(ways).unsqueeze(0).unsqueeze(2) # batch and shot dim
        y_query = y_query.repeat(batch_size, 1, 1)
        y_query = y_query.view(batch_size, -1).to('cuda')

        y_support = torch.arange(ways).unsqueeze(0).unsqueeze(2) # batch and shot dim
        y_support = y_support.repeat(batch_size, 1, 1)
        y_support = y_support.view(batch_size, -1).to('cuda')
        x = torch.cat([x_support, x_query], 1) # e.g. [1,50*(n_support+n_query),*(3,84,84)]

        loss = self._get_reconstruction_loss((x, y_support), ways, 1, 3)
        self.log('train_loss', loss)
        return loss

    # def validation_step(self, batch, batch_idx):
    #     loss = self._get_reconstruction_loss(batch)
    #     self.log('val_loss', loss)

    # def test_step(self, batch, batch_idx):
    #     loss = self._get_reconstruction_loss(batch)
    #     self.log('test_loss', loss)

In [None]:
emb_i = torch.tensor([[ 6.6270, 11.6813],
        [ 6.4168,  7.7024],
        [ 7.9163,  8.0348],
        [ 4.6681,  7.2211],
        [ 5.5762, 10.0012]])

emb_j = torch.tensor([[ 6.3032, 10.0734],
        [ 6.9568,  9.8194],
        [ 8.8078,  8.1451],
        [ 4.3386,  8.5691],
        [ 7.0590,  6.7981],
        [ 5.4431,  8.0148],
        [ 5.0362,  6.1590],
        [ 5.1802, 10.7874],
        [ 7.7798,  7.3308],
        [ 5.4246,  6.2691],
        [ 5.8096, 10.1320],
        [ 5.6292, 10.5785],
        [ 4.9312,  5.8408],
        [ 4.3520,  6.3029],
        [ 5.4147,  9.8841],
        [ 6.2309,  8.4320],
        [ 6.0696,  8.4501],
        [ 6.0402,  6.7177],
        [ 4.5353,  8.5051],
        [ 5.2821,  8.1228],
        [ 4.8968,  7.7489],
        [ 5.6221,  6.7439],
        [ 7.4615,  7.0588],
        [ 7.2516,  7.3432],
        [ 6.0943,  9.1108],
        [ 6.3836,  7.8931],
        [ 4.9303,  6.2700],
        [ 8.0341,  7.9363],
        [ 5.0337, 10.5701],
        [ 6.0564, 11.4888],
        [ 6.2713,  8.1638],
        [ 6.0115,  8.6252],
        [ 8.0858,  8.3572],
        [ 7.4930,  7.2508],
        [ 5.3719,  6.0493],
        [ 8.1590,  7.5640],
        [ 6.8972,  7.4481],
        [ 5.8029,  6.7042],
        [ 6.5177,  8.4637],
        [ 6.5807,  6.9593],
        [ 7.0109,  9.6658],
        [ 3.9612,  6.7887],
        [ 7.0019, 12.0692],
        [ 6.6533, 11.9660],
        [ 6.8464, 12.0303],
        [ 5.3894,  7.2251],
        [ 6.4877, 11.8622],
        [ 5.7942,  9.3818],
        [ 8.7622,  8.5982],
        [ 4.9367, 10.4112],
        [ 8.2535,  7.8129],
        [ 4.8408, 10.0947],
        [ 5.4412, 10.4591],
        [ 6.0605, 10.4132],
        [ 6.6809,  8.1679],
        [ 5.1636,  8.1599],
        [ 6.0977,  8.0572],
        [ 7.5062,  7.4989],
        [ 6.5523,  7.0610],
        [ 8.6864,  7.8542],
        [ 7.8051,  9.0486],
        [ 6.4663,  9.5589],
        [ 8.1567,  7.3624],
        [ 6.8541,  7.5712],
        [ 4.4631,  7.2664],
        [ 7.7564,  7.6146],
        [ 4.2792,  8.4721],
        [ 5.6382,  8.5117],
        [ 4.7060,  7.7046],
        [ 7.2034,  8.9113],
        [ 8.4993,  7.6508],
        [ 6.4934,  9.8171],
        [ 7.7694,  8.6558],
        [ 6.1207,  6.5721],
        [ 5.0343,  7.8402],
        [ 7.4249,  7.1609],
        [ 4.5241,  6.9034],
        [ 6.5902,  7.3029],
        [ 6.9436, 12.0691],
        [ 6.6224, 11.9718],
        [ 6.5599, 12.0108],
        [ 6.4960,  8.6471],
        [ 4.2922,  7.0729],
        [ 4.2204,  6.7476],
        [ 7.2825,  8.0695],
        [ 5.0677,  9.7671],
        [ 4.0716,  7.9130],
        [ 7.6866,  7.9514],
        [ 6.5860, 10.6934],
        [ 6.4237,  7.7053],
        [ 4.9411,  7.0874],
        [ 6.1804,  7.2141],
        [ 5.3615,  6.0960],
        [ 4.8339,  7.1372],
        [ 6.9698,  8.4045],
        [ 7.0948,  7.6626],
        [ 5.8325,  7.4330],
        [ 4.1195,  7.4736],
        [ 4.3055,  7.2621],
        [ 6.2035, 11.6645],
        [ 6.8866,  6.4410],
        [ 5.5439, 10.7403],
        [ 6.3185,  7.0267],
        [ 6.6733,  7.9577],
        [ 5.0281,  6.9473],
        [ 6.2958, 10.8109],
        [ 5.4935, 10.6249],
        [ 5.3473, 10.2931],
        [ 8.0300,  7.0989],
        [ 8.0298,  7.2092],
        [ 8.5735,  8.2420],
        [ 4.6567,  8.3166],
        [ 4.4944,  8.3375],
        [ 4.2823,  6.1382],
        [ 6.2566, 10.9357],
        [ 6.1064, 10.7109],
        [ 6.1095, 10.9985],
        [ 4.6992,  5.9350],
        [ 4.6401,  8.5465],
        [ 4.8380,  6.1790],
        [ 4.3196,  7.8004],
        [ 4.9383,  9.5967],
        [ 7.2132,  8.6334],
        [ 8.6895,  7.8514],
        [ 7.7159,  7.6624],
        [ 8.1368,  7.7046],
        [ 6.6882,  8.9046],
        [ 7.6113,  8.5579],
        [ 4.9312,  8.7897],
        [ 6.9077,  7.7897],
        [ 8.3960,  8.0787],
        [ 7.5751,  9.3934],
        [ 5.2499,  9.0045],
        [ 5.8971, 10.0067],
        [ 4.7151, 10.0582],
        [ 6.7354,  7.5874],
        [ 6.4488,  8.1458],
        [ 5.1967,  7.3349],
        [ 4.6785,  6.5824],
        [ 4.0479,  7.6312],
        [ 6.2214,  6.7932],
        [ 5.4997, 10.0227],
        [ 5.1693,  6.3969],
        [ 5.0809, 10.1422],
        [ 4.8012,  9.8925],
        [ 6.9810,  8.2349],
        [ 8.0005,  8.0630],
        [ 6.9695, 11.8812],
        [ 5.6798, 10.3008],
        [ 6.1110, 10.3863]])

In [None]:
z_i = torch.tensor([[0.4934, 0.8698],
        [0.6401, 0.7683],
        [0.7018, 0.7123],
        [0.5429, 0.8398],
        [0.4870, 0.8734]])

z_j = torch.tensor([[0.5304, 0.8477],
        [0.5781, 0.8160],
        [0.7342, 0.6789],
        [0.4517, 0.8922],
        [0.7203, 0.6937],
        [0.5618, 0.8273],
        [0.6330, 0.7741],
        [0.4329, 0.9014],
        [0.7278, 0.6858],
        [0.6543, 0.7562],
        [0.4974, 0.8675],
        [0.4698, 0.8828],
        [0.6451, 0.7641],
        [0.5682, 0.8229],
        [0.4804, 0.8770],
        [0.5943, 0.8042],
        [0.5834, 0.8122],
        [0.6686, 0.7436],
        [0.4705, 0.8824],
        [0.5452, 0.8383],
        [0.5342, 0.8454],
        [0.6403, 0.7681],
        [0.7264, 0.6872],
        [0.7027, 0.7115],
        [0.5560, 0.8312],
        [0.6288, 0.7775],
        [0.6181, 0.7861],
        [0.7114, 0.7028],
        [0.4300, 0.9028],
        [0.4663, 0.8846],
        [0.6092, 0.7930],
        [0.5718, 0.8204],
        [0.6953, 0.7187],
        [0.7186, 0.6954],
        [0.6640, 0.7477],
        [0.7333, 0.6799],
        [0.6795, 0.7337],
        [0.6544, 0.7561],
        [0.6101, 0.7923],
        [0.6871, 0.7266],
        [0.5871, 0.8095],
        [0.5040, 0.8637],
        [0.5018, 0.8650],
        [0.4859, 0.8740],
        [0.4946, 0.8691],
        [0.5979, 0.8016],
        [0.4798, 0.8774],
        [0.5255, 0.8508],
        [0.7138, 0.7004],
        [0.4284, 0.9036],
        [0.7262, 0.6875],
        [0.4324, 0.9017],
        [0.4615, 0.8871],
        [0.5030, 0.8643],
        [0.6331, 0.7740],
        [0.5347, 0.8450],
        [0.6035, 0.7974],
        [0.7075, 0.7068],
        [0.6802, 0.7330],
        [0.7417, 0.6707],
        [0.6532, 0.7572],
        [0.5603, 0.8283],
        [0.7423, 0.6700],
        [0.6711, 0.7413],
        [0.5234, 0.8521],
        [0.7136, 0.7006],
        [0.4508, 0.8926],
        [0.5522, 0.8337],
        [0.5213, 0.8534],
        [0.6286, 0.7777],
        [0.7432, 0.6690],
        [0.5517, 0.8341],
        [0.6680, 0.7442],
        [0.6815, 0.7318],
        [0.5403, 0.8415],
        [0.7198, 0.6942],
        [0.5481, 0.8364],
        [0.6700, 0.7424],
        [0.4987, 0.8668],
        [0.4840, 0.8750],
        [0.4793, 0.8776],
        [0.6006, 0.7995],
        [0.5188, 0.8549],
        [0.5303, 0.8478],
        [0.6700, 0.7424],
        [0.4606, 0.8876],
        [0.4575, 0.8892],
        [0.6950, 0.7190],
        [0.5244, 0.8515],
        [0.6403, 0.7681],
        [0.5719, 0.8203],
        [0.6506, 0.7594],
        [0.6604, 0.7509],
        [0.5608, 0.8280],
        [0.6383, 0.7697],
        [0.6794, 0.7338],
        [0.6173, 0.7867],
        [0.4827, 0.8758],
        [0.5100, 0.8602],
        [0.4695, 0.8829],
        [0.7303, 0.6831],
        [0.4587, 0.8886],
        [0.6686, 0.7436],
        [0.6426, 0.7662],
        [0.5863, 0.8101],
        [0.5032, 0.8641],
        [0.4593, 0.8883],
        [0.4610, 0.8874],
        [0.7492, 0.6623],
        [0.7441, 0.6681],
        [0.7209, 0.6930],
        [0.4886, 0.8725],
        [0.4745, 0.8803],
        [0.5722, 0.8201],
        [0.4966, 0.8680],
        [0.4953, 0.8687],
        [0.4856, 0.8742],
        [0.6208, 0.7840],
        [0.4771, 0.8788],
        [0.6165, 0.7874],
        [0.4844, 0.8748],
        [0.4576, 0.8892],
        [0.6412, 0.7674],
        [0.7420, 0.6704],
        [0.7096, 0.7046],
        [0.7261, 0.6876],
        [0.6006, 0.7996],
        [0.6646, 0.7472],
        [0.4893, 0.8721],
        [0.6635, 0.7482],
        [0.7206, 0.6934],
        [0.6277, 0.7784],
        [0.5037, 0.8639],
        [0.5077, 0.8615],
        [0.4245, 0.9054],
        [0.6639, 0.7478],
        [0.6207, 0.7840],
        [0.5781, 0.8160],
        [0.5793, 0.8151],
        [0.4686, 0.8834],
        [0.6754, 0.7375],
        [0.4811, 0.8767],
        [0.6285, 0.7778],
        [0.4479, 0.8941],
        [0.4366, 0.8996],
        [0.6466, 0.7628],
        [0.7043, 0.7099],
        [0.5060, 0.8625],
        [0.4829, 0.8757],
        [0.5071, 0.8619]])

In [None]:
(z_i @ z_j.T).t().contiguous().shape

In [None]:
z = torch.cat([z_i, z_j], dim=0)

In [None]:
sim = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2) / .5

In [None]:
sim.shape

In [None]:
sim = sim.t().contiguous()

In [None]:
sim_i_j = torch.diag(sim)

In [None]:
list(range(0, 5))

In [None]:
sim_i_j = sim[0, 5-1]

In [None]:
sim_i_j

In [None]:
numerator = torch.exp(sim_i_j)
numerator

In [None]:
denominator = torch.sum(
    torch.ones((5, )).scatter_(0, torch.tensor([0]), 0.0).bool() * torch.exp(sim[0,:])
)
denominator

In [None]:
loss_ij = -torch.log(numerator / denominator)
loss_ij

In [None]:
####### for ji?

sim_i_j_2 = sim[0 + 5- 1, 0]
sim_i_j_2

In [None]:
numerator2 = torch.exp(sim_i_j_2)
numerator2

In [None]:
denominator2 = torch.sum(
    torch.ones((5, )).scatter_(0, torch.tensor([4]), 0.0).bool() * torch.exp(sim[4,:])
)
denominator2

In [None]:
loss_ij2 = -torch.log(numerator2 / denominator2)

In [None]:
loss_ij, loss_ij2

In [None]:
def l_ij(i, j):
    # shape of sim initially is (n_clusters, n_aug_images)
    # so we have (5, 150) - similarity between each cluster and augmented image
    # transposed it becomes (150, 5) - similarity between each augmented image and cluster
    sim = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2) / .5
    sim = sim.t().contiguous()

    # numerator math
    sim_i_j = sim[i, j]
    print(f"sim({i}, {j})={sim_i_j}")

    numerator = torch.exp(sim_i_j)
    print("Numerator", numerator)
    

    # denominator math
    # because there are 5 classes

    mask = torch.ones((5, )).scatter(0, torch.tensor([i]), 0.0).bool()
    print(f"1{{k!={i}}}", mask)
    denominator = torch.sum(
        mask * torch.exp(sim[i,:])
    )
    print("Denominator", denominator)
    loss_ij = -torch.log(numerator / denominator)
    print(f"loss({i},{j})={loss_ij}\n")
    return loss_ij.squeeze(0)

In [None]:
N = 5 # n_clusters
loss = 0.
for k in range(0, N):
    loss += l_ij(k, k + N - 1) + l_ij(k + N - 1, k)
print(loss)

In [None]:
class_num = 5
N = 2 * class_num
mask = torch.ones((N, N))
mask = mask.fill_diagonal_(0)
for i in range(class_num):
    mask[i, class_num + i] = 0
    mask[class_num + i, i] = 0
    mask = mask.bool()
mask.shape

In [None]:
p_i = z_i.sum(0).view(-1)

In [None]:
z_i

In [None]:
z_i.unsqueeze(1).shape, z_j.unsqueeze(0).shape

In [None]:
# dists = torch.sum((z_i.unsqueeze(1) - z_j.unsqueeze(0))** 2, dim=-1)
dists = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2) / 0.5
dists.shape

In [None]:
raw_labels = torch.tensor([[1, 1, 4, 4, 4, 1, 0, 3, 4, 4, 2, 3, 4, 4, 1, 2, 1, 0, 4, 0, 0, 2,
       1, 4, 3, 2, 1, 4, 3, 2, 0, 1, 4, 0, 0, 0, 4, 1, 0, 0, 0, 1, 0, 4,
       2, 2, 4, 3, 0, 0, 1, 1, 1, 2, 1, 2, 0, 4, 0, 0, 2, 4, 2, 2, 4, 1,
       3, 3, 4, 3, 0, 3, 3, 2, 0, 4, 4, 4, 4, 4, 3, 2, 0, 3, 1, 3, 4, 4,
       4, 3, 2, 3, 1, 1, 2, 1, 3, 3, 3, 2, 1, 0, 4, 3, 4, 4, 4, 3, 4, 4,
       0, 0, 0, 4, 4, 0, 2, 2, 2, 4, 4, 4, 1, 1, 2, 2, 2, 1, 0, 1, 4, 4,
       3, 3, 3, 3, 2, 3, 3, 1, 2, 2, 1, 1, 1, 1, 4, 3, 0, 0, 0, 1, 1, 3,
       0, 3, 4, 3, 4, 4, 2, 0, 2, 0, 3, 3, 0, 0, 2, 2, 4, 0, 0, 1, 1, 2,
       2, 1, 1, 0, 0, 4, 3, 2, 2, 3, 0, 0, 1, 4, 2, 2, 2, 2, 3, 4, 1, 4,
       0, 1]])
query_labels = raw_labels[:, 50 * 1 :]

In [None]:
augtarget = torch.tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,
          6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 11,
         12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17,
         18, 18, 18, 19, 19, 19, 20, 20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23,
         24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29,
         30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35,
         36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41,
         42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47,
         48, 48, 48, 49, 49, 49]])

In [None]:
z_j.shape

In [None]:
query_labels.shape

In [None]:
dists.unsqueeze(0).shape

In [None]:
F.cross_entropy(dists.unsqueeze(0), torch.Tensor([[0 for i in range(150)]]).long(), reduction='sum')

In [None]:
F.cross_entropy(dists.unsqueeze(0), query_labels, reduction='sum')

In [None]:
raw_labels.max()

In [5]:
rz = torch.load('data/raw_embeddings.pt').cpu()

In [13]:
rz.squeeze(0).shape

torch.Size([200, 64])

In [14]:
mapper = umap.UMAP(n_components=3, random_state=42).fit(rz.detach().squeeze())

In [18]:
mrz = mapper.transform(rz.detach().squeeze())

In [23]:
mrz_support = mrz[: 50 * 1]
# e.g. [1,50*n_query,*(3,84,84)]
mrz_query = mrz[50 * 1 :]

In [32]:
clf = cluster.KMeans(n_clusters=5)
pred_labels = clf.fit_predict(mrz)

In [34]:
centroids = clf.cluster_centers_

In [38]:
centroids = torch.from_numpy(mapper.inverse_transform(centroids)).unsqueeze(0)

In [40]:
centroids.shape

torch.Size([1, 5, 64])

In [47]:
rz_query = rz[:, 50:,:]

In [71]:
centroids.shape, rz_query.shape

(torch.Size([1, 5, 64]), torch.Size([1, 150, 64]))

In [82]:
dists = cosine_similarity(centroids, rz_query)
dists.shape, torch.from_numpy(pred_labels[50:]).unsqueeze(0).shape

(torch.Size([1, 5, 150]), torch.Size([1, 150]))

In [87]:
F.cross_entropy(dists, torch.from_numpy(pred_labels[50:]).unsqueeze(0).long(), reduction='mean')

tensor(1.5604, grad_fn=<NllLoss2DBackward>)

In [92]:
def nt_xent_loss(z_i, z_j, query_labels, temperature=.5, reduction='mean'):
    N = z_j.shape[1]
    # calculating distance from every centroid to every augmented image
    dists = cosine_similarity(z_i, z_j) / temperature
    labels = torch.zeros(N).to(dists.device).long()
    print(dists.shape, query_labels.shape)
    loss = F.cross_entropy(dists, labels.unsqueeze(0), reduction=reduction)
    return loss

In [93]:
nt_xent_loss(centroids, rz_query, torch.from_numpy(pred_labels[50:]).long(), temperature=1., reduction='mean')

torch.Size([1, 5, 150]) torch.Size([150])


tensor(1.6147, grad_fn=<NllLoss2DBackward>)