In [44]:
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from torch.utils.data.sampler import BatchSampler
import numpy as np
from torchvision import  models
from pytorch_metric_learning import losses
import umap
import  matplotlib.pyplot as plt
import torch.optim as optim
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


In [45]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

main_dataset = torchvision.datasets.CIFAR10(root='.', train=True, download=False, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='.', train=False,download=False, transform=transform)

In [46]:
# generate indices: instead of the actual data we pass in integers instead
train_indices, val_indices,_,_ = train_test_split(range(len(main_dataset.data)), main_dataset.targets, stratify=main_dataset.targets, test_size=0.8, random_state=42)
# generate subset based on indices
train_dataset = Subset(main_dataset, train_indices)
val_dataset = Subset(main_dataset, val_indices)

In [47]:
# testing stratification
train_labels = torch.tensor([x[1] for x in train_dataset])
unique_values, counts = train_labels.unique(return_counts=True)
unique_values

tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [48]:
class BalancedBatchSampler(BatchSampler):
    def __init__(self, dataset, n_classes, n_samples):
        loader = DataLoader(dataset)
        self.labels =torch.LongTensor([x[1] for x in loader])
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0] for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])

        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size < len(self.dataset):
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return len(self.dataset) // self.batch_size


In [49]:
# create batches
val_loader = DataLoader(val_dataset, batch_size=60)
train_loader = DataLoader(train_dataset, batch_sampler=BalancedBatchSampler(train_dataset, 10, 6))
test_loader = DataLoader(test_dataset, batch_size=60)

In [50]:
# Define ResNet18 model
class ResNet18(models.ResNet):
    def __init__(self, num_classes=10, is_classifer:bool = True):
        super().__init__(models.resnet.BasicBlock, [2, 2, 2, 2], num_classes=num_classes)
        self.is_classifier = is_classifer
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        backbone = torch.flatten(x, 1)
        if self.is_classifier:
            x = self.fc(backbone)
        return x, backbone

model = ResNet18().to(device)
# classifier =nn.Linear(128, 10)

In [51]:
# angular = losses.AngularLoss(40)
loss_fn = nn.CrossEntropyLoss() # TODO: change this

optimizer = optim.Adam(model.parameters())
# optimizer = optim.SGD(backbone_model.parameters(), lr=0.0001, momentum=0.9)


In [None]:
for epoch in range(20):
    train_loss = 0.0
    val_loss = 0.0
    model.train()
    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        logits, embeddings = model(inputs)
        loss = loss_fn(logits, targets) # TODO: change this
        loss.backward()
        optimizer.step()
        # print statistics
        train_loss += loss.item()
    model.eval()
    for inputs, targets in val_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        logits, embeddings = model(inputs)
        loss = loss_fn(logits, targets) # TODO: change this
        val_loss += loss.item()
    print(f'{epoch + 1} train loss: {train_loss / len(train_loader)} val loss: {val_loss / len(val_loader)}')


1 train loss: 1.686223339603608 val loss: 1.6378221007837528


In [None]:
PATH = './resnet.pth'
torch.save(model.state_dict(), PATH)

In [None]:
x_test = torch.stack([x[0] for x in test_dataset])

In [None]:
y_test = torch.tensor([x[1] for x in test_dataset])

In [None]:
_,test_embeddings = model.to('cpu')(x_test)

In [None]:
umap_reducer = umap.UMAP(n_components=2)
x_umap = umap_reducer.fit_transform(test_embeddings.detach().numpy())

In [None]:
x_umap.shape

In [None]:
fig = plt.figure()
# ax = fig.add_subplot(projection='3d')
ax = fig.add_subplot()
# ax.scatter(x_umap[:,0], x_umap[:,1], x_umap[:,2],c=y_test)
ax.scatter(x_umap[:,0], x_umap[:,1],c=y_test, cmap='Spectral')