In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

from model import DownstreamModel, BaselineModel

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
seed = 42 
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x111f136b0>

In [4]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

cifar_train = datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform
)
cifar_test = datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform
)

num_train = len(cifar_train)
num_valid = int(0.1 * num_train)
num_train -= num_valid

train_dataset, val_dataset = random_split(cifar_train, [num_train, num_valid])

train_transform = transforms.RandomResizedCrop(32,(0.8,1.0))
train_dataset.dataset.transform = transforms.Compose([transform, train_transform])

Files already downloaded and verified
Files already downloaded and verified


In [10]:
batch_size = 128
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(cifar_test, batch_size=batch_size, shuffle=False)

In [11]:
encoder_path = "models/encoder.pth"
encoder = torch.load(encoder_path)
model_head = BaselineModel(num_channels=2) 
num_classes = 10
model = DownstreamModel(encoder=encoder, model_head=model_head, num_classes=num_classes, batch_size=batch_size)

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [14]:
num_epochs = 1

for epoch in range(num_epochs): 
    training_loss = 0.0

    for i, (inputs, labels) in enumerate(train_dataloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = inputs.to(device), labels.to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        training_loss += loss.item()
        break

    training_loss /= len(train_dataloader)
    print(f"Epoch: {epoch+1}, training_loss: {training_loss}")    

print('Finished Training')



Epoch: 1, training_loss: 0.006495394489981912
Finished Training


In [15]:
correct = 0
total = 0

with torch.no_grad():
    for (images, labels) in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        break

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

Accuracy of the network on the 10000 test images: 7 %
