In [None]:
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10, FashionMNIST, MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm

## Install `neurve` package from GitHub

In [None]:
!pip install git+https://github.com/ekorman/neurve

## Choose dataset

In [None]:
# options are "fashion_mnist", "mnist", and "cifar"

dataset = "cifar"

## Download and load model

In [None]:
!wget https://github.com/ekorman/neurve/releases/download/v0.1.0/{dataset}.tar.gz
!tar xzvf {dataset}.tar.gz

In [None]:
from neurve.unsupervised.models import SimCLRMfld

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = SimCLRMfld.load_from_folder(dataset).to(device).eval()

## Download and load dataset

In [None]:
dset_class = {"fashion_mnist": FashionMNIST, "mnist": MNIST, "cifar": CIFAR10}[dataset]
dset = FashionMNIST(root="neurve/data/", train=False, download=True, transform=ToTensor())
dl = DataLoader(dset, batch_size=32)

## Run inference

In [None]:
all_q, all_coords = None, None
with torch.no_grad():
    for x, y in tqdm(dl):
        x = x.to(device)
        if all_q is None:
            all_q, all_coords = [t.cpu() for t in net.encode(x)]
        else:
            q, coords = net.encode(x)
            all_q = torch.cat([all_q, q.cpu()])
            all_coords = torch.cat([all_coords, coords.cpu()])

all_coords = torch.sigmoid(all_coords)

## Display all the charts

Each image is drawn in the chart it most likely belongs to

In [None]:
qamax = all_q.argmax(1)

n_charts = all_q.shape[1]

def get_im_tensor(i):
    return dset[i][0].numpy().transpose(1, 2, 0)

for c in qamax.unique():
    fig, ax = plt.subplots(figsize=(16, 16))
    ax.axis("off")
    for i in torch.where(qamax == c)[0]:
        cmap = None
        im_tensor = get_im_tensor(i)
        if im_tensor.shape[-1] == 1:
            im_tensor = im_tensor[:, :, 0]
            cmap = "gray"
        ab = AnnotationBbox(
            OffsetImage(im_tensor, cmap=cmap),
            (all_coords[i, c, 0], all_coords[i, c, 1]),
            frameon=False
        )
    ax.add_artist(ab)
  
  plt.show()