In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from src.models import DummyNet
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from sklearn.manifold import TSNE

%matplotlib inline


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load("checkpoints/checkpoint_10.pkl", map_location=device)
model = DummyNet().to(device)
model.load_state_dict(checkpoint["model"])


In [None]:
dataset = ImageFolder(
    root="dataset/",
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Resize((224, 224))]
    ),
)

loader = DataLoader(dataset, batch_size=200, shuffle=True)
images, labels = next(iter(loader))
images.size(), labels.size()


In [None]:
images = images.to(device)
embedded_images = model(images)
embedded_images = embedded_images.detach().cpu().numpy()
embedded_images = TSNE(
    n_components=2,
    learning_rate="auto",
    init="pca",
).fit_transform(embedded_images)
embedded_images.shape

In [None]:
plt.figure(dpi=100)
scatter = plt.scatter(embedded_images[:, 0], embedded_images[:, 1], c=labels)
handles, _ = scatter.legend_elements(prop="colors")
plt.legend(handles, dataset.classes)
plt.title(
    "Visualizing The Encoded Data Using t-SNE\nModel: {}".format(
        model.__class__.__name__
    ),
    fontsize=10,
)
plt.xlabel("$x_1$", fontsize=15)
plt.ylabel("$x_2$", fontsize=15)
plt.savefig("tsne_vis_.jpg")