In [None]:
import numpy as np
import torch
import torchvision
from torch import nn
from torch.nn.functional as F
from sklearn.manifold import TSNE
import matplot
from matplotlib import pyplot as plt
import tqdm.notebook as tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!ls /content/gdrive/MyDrive/cifar10_classifier_large.pth

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device", device)

transform = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(),
     torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_data = torchvision.datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform)

test_data = torchvision.datasets.CIFAR10(
    root="data",
    train=True,
    download=True,
    transform=transform
)

test_data = torchvision.datasets.CIFAR10(
    root='./data',
    train=False,
    download=True,
    transform=transform)

batch_size = 4
trainloader = torch.utils.data.DataLoader(training_data, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(test_data, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

In [None]:
images = [training_data[i][0] for i in range(9)]
plt.imshow(torchvision.utils.make_grid(torch.stack(images), nrow=3, padding=5).numpy().transpose((1,2,0)))

In [None]:
# ...

In [None]:
class Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Conv2d(3, 128, 3, 1, 1)
    self.bn1 = nn.BatchNorm2d(128)
    self.conv2 = nn.Conv2d(128, 128, 3, 1, 1)
    self.bn2 = nn.BatchNorm2d(128)
    self.pool1 = nn.MaxPool2d(2)
    self.conv3 = nn.Conv2d(128, 256, 3, 1, 1)
    self.bn3 = nn.BatchNorm2d(256)
    self.conv4 = nn.Conv2d(256, 256, 3, 1, 1)
    self.bn4 = nn.BatchNorm2d(256)
    self.pool2 = nn.MaxPool2d(2)
    self.linear1 = nn.Linear(256 * 8 * 8, 256)
    self.bn_l1 = nn.BatchNorm1d(256)
    self.linear2 = nn.Linear(256, 10)

  def forward(self, x):
    out = self.bn1(F.relu(self.conv1(x)))
    out = self.bn2(F.relu(self.conv2(out)))
    out = self.pool1(out)
    out = self.bn3(F.relu(self.conv3(out)))
    out = self.bn4(F.relu(self.conv4(out)))
    out = self.pool2(out)
    out = torch.flatten(out, start_dim=1)
    out = self.bn_l1(F.relu(self.linear1(out)))
    out = self.linear2(out)
    return out

net = Net().to(device)
model_save_name = 'cifar10_classifier_large.pth'
path = F"/content/gdrive/MyDrive/{model_save_name}"
net.load_state_dict(torch.load(path))
net.eval()

In [None]:
# ...

In [None]:
class SaveFeatures():
  features = None
  def __init__(self, m): self.hook = m.register_forward_hook(self.hook_fn)
  def hook_fn(self, module, input, output): self.features = ((output.cpu()).data).numpy()
  def remove(self): self.hook.remove()

In [None]:
def get_features_from_layer(layer):
  activated_features = SaveFeatures(layer)
  return activated_features

In [None]:
# ...