In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import models, transforms
import matplotlib
import matplotlib.pyplot as plt
import time
import os
import copy
import random
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torchsummary import summary
from collections import OrderedDict
import re

from data_helper import UnlabeledDataset, LabeledDataset
from helper import collate_fn, draw_box, compute_ts_road_map
from modelzoo import *
from simclr_transforms import *

In [6]:
xi = torch.rand(3,512)
xj = torch.rand(3,512)

In [9]:
x = torch.cat((xi, xj), dim=0)
x.shape

torch.Size([6, 512])

In [10]:
sim_mat = torch.mm(x, x.T)
sim_mat.shape

torch.Size([6, 6])

In [16]:
torch.norm(x, dim=1).unsqueeze(1).shape

torch.Size([6, 1])

In [11]:
if True:
    sim_mat_denom = torch.mm(torch.norm(x, dim=1).unsqueeze(1), torch.norm(x, dim=1).unsqueeze(1).T)
    sim_mat = sim_mat / sim_mat_denom.clamp(min=1e-16)
    

In [12]:
sim_mat_denom.shape

torch.Size([6, 6])

In [18]:
sim_mat = torch.exp(sim_mat / 0.1)

In [28]:
torch.sum(xi * xj, dim=-1).shape

torch.Size([3])

In [20]:
if True:
    sim_mat_denom = torch.norm(xi, dim=1) * torch.norm(xj, dim=1)
    sim_match = torch.exp(torch.sum(xi * xj, dim=-1) / sim_mat_denom / 0.1)
else:
    sim_match = torch.exp(torch.sum(xi * xj, dim=-1) / 0.1)

In [21]:
sim_match.shape

torch.Size([3])

In [22]:
sim_match = torch.cat((sim_match, sim_match), dim=0)
sim_match.shape

torch.Size([6])

In [31]:
norm_sum = torch.exp(torch.ones(x.size(0)) / 0.1)
loss = torch.mean(-torch.log(sim_match / (torch.sum(sim_mat, dim=-1) - norm_sum)))

In [32]:
loss.shape

torch.Size([])

In [46]:
xi = torch.Tensor([[5,4],[1,0],[3,3.5]])
xj = torch.Tensor([[4,3],[2,1],[3,2.5]])

print(xi.shape,xj.shape)

torch.Size([3, 2]) torch.Size([3, 2])


In [47]:
x = torch.cat((xi, xj), dim=0)
batch_size = x.shape[0]

In [48]:
sim_mat = torch.mm(x, x.T)
sim_mat.shape

torch.Size([6, 6])

In [49]:
x

tensor([[5.0000, 4.0000],
        [1.0000, 0.0000],
        [3.0000, 3.5000],
        [4.0000, 3.0000],
        [2.0000, 1.0000],
        [3.0000, 2.5000]])

In [50]:
x.T

tensor([[5.0000, 1.0000, 3.0000, 4.0000, 2.0000, 3.0000],
        [4.0000, 0.0000, 3.5000, 3.0000, 1.0000, 2.5000]])

In [51]:
sim_mat

tensor([[41.0000,  5.0000, 29.0000, 32.0000, 14.0000, 25.0000],
        [ 5.0000,  1.0000,  3.0000,  4.0000,  2.0000,  3.0000],
        [29.0000,  3.0000, 21.2500, 22.5000,  9.5000, 17.7500],
        [32.0000,  4.0000, 22.5000, 25.0000, 11.0000, 19.5000],
        [14.0000,  2.0000,  9.5000, 11.0000,  5.0000,  8.5000],
        [25.0000,  3.0000, 17.7500, 19.5000,  8.5000, 15.2500]])

In [52]:
torch.norm(x, dim=1).unsqueeze(1)

tensor([[6.4031],
        [1.0000],
        [4.6098],
        [5.0000],
        [2.2361],
        [3.9051]])

In [53]:
torch.topk(sim_mat,k=2)

torch.return_types.topk(
values=tensor([[41.0000, 32.0000],
        [ 5.0000,  4.0000],
        [29.0000, 22.5000],
        [32.0000, 25.0000],
        [14.0000, 11.0000],
        [25.0000, 19.5000]]),
indices=tensor([[0, 3],
        [0, 3],
        [0, 3],
        [0, 3],
        [0, 3],
        [0, 3]]))

In [54]:
closest_vectors = torch.topk(sim_mat,k=2)[1][:,1]
target_vectors = torch.cat((torch.arange(batch_size/2,batch_size),torch.arange(0,batch_size/2)),dim=0)
batch_acc = torch.sum(closest_vectors == target_vectors)*100/batch_size


In [55]:
sim_mat_denom = torch.mm(torch.norm(x, dim=1).unsqueeze(1), torch.norm(x, dim=1).unsqueeze(1).T)
sim_mat = sim_mat / sim_mat_denom.clamp(min=1e-16)

In [56]:
sim_mat_denom

tensor([[41.0000,  6.4031, 29.5169, 32.0156, 14.3178, 25.0050],
        [ 6.4031,  1.0000,  4.6098,  5.0000,  2.2361,  3.9051],
        [29.5169,  4.6098, 21.2500, 23.0489, 10.3078, 18.0017],
        [32.0156,  5.0000, 23.0489, 25.0000, 11.1803, 19.5256],
        [14.3178,  2.2361, 10.3078, 11.1803,  5.0000,  8.7321],
        [25.0050,  3.9051, 18.0017, 19.5256,  8.7321, 15.2500]])

In [40]:
sim_mat

tensor([[1.0000, 0.7809, 0.9825, 0.9995, 0.9778, 0.9998],
        [0.7809, 1.0000, 0.6508, 0.8000, 0.8944, 0.7682],
        [0.9825, 0.6508, 1.0000, 0.9762, 0.9216, 0.9860],
        [0.9995, 0.8000, 0.9762, 1.0000, 0.9839, 0.9987],
        [0.9778, 0.8944, 0.9216, 0.9839, 1.0000, 0.9734],
        [0.9998, 0.7682, 0.9860, 0.9987, 0.9734, 1.0000]])

In [41]:
torch.topk(sim_mat,k=2)

torch.return_types.topk(
values=tensor([[1.0000, 0.9998],
        [1.0000, 0.8944],
        [1.0000, 0.9860],
        [1.0000, 0.9995],
        [1.0000, 0.9839],
        [1.0000, 0.9998]]),
indices=tensor([[0, 5],
        [1, 4],
        [2, 5],
        [3, 0],
        [4, 3],
        [5, 0]]))

In [42]:
closest_vectors = torch.topk(sim_mat,k=2)[1][:,1]
target_vectors = torch.cat((torch.arange(batch_size/2,batch_size),torch.arange(0,batch_size/2)),dim=0)
batch_acc = torch.sum(closest_vectors == target_vectors)*100/batch_size


In [43]:
closest_vectors

tensor([5, 4, 5, 0, 3, 0])

In [44]:
target_vectors

tensor([3., 4., 5., 0., 1., 2.])

In [45]:
x

tensor([[5.0000, 4.0000],
        [1.0000, 0.0000],
        [3.0000, 3.5000],
        [4.0000, 3.0000],
        [2.0000, 1.0000],
        [3.0000, 2.5000]])

In [57]:
torch.norm(x, dim=1).unsqueeze(1)

tensor([[6.4031],
        [1.0000],
        [4.6098],
        [5.0000],
        [2.2361],
        [3.9051]])

In [65]:
126*134*6

101304