In [None]:
import numpy as np

import torch
import torch.nn as nn

import torchvision
import torchvision.transforms as T

from barlow_twins import BarlowTwins

from augmentation import apply_transforms_inf
from short_video_dataset import ShortVideoDataset

In [None]:
model = torch.load('model.pth')

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
model.to(device)

In [None]:
dataset = ShortVideoDataset('video_short_half_res', transform=T.Compose([
    T.ToTensor(),
    T.CenterCrop(size=720)
]))

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    # drop_last=True,
    num_workers=1,
)

In [None]:
embeddings = []

model.eval()
for image in dataloader:
    image = apply_transforms_inf(image).to(device)
    emb = model(image).cpu().detach().flatten()

    embeddings.append(np.array(emb))

In [None]:
len(embeddings)

In [None]:
embeddings = np.array(embeddings)

## PCA

In [None]:
from sklearn.decomposition import PCA

In [None]:
pca = PCA(n_components=2)
pca.fit(embeddings)

embeddings_red = pca.transform(embeddings)
X, Y = embeddings_red[:,0], embeddings_red[:,1]

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.scatter(X, Y)

In [None]:
X_sorted = sorted(list(zip(X, range(len(X)))), key= lambda x: x[0])

## Nearest Neighbours

In [None]:
from sklearn.neighbors import NearestNeighbors
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
neigh = NearestNeighbors(n_neighbors=10)
neigh.fit(embeddings)

In [None]:
random_index = np.random.randint(len(embeddings))
neigh.kneighbors(embeddings[random_index].reshape(1,-1), 10, return_distance=False)

In [None]:
num_samples = 10
num_neighb = 5

fig = plt.figure(figsize=(20,20))
grid = ImageGrid(fig, 111, nrows_ncols=(10,num_neighb), axes_pad=0.1)

for r in range(num_samples):
    random_index = np.random.randint(len(embeddings))
    neighb_index = neigh.kneighbors(embeddings[random_index].reshape(1,-1), num_neighb, return_distance=False)

    for c in range(len(neighb_index[0])):
        im = dataset.__getitem__(neighb_index[0][c])
        im = im.swapaxes(0,-1).swapaxes(0,1)
        grid[num_neighb*r+c].imshow(im, )


In [None]:
X_sorted

In [None]:
fig = plt.figure(figsize=(200,200))
grid = ImageGrid(fig, 111, nrows_ncols=(100, 5), axes_pad=0.1)

for ax, (_, i) in zip(grid, X_sorted):
    im = dataset.__getitem__(i)
    im = im.swapaxes(0,-1).swapaxes(0,1)
    ax.imshow(im)