In [1]:
import numpy as np
from PIL import Image
from PIL import ImageShow
from torch.utils.data import Dataset, Subset
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor

def show_image(t):
    t = (t.permute(1,2,0) * 255).to(torch.uint8)
    img = Image.fromarray(t.numpy()).resize((200,200), Image.Resampling.NEAREST)
    ImageShow.IPythonViewer().show(img)

CLASSES = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def data(train:bool, keep:list[str], ood:list[str]) -> Dataset:
    """Represents part of the CIFAR-10 dataset.
    
    Targets will be relabeled to be in the same order as `classes`.
    Any "out of distribution" targets from the `ood` list will be given a target of `len(classes)`.
    
    train: Whether to use the train or test portion of the dataset
    keep: List of classes that we want to keep. Available classes:
         airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
    ood: List of classes that are considered out-of-distribution
    """
    
    target_map = {}
    for i,c in enumerate(keep):
        target_map[CLASSES.index(c)] = i
    for c in ood:
        target_map[CLASSES.index(c)] = len(keep)
    
    orig = CIFAR10('data', train=train, transform=ToTensor(), target_transform=lambda c:target_map[c], download=True)
    indices = [i for i,t in enumerate(orig.targets) if t in target_map]
    return Subset(orig, indices)


In [4]:
import torch
from torch.nn import Module, Conv2d, ReLU, MaxPool2d, Flatten, Linear, CrossEntropyLoss
from torch.utils.data import DataLoader
import time

class Model(Module):
    def __init__(self, n_outputs):
        super().__init__()
        self.layers = torch.nn.Sequential(
            Conv2d(3, 32, 3),
            ReLU(),
            MaxPool2d(2, 2),
            Conv2d(32, 64, 3),
            ReLU(),
            MaxPool2d(2, 2),
            Conv2d(64, 32, 4),
            ReLU(),
            Flatten(),
            Linear(288, n_outputs),
        )

    def forward(self, x):
        return self.layers(x)
    
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def test(model, test_data):
    with torch.no_grad():
        dataloader = DataLoader(test_data, batch_size=64, shuffle=False)
        running_acc = torch.zeros(()).to(device)
        running_n = 0
        for X,targets in dataloader:
            X = X.to(device)
            targets = targets.to(device)
            
            outputs = model(X)
            
            predictions = outputs.argmax(dim=1)
            running_acc += (predictions == targets).sum(dim=0)
            running_n += X.shape[0]
        print(f'    Accuracy', running_acc.item() / running_n)

def train(keep:list[str], ood_train:list[str], ood_test:list[str], max_epochs:int = 10, max_minutes:int = 5):
    print(device)
    
    train_data = data(True, keep, ood_train)
    test_data = data(False, keep, ood_test)
    dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
    
    n_outputs = len(keep)
    if len(ood_train) > 0:
        n_outputs += 1
        
    model = Model(n_outputs).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.0001)
    loss = CrossEntropyLoss()

    start_time = time.monotonic()

    for epoch in range(max_epochs):
        running_loss = torch.zeros(()).to(device)
        running_n = 0
        for X,targets in dataloader:
            X = X.to(device)
            targets = targets.to(device)
            opt.zero_grad()
            outputs = model(X)

            loss_value = loss(outputs, targets)
            loss_value.backward()
            opt.step()
            running_loss += loss_value.detach()
            running_n += X.shape[0]
        print('Epoch', epoch, '   Loss', running_loss.item() / running_n)
        
        test(model, test_data)

        if time.monotonic() - start_time > max_minutes * 60:
            print("Run out of time!")
            break
    print('Time taken', time.monotonic() - start_time, ' seconds')

    return model

In [5]:
keep = ['cat','dog']
ood_train = ['horse']
ood_test = ['ship']

model = train(keep, ood_train, ood_test)

cuda:0
Files already downloaded and verified
Files already downloaded and verified
Epoch 0    Loss 0.016696444702148438
    Accuracy 0.39866666666666667
Epoch 1    Loss 0.015561748250325521
    Accuracy 0.47533333333333333
Epoch 2    Loss 0.01444874267578125
    Accuracy 0.4876666666666667
Epoch 3    Loss 0.013751644897460938
    Accuracy 0.5113333333333333
Epoch 4    Loss 0.013458868408203125
    Accuracy 0.523
Epoch 5    Loss 0.013134765625
    Accuracy 0.5436666666666666
Epoch 6    Loss 0.012840402221679688
    Accuracy 0.543
Epoch 7    Loss 0.012672285970052084
    Accuracy 0.5536666666666666
Epoch 8    Loss 0.012467698160807292
    Accuracy 0.5493333333333333
Epoch 9    Loss 0.012376154581705729
    Accuracy 0.5193333333333333
Time taken 24.670533333002822  seconds
