In [74]:
import torchmeta
from torchmeta.datasets.helpers import omniglot
from functools import reduce
from operator import __add__
import numpy as np
import pytorch_lightning as pl
from torchvision.transforms import Compose, Resize, ToTensor
import torch
import umap
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
from unsupervised_meta_learning.pl_dataloaders import (UnlabelledDataModule, get_episode_loader,
                                                       UnlabelledDataset)
from unsupervised_meta_learning.proto_utils import euclidean_distance, cosine_similarity, nt_xent_loss, cluster_diff_loss
from unsupervised_meta_learning.protoclr import ProtoCLR, get_train_images
from sklearnex import patch_sklearn
patch_sklearn()
from sklearn import cluster


Intel(R) Extension for Scikit-learn* enabled (https://github.com/intel/scikit-learn-intelex)


In [57]:
dataset = torchmeta.datasets.Omniglot('./data/untarred', num_classes_per_task=5, transform=Compose([Resize(28), ToTensor()]), meta_train=True, use_vinyals_split=True, )

In [58]:
dataset = torchmeta.transforms.ClassSplitter(dataset, shuffle=True, num_train_per_class=1,)

In [138]:
dataloader = torchmeta.utils.data.BatchMetaDataLoader(dataset, batch_size=1, num_workers=0)

In [69]:
xs = next(iter(dataloader))

In [99]:
dataset = omniglot("./data/untarred", ways=5, shots=1, use_vinyals_split=True,meta_train=True, download=True)
dataloader = torchmeta.utils.data.BatchMetaDataLoader(dataset, batch_size=1, num_workers=4, shuffle=True)

In [139]:
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_descriptor')

In [142]:
data = []
targets  = []
cntr = 0
for xs in dataloader:
    if cntr > 5000:
        break
    else:
        cntr += 1
    d, t = xs['train']
    data.append(d.squeeze(0))
    targets.append(t)


In [143]:
torch.cat(data).shape

torch.Size([25005, 1, 28, 28])

In [145]:
torch.cat(targets)

tensor([[3, 1, 2, 0, 4],
        [0, 3, 1, 4, 2],
        [4, 0, 1, 3, 2],
        ...,
        [3, 2, 1, 4, 0],
        [4, 3, 2, 1, 0],
        [4, 3, 1, 0, 2]])

In [117]:
torch.cat(targets).flatten()

tensor([3, 2, 0,  ..., 2, 3, 1])

In [2]:
dm = UnlabelledDataModule(
    "omniglot",
    "./data/untarred",
    split="train",
    transform=None,
    n_support=1,
    n_query=3,
    n_images=None,
    n_classes=None,
    batch_size=50,
    seed=10,
    mode="trainval",
    num_workers=0,
    eval_ways=5,
    eval_support_shots=1,
    eval_query_shots=15,
)


In [3]:
dataset_train = UnlabelledDataset(
            'omniglot',
            'data/untarred/',
            split="train",
            transform=None,
            n_images=None,
            n_classes=5,
            n_support=1,
            n_query=3,
            no_aug_support=False,
            no_aug_query=False,
)

5 100
<class 'numpy.ndarray'>


In [4]:
dataset_train.data[0][0].shape

(105,)

In [5]:
len(dataset_train.data)

100

In [6]:
dm.setup()

1028 20560
<class 'numpy.ndarray'>
172 3440
<class 'numpy.ndarray'>


In [None]:
trainer = pl.Trainer(
    profiler='simple',
    max_epochs=2,
    limit_train_batches=100,
    fast_dev_run=False,
    limit_val_batches=15,
    limit_test_batches=600,
    num_sanity_val_steps=2, gpus=1,
)

In [None]:
kernel_sizes = (3, 3)
conv_padding = reduce(__add__, [(k // 2 + (k - 2 * (k // 2)) - 1, k // 2) for k in kernel_sizes[::-1]])

class Encoder(nn.Module):
    def __init__(self, in_channels=1, hidden_size=64, out_channels=64):
        super().__init__()

        self.encoder = nn.Sequential(
            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(in_channels, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 14 x 14

            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 7x7

            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 3x3

            # nn.ZeroPad2d(conv_padding),
            nn.Conv2d(hidden_size, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.MaxPool2d(2),  # 1x1
            # nn.Flatten()
        )

    def forward(self, inputs):
        return self.encoder(inputs)

class Decoder(nn.Module):
    def __init__(self, in_channels=1, hidden_size=64, out_channels=64):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.UpsamplingNearest2d(size=(4, 4)),
            nn.Conv2d(in_channels=out_channels,
                      out_channels=hidden_size, kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),

            nn.UpsamplingNearest2d(size=(7, 7)),
            nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size,
                      kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),

            nn.UpsamplingNearest2d(size=(14, 14)),
            nn.Conv2d(in_channels=hidden_size, out_channels=hidden_size,
                      kernel_size=3, padding='same'),
            nn.BatchNorm2d(hidden_size),
            nn.ReLU(),

            nn.UpsamplingNearest2d(size=(28, 28)),
            nn.Conv2d(in_channels=hidden_size, out_channels=in_channels,
                      kernel_size=3, padding='same'),
            nn.BatchNorm2d(in_channels),
            nn.Sigmoid()
        )

    def forward(self, inputs):
        return self.decoder(inputs)
        
class AE(nn.Module):
    def __init__(self, in_channels=1, hidden_size=64, out_channels=64):
        super().__init__()

        self.encoder = Encoder(in_channels=in_channels, hidden_size=hidden_size, out_channels=out_channels)
        self.decoder = Decoder(in_channels=in_channels, hidden_size=hidden_size, out_channels=out_channels)

    def forward(self, inputs):
        print(inputs.shape)
        embeddings = self.encoder(inputs.view(-1, *inputs.shape[-3:]))
        print(embeddings.shape)
        recons = self.decoder(embeddings.unsqueeze(-1).unsqueeze(-1))
        return embeddings.view(*inputs.shape[:-3], -1), recons.view(*inputs.shape)

In [None]:
model = ProtoCLR(
    n_support=1, n_query=3, batch_size=50, distance='cosine', τ=.5,
    num_input_channels=1, decoder_class=Decoder, encoder_class=Encoder,
    lr_decay_step=25000, lr_decay_rate=.5, ae=True, gamma=1., log_images=True)
dataset_train = UnlabelledDataset(
    dataset='omniglot',
    datapath='./data/',
    split='train',
    n_support=1,
    n_query=0
)

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of input channels of the image. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the first convolutional layers. Deeper layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the encoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.net = nn.Sequential(
            nn.Conv2d(num_input_channels, c_hid, kernel_size=3, padding=1, stride=2), # 32x32 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 16x16 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1, stride=2), # 8x8 => 4x4
            act_fn(),
            nn.Flatten(), # Image grid to single feature vector
            nn.Linear(2*16*c_hid, latent_dim)
        )

    def forward(self, x):
        return self.net(x)

In [None]:
class Decoder(nn.Module):

    def __init__(self,
                 num_input_channels : int,
                 base_channel_size : int,
                 latent_dim : int,
                 act_fn : object = nn.GELU):
        """
        Inputs:
            - num_input_channels : Number of channels of the image to reconstruct. For CIFAR, this parameter is 3
            - base_channel_size : Number of channels we use in the last convolutional layers. Early layers might use a duplicate of it.
            - latent_dim : Dimensionality of latent representation z
            - act_fn : Activation function used throughout the decoder network
        """
        super().__init__()
        c_hid = base_channel_size
        self.linear = nn.Sequential(
            nn.Linear(latent_dim, 2*16*c_hid),
            act_fn()
        )
        self.net = nn.Sequential(
            nn.ConvTranspose2d(2*c_hid, 2*c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 4x4 => 8x8
            act_fn(),
            nn.Conv2d(2*c_hid, 2*c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(2*c_hid, c_hid, kernel_size=3, output_padding=1, padding=1, stride=2), # 8x8 => 16x16
            act_fn(),
            nn.Conv2d(c_hid, c_hid, kernel_size=3, padding=1),
            act_fn(),
            nn.ConvTranspose2d(c_hid, num_input_channels, kernel_size=3, output_padding=1, padding=1, stride=2), # 16x16 => 32x32
            nn.Tanh() # The input images is scaled between -1 and 1, hence the output has to be bounded as well
        )

    def forward(self, x):
        x = self.linear(x)
        x = x.reshape(x.shape[0], -1, 4, 4)
        x = self.net(x)
        return x

In [None]:
class Autoencoder(pl.LightningModule):

    def __init__(self,
                 base_channel_size: int,
                 latent_dim: int,
                 encoder_class : object = Encoder,
                 decoder_class : object = Decoder,
                 num_input_channels: int = 1,
                 width: int = 28,
                 height: int = 28):
        super().__init__()
        # Saving hyperparameters of autoencoder
        self.save_hyperparameters()
        # Creating encoder and decoder
        self.encoder = encoder_class(num_input_channels, base_channel_size, latent_dim)
        self.decoder = decoder_class(num_input_channels, base_channel_size, latent_dim)
        # Example input array needed for visualizing the graph of the network
        self.example_input_array = torch.zeros(2, num_input_channels, width, height)

    def forward(self, x):
        """
        The forward function takes in an image and returns the reconstructed image
        """
        # print(x.shape)
        z = self.encoder(x.view(-1, *x.shape[-3:]))
        x_hat = self.decoder(z)
        # z = self.encoder(x)
        # x_hat = self.decoder(z)
        return x_hat.view(*x.shape)

    def _get_reconstruction_loss(self, batch, ways, n_supp, n_query):
        """
        Given a batch of images, this function returns the reconstruction loss (MSE in our case)
        """
        x, _ = batch # We do not need the labels
        x_hat = self.forward(x)

        x_supp = x[:,:ways * n_supp]
        r_supp = x_hat[:,:ways * n_supp]
        r_query = x_hat[:,ways * n_supp:]
        # print("####", r_supp.shape, r_query.shape)

        r_query = r_query.view(1, r_supp.shape[1], 3, 1, 28, 28)



        # loss = F.mse_loss(x.squeeze(0), x_hat, reduction='none').sum(dim=[1, 2, 3,]).mean(dim=[0])
        loss = F.mse_loss(
                r_query,
                torch.broadcast_to(x_supp.unsqueeze(2), r_query.shape),
                reduction='none').sum(dim=[1, 2, 3, 4, 5]).mean(dim=[0])
        # loss = F.mse_loss(
        #         r_supp, 
        #         x_supp,
        #         reduction='none').sum(dim=[1, 2, 3, 4,]).mean(dim=[0])

        # loss = F.mse_loss(x, x_hat, reduction="none")
        # loss = loss.sum(dim=[1,2,3]).mean(dim=[0])
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        # Using a scheduler is optional but can be helpful.
        # The scheduler reduces the LR if the validation performance hasn't improved for the last N epochs
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                         mode='min',
                                                         factor=0.2,
                                                         patience=20,
                                                         min_lr=5e-5)
        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "train_loss"}

    def training_step(self, batch, batch_idx):
        data = batch['data'].to(self.device) # [batch_size x ways x shots x image_dim]
        data = data.unsqueeze(0)
        # e.g. 50 images, 2 support, 2 query, miniImageNet: torch.Size([1, 50, 4, 3, 84, 84])
        batch_size = data.size(0)
        ways = data.size(1)
        x_support = data[:,:,:1]
        x_support = x_support.reshape((batch_size, ways * 1, *x_support.shape[-3:])) # e.g. [1,50*n_support,*(3,84,84)]
        x_query = data[:,:,1:]
        x_query = x_query.reshape((batch_size, ways * 3, *x_query.shape[-3:])) # e.g. [1,50*n_query,*(3,84,84)]
        # print(f'!!!! {x_support.shape}')
        # Create dummy query labels
        y_query = torch.arange(ways).unsqueeze(0).unsqueeze(2) # batch and shot dim
        y_query = y_query.repeat(batch_size, 1, 1)
        y_query = y_query.view(batch_size, -1).to('cuda')

        y_support = torch.arange(ways).unsqueeze(0).unsqueeze(2) # batch and shot dim
        y_support = y_support.repeat(batch_size, 1, 1)
        y_support = y_support.view(batch_size, -1).to('cuda')
        x = torch.cat([x_support, x_query], 1) # e.g. [1,50*(n_support+n_query),*(3,84,84)]

        loss = self._get_reconstruction_loss((x, y_support), ways, 1, 3)
        self.log('train_loss', loss)
        return loss

    # def validation_step(self, batch, batch_idx):
    #     loss = self._get_reconstruction_loss(batch)
    #     self.log('val_loss', loss)

    # def test_step(self, batch, batch_idx):
    #     loss = self._get_reconstruction_loss(batch)
    #     self.log('test_loss', loss)

In [20]:
euclidean_distance(z_j_1, z_j_1).shape

torch.Size([27, 2])

In [23]:
torch.sum((z_j_1.unsqueeze(0) - z_j_1.unsqueeze(1))** 2, dim=-1).shape

torch.Size([27, 27])

In [15]:
z_j_1.unsqueeze(1).shape, z_j_1.unsqueeze(0).shape

(torch.Size([27, 1, 2]), torch.Size([1, 27, 2]))

In [86]:
sim = F.cosine_similarity(z_j_1.unsqueeze(1), z_j_1.unsqueeze(0), dim=2) / .5
sim.shape

torch.Size([27, 27])

In [73]:
F.cross_entropy(sim, torch.Tensor([0 for i in range(27)]).long())

tensor(3.2901)

In [68]:
torch.diag(sim, 1)

tensor([0.9916, 0.9948, 0.9888, 0.9997, 0.9840, 0.9487, 0.9652, 0.9932, 0.9709,
        0.9991, 0.9631, 0.9999, 0.9989, 0.9460, 0.9816, 0.9933, 0.9951, 0.9948,
        0.9868, 0.9812, 0.9757, 0.9996, 0.9969, 0.9909, 0.9986, 0.9937])

In [88]:
loss = 0.
temperature = .5
for label in uniq_labels:
    z_j_t = z_j_labels[z_j_labels[:, 2] == label][:, :2]
    sim = F.cosine_similarity(z_j_t.unsqueeze(1), z_j_t.unsqueeze(0), dim=2) / temperature
    loss += F.cross_entropy(sim, torch.Tensor([1/temperature for i in range(sim.shape[0])]).long())

loss

tensor(16.9763)

In [74]:
sim = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2).t().contiguous() / .5
sim.shape

torch.Size([150, 150])

In [100]:
i = 0
j = 0
label_i = query_labels[i]
numerator = torch.exp(sim[i, label_i])
numerator

tensor(7.2548)

In [107]:
denominator = torch.exp(sim[1:, label_i]).sum()
denominator

tensor(1081.2456)

In [109]:
-torch.log(numerator/denominator)

tensor(5.0042)

In [44]:
z_i / 2

tensor([[0.2467, 0.4349],
        [0.3201, 0.3841],
        [0.3509, 0.3562],
        [0.2715, 0.4199],
        [0.2435, 0.4367]])

In [41]:
z_i.shape, z_i.unsqueeze(0).shape

(torch.Size([5, 2]), torch.Size([1, 5, 2]))

In [46]:
z_j.shape, z_j.unsqueeze(1).shape

(torch.Size([150, 2]), torch.Size([150, 1, 2]))

In [48]:
torch.sum((z_i.unsqueeze(1) - z_j.unsqueeze(0))**2, dim=-1).unsqueeze(0).shape

torch.Size([1, 5, 150])

In [None]:
F.cross_entropy(dists, labels)

In [None]:
z = torch.cat([z_i, z_j], dim=0)

In [None]:
sim = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2) / .5

In [None]:
sim.shape

In [None]:
sim = sim.t().contiguous()

In [None]:
sim_i_j = torch.diag(sim)

In [None]:
list(range(0, 5))

In [None]:
sim_i_j = sim[0, 5-1]

In [None]:
sim_i_j

In [None]:
numerator = torch.exp(sim_i_j)
numerator

In [None]:
denominator = torch.sum(
    torch.ones((5, )).scatter_(0, torch.tensor([0]), 0.0).bool() * torch.exp(sim[0,:])
)
denominator

In [None]:
loss_ij = -torch.log(numerator / denominator)
loss_ij

In [None]:
####### for ji?

sim_i_j_2 = sim[0 + 5- 1, 0]
sim_i_j_2

In [None]:
numerator2 = torch.exp(sim_i_j_2)
numerator2

In [None]:
denominator2 = torch.sum(
    torch.ones((5, )).scatter_(0, torch.tensor([4]), 0.0).bool() * torch.exp(sim[4,:])
)
denominator2

In [None]:
loss_ij2 = -torch.log(numerator2 / denominator2)

In [None]:
loss_ij, loss_ij2

In [None]:
def l_ij(i, j):
    # shape of sim initially is (n_clusters, n_aug_images)
    # so we have (5, 150) - similarity between each cluster and augmented image
    # transposed it becomes (150, 5) - similarity between each augmented image and cluster
    sim = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2) / .5
    sim = sim.t().contiguous()

    # numerator math
    sim_i_j = sim[i, j]
    print(f"sim({i}, {j})={sim_i_j}")

    numerator = torch.exp(sim_i_j)
    print("Numerator", numerator)
    

    # denominator math
    # because there are 5 classes

    mask = torch.ones((5, )).scatter(0, torch.tensor([i]), 0.0).bool()
    print(f"1{{k!={i}}}", mask)
    denominator = torch.sum(
        mask * torch.exp(sim[i,:])
    )
    print("Denominator", denominator)
    loss_ij = -torch.log(numerator / denominator)
    print(f"loss({i},{j})={loss_ij}\n")
    return loss_ij.squeeze(0)

In [None]:
N = 5 # n_clusters
loss = 0.
for k in range(0, N):
    loss += l_ij(k, k + N - 1) + l_ij(k + N - 1, k)
print(loss)

In [None]:
class_num = 5
N = 2 * class_num
mask = torch.ones((N, N))
mask = mask.fill_diagonal_(0)
for i in range(class_num):
    mask[i, class_num + i] = 0
    mask[class_num + i, i] = 0
    mask = mask.bool()
mask.shape

In [None]:
p_i = z_i.sum(0).view(-1)

In [None]:
z_i

In [None]:
z_i.unsqueeze(1).shape, z_j.unsqueeze(0).shape

In [None]:
# dists = torch.sum((z_i.unsqueeze(1) - z_j.unsqueeze(0))** 2, dim=-1)
dists = F.cosine_similarity(z_i.unsqueeze(1), z_j.unsqueeze(0), dim=2) / 0.5
dists.shape

In [None]:
raw_labels = torch.tensor([[1, 1, 4, 4, 4, 1, 0, 3, 4, 4, 2, 3, 4, 4, 1, 2, 1, 0, 4, 0, 0, 2,
       1, 4, 3, 2, 1, 4, 3, 2, 0, 1, 4, 0, 0, 0, 4, 1, 0, 0, 0, 1, 0, 4,
       2, 2, 4, 3, 0, 0, 1, 1, 1, 2, 1, 2, 0, 4, 0, 0, 2, 4, 2, 2, 4, 1,
       3, 3, 4, 3, 0, 3, 3, 2, 0, 4, 4, 4, 4, 4, 3, 2, 0, 3, 1, 3, 4, 4,
       4, 3, 2, 3, 1, 1, 2, 1, 3, 3, 3, 2, 1, 0, 4, 3, 4, 4, 4, 3, 4, 4,
       0, 0, 0, 4, 4, 0, 2, 2, 2, 4, 4, 4, 1, 1, 2, 2, 2, 1, 0, 1, 4, 4,
       3, 3, 3, 3, 2, 3, 3, 1, 2, 2, 1, 1, 1, 1, 4, 3, 0, 0, 0, 1, 1, 3,
       0, 3, 4, 3, 4, 4, 2, 0, 2, 0, 3, 3, 0, 0, 2, 2, 4, 0, 0, 1, 1, 2,
       2, 1, 1, 0, 0, 4, 3, 2, 2, 3, 0, 0, 1, 4, 2, 2, 2, 2, 3, 4, 1, 4,
       0, 1]])
query_labels = raw_labels[:, 50 * 1 :]

In [None]:
augtarget = torch.tensor([[ 0,  0,  0,  1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  5,  5,  5,
          6,  6,  6,  7,  7,  7,  8,  8,  8,  9,  9,  9, 10, 10, 10, 11, 11, 11,
         12, 12, 12, 13, 13, 13, 14, 14, 14, 15, 15, 15, 16, 16, 16, 17, 17, 17,
         18, 18, 18, 19, 19, 19, 20, 20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23,
         24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29,
         30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35,
         36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41,
         42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47,
         48, 48, 48, 49, 49, 49]])

In [None]:
z_j.shape

In [None]:
query_labels.shape

In [None]:
dists.unsqueeze(0).shape

In [None]:
F.cross_entropy(dists.unsqueeze(0), torch.Tensor([[0 for i in range(150)]]).long(), reduction='sum')

In [None]:
F.cross_entropy(dists.unsqueeze(0), query_labels, reduction='sum')

In [None]:
raw_labels.max()

In [5]:
rz = torch.load('data/raw_embeddings.pt').cpu()

In [13]:
rz.squeeze(0).shape

torch.Size([200, 64])

In [14]:
mapper = umap.UMAP(n_components=3, random_state=42).fit(rz.detach().squeeze())

In [18]:
mrz = mapper.transform(rz.detach().squeeze())

In [23]:
mrz_support = mrz[: 50 * 1]
# e.g. [1,50*n_query,*(3,84,84)]
mrz_query = mrz[50 * 1 :]

In [32]:
clf = cluster.KMeans(n_clusters=5)
pred_labels = clf.fit_predict(mrz)

In [34]:
centroids = clf.cluster_centers_

In [38]:
centroids = torch.from_numpy(mapper.inverse_transform(centroids)).unsqueeze(0)

In [40]:
centroids.shape

torch.Size([1, 5, 64])

In [47]:
rz_query = rz[:, 50:,:]

In [71]:
centroids.shape, rz_query.shape

(torch.Size([1, 5, 64]), torch.Size([1, 150, 64]))

In [82]:
dists = cosine_similarity(centroids, rz_query)
dists.shape, torch.from_numpy(pred_labels[50:]).unsqueeze(0).shape

(torch.Size([1, 5, 150]), torch.Size([1, 150]))

In [87]:
F.cross_entropy(dists, torch.from_numpy(pred_labels[50:]).unsqueeze(0).long(), reduction='mean')

tensor(1.5604, grad_fn=<NllLoss2DBackward>)

In [92]:
def nt_xent_loss(z_i, z_j, query_labels, temperature=.5, reduction='mean'):
    N = z_j.shape[1]
    # calculating distance from every centroid to every augmented image
    dists = cosine_similarity(z_i, z_j) / temperature
    labels = torch.zeros(N).to(dists.device).long()
    print(dists.shape, query_labels.shape)
    loss = F.cross_entropy(dists, labels.unsqueeze(0), reduction=reduction)
    return loss

In [93]:
nt_xent_loss(centroids, rz_query, torch.from_numpy(pred_labels[50:]).long(), temperature=1., reduction='mean')

torch.Size([1, 5, 150]) torch.Size([150])


tensor(1.6147, grad_fn=<NllLoss2DBackward>)