In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T

In [2]:
from barlow_twins import BarlowTwins
from torchvision.models import ResNet, resnet18
from custom_mnist import CustomMnist
from tqdm import tqdm

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
model = torch.load('models/model_x.pth')

In [5]:
backbone = torch.load('models/backbone_x.pth')


In [6]:
train_dataset = CustomMnist('datasets/mnist', train=True, transform=T.ToTensor())
test_dataset = CustomMnist('datasets/mnist', train=False, transform=T.ToTensor())

In [7]:
train_dataset.data = train_dataset.data.unsqueeze(dim=1).repeat(1,3,1,1)
test_dataset.data = test_dataset.data.unsqueeze(dim=1).repeat(1,3,1,1)

In [8]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1024)

In [9]:
class LinearEval(nn.Module):
    """Some Information about LinearEval"""
    def __init__(self, backbone: nn.Module):
        super(LinearEval, self).__init__()
        self.backbone = backbone
        self.linear = nn.Linear(in_features=512, out_features=10, bias=True)

    def forward(self, x):
        x = self.backbone(x)
        x = torch.flatten(x, start_dim=1)
        x = self.linear(x)

        return x

In [10]:
eval_model = LinearEval(backbone=backbone).to(device)
for param in eval_model.backbone.parameters():
    param.requires_grad = False

In [11]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.SGD(eval_model.parameters(), lr=0.06, momentum=0.8)

In [12]:
def test_model(model, testloader):
    wrongly_classified = 0
    for i, data in enumerate(testloader, 0):
        total = len(data[0])
        inputs, labels = data
        inputs,labels = inputs.to(device), labels.to(device)

        with torch.no_grad():
            preds = model(inputs).argmax(dim=1)

        wrong = total - (preds == labels).sum()
        wrongly_classified += wrong

    return wrongly_classified / len(test_dataset)

In [13]:
def train(model, train_loader, test_loader, device):
    # loop over the dataset multiple times
    for epoch in range(30):
        running_loss = 0.0
        for i, data in enumerate(tqdm(train_loader), 0):
            inputs, labels = data
            labels = nn.functional.one_hot(labels).float()
            inputs, labels = inputs.to(device), labels.to(device)


            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = eval_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print('Epoch: {}, Loss: {}, Test error: {}'.format(epoch, running_loss, test_model(eval_model, test_loader)))

    print('Finished Training')

In [14]:
train(eval_model, train_loader, test_loader, device)

100%|██████████| 59/59 [00:06<00:00,  8.88it/s]


Epoch: 0, Loss: 20.877052783966064, Test error: 0.06529999524354935


100%|██████████| 59/59 [00:05<00:00, 10.15it/s]


Epoch: 1, Loss: 11.4639972448349, Test error: 0.05889999866485596


100%|██████████| 59/59 [00:05<00:00, 10.10it/s]


Epoch: 2, Loss: 10.644036769866943, Test error: 0.05810000002384186


100%|██████████| 59/59 [00:05<00:00, 10.05it/s]


Epoch: 3, Loss: 10.08543911576271, Test error: 0.05700000002980232


100%|██████████| 59/59 [00:05<00:00,  9.97it/s]


Epoch: 4, Loss: 9.770879067480564, Test error: 0.055399999022483826


100%|██████████| 59/59 [00:05<00:00, 10.26it/s]


Epoch: 5, Loss: 9.536826767027378, Test error: 0.05509999766945839


100%|██████████| 59/59 [00:05<00:00, 10.28it/s]


Epoch: 6, Loss: 9.285341687500477, Test error: 0.05350000038743019


100%|██████████| 59/59 [00:05<00:00, 10.13it/s]


Epoch: 7, Loss: 9.095866352319717, Test error: 0.051500000059604645


100%|██████████| 59/59 [00:05<00:00, 10.16it/s]


Epoch: 8, Loss: 8.863294132053852, Test error: 0.053999997675418854


100%|██████████| 59/59 [00:05<00:00, 10.20it/s]


Epoch: 9, Loss: 8.714181773364544, Test error: 0.05249999836087227


100%|██████████| 59/59 [00:05<00:00, 10.22it/s]


Epoch: 10, Loss: 8.644704982638359, Test error: 0.052299998700618744


100%|██████████| 59/59 [00:05<00:00, 10.05it/s]


Epoch: 11, Loss: 8.493742361664772, Test error: 0.0510999970138073


100%|██████████| 59/59 [00:05<00:00,  9.96it/s]


Epoch: 12, Loss: 8.422406524419785, Test error: 0.05169999971985817


100%|██████████| 59/59 [00:05<00:00, 10.18it/s]


Epoch: 13, Loss: 8.339198671281338, Test error: 0.04959999769926071


100%|██████████| 59/59 [00:05<00:00, 10.08it/s]


Epoch: 14, Loss: 8.212627649307251, Test error: 0.04999999701976776


100%|██████████| 59/59 [00:05<00:00, 10.01it/s]


Epoch: 15, Loss: 8.250644117593765, Test error: 0.050200000405311584


100%|██████████| 59/59 [00:06<00:00,  9.83it/s]


Epoch: 16, Loss: 8.059763118624687, Test error: 0.050200000405311584


100%|██████████| 59/59 [00:05<00:00,  9.96it/s]


Epoch: 17, Loss: 8.018186494708061, Test error: 0.050200000405311584


100%|██████████| 59/59 [00:05<00:00, 10.05it/s]


Epoch: 18, Loss: 7.930015340447426, Test error: 0.049799997359514236


100%|██████████| 59/59 [00:05<00:00, 10.05it/s]


Epoch: 19, Loss: 7.808642916381359, Test error: 0.04699999839067459


100%|██████████| 59/59 [00:05<00:00, 10.10it/s]


Epoch: 20, Loss: 7.849245570600033, Test error: 0.04809999838471413


100%|██████████| 59/59 [00:05<00:00, 10.16it/s]


Epoch: 21, Loss: 7.771488256752491, Test error: 0.048899997025728226


100%|██████████| 59/59 [00:05<00:00, 10.05it/s]


Epoch: 22, Loss: 7.677585020661354, Test error: 0.05009999871253967


100%|██████████| 59/59 [00:05<00:00, 10.19it/s]


Epoch: 23, Loss: 7.647305317223072, Test error: 0.047599997371435165


100%|██████████| 59/59 [00:05<00:00, 10.20it/s]


Epoch: 24, Loss: 7.552098289132118, Test error: 0.04699999839067459


100%|██████████| 59/59 [00:05<00:00, 10.23it/s]


Epoch: 25, Loss: 7.580753400921822, Test error: 0.047199998050928116


100%|██████████| 59/59 [00:05<00:00, 10.16it/s]


Epoch: 26, Loss: 7.525532782077789, Test error: 0.04959999769926071


100%|██████████| 59/59 [00:05<00:00,  9.96it/s]


Epoch: 27, Loss: 7.436621397733688, Test error: 0.04749999940395355


100%|██████████| 59/59 [00:05<00:00, 10.18it/s]


Epoch: 28, Loss: 7.421942211687565, Test error: 0.04859999939799309


100%|██████████| 59/59 [00:05<00:00, 10.09it/s]


Epoch: 29, Loss: 7.394971892237663, Test error: 0.050200000405311584
Finished Training
