In [35]:
import torch
import torchvision

class TrainData:
    def __init__(self):
        self._download_train_data()
        
    def _download_train_data(self):
        orig_train = torchvision.datasets.CIFAR10('data', train=True, transform=torchvision.transforms.ToTensor(), download=True)

        classes = tuple(orig_train.classes)

        keep_labels = (classes.index('cat'), classes.index('dog'))
        aux_labels = tuple(classes.index(a) for a in ['airplane', 'automobile', 'bird', 'ship', 'truck'])
        n = 10000
        auxn = 25000
        channels = 3
        w = 32
        h = 32
        X = torch.zeros((n, channels, w, h))
        y = torch.zeros((n,), dtype=torch.long)
        auxX = torch.zeros((auxn, channels, w, h))

        j = 0
        auxj = 0
        for x,label in orig_train:
            if label in keep_labels:
                X[j,:,:,:] = x
                y[j] = keep_labels.index(label)
                j += 1
            if label in aux_labels:
                auxX[auxj,:,:,:] = x
                auxj += 1
        if j != n:
            raise Exception("Wrong number of valid examples")
        if auxj != auxn:
            raise Exception("Wrong number of aux examples")
        self._X = X
        self._y = y
        self._n = n
        self._auxX = auxX
        self._auxn = auxn
        
    def in_distribution_dataset(self):
        return [(self._X[i], self._y[i]) for i in range(self._n)]
        
    def mixed_dataset(self):
        ind = [(self._X[i], (self._y[i], 0)) for i in range(self._n)]
        ood = [(self._auxX[i], (0.5, 1)) for i in range(self._auxn)]
        return ind + ood
        
train_data = TrainData()
print("Have training data")

Files already downloaded and verified
Have training data


In [52]:

class TestData:
    def __init__(self):
        self._download_test_data()
        
    def _download_test_data(self):
        orig_train = torchvision.datasets.CIFAR10('data', train=False, transform=torchvision.transforms.ToTensor(), download=True)

        classes = tuple(orig_train.classes)

        keep_labels = (classes.index('cat'), classes.index('dog'))
        ood_labels = tuple(classes.index(a) for a in ['deer','frog','horse'])
        n = 2000
        oodn = 3000
        channels = 3
        w = 32
        h = 32
        X = torch.zeros((n, channels, w, h))
        y = torch.zeros((n,), dtype=torch.long)
        oodX = torch.zeros((oodn, channels, w, h))

        j = 0
        oodj = 0
        for x,label in orig_train:
            if label in keep_labels:
                X[j,:,:,:] = x
                y[j] = keep_labels.index(label)
                j += 1
            if label in ood_labels:
                oodX[oodj,:,:,:] = x
                oodj += 1
        if j != n:
            raise Exception(f"Wrong number of valid examples {j} {n}")
        if oodj != oodn:
            raise Exception(f"Wrong number of ood examples {oodj} {oodn}")
        self._X = X
        self._y = y
        self._n = n
        self._oodX = oodX
        self._oodn = oodn
        
    def in_distribution_dataset(self):
        return [(self._X[i], self._y[i]) for i in range(self._n)]
        
#    def mixed_dataset(self):
#        ind = [(self._X[i], (self._y[i], 0)) for i in range(self._n)]
#        ood = [(self._auxX[i], (0.5, 1)) for i in range(self._auxn)]
#        return ind + ood
        
test_data = TestData()
print("Have test data")

Files already downloaded and verified
Have test data


In [87]:
class SimpleCnnModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv2d(3, 6, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Conv2d(6, 12, 5),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2, 2),
            torch.nn.Flatten(),
            torch.nn.Linear(12 * 5 * 5, 2),
            #torch.nn.ReLU(),
            #torch.nn.Linear(120, 84),
            #torch.nn.ReLU(),
            #torch.nn.Linear(10, 2),
        )

    def forward(self, x):
        return self.layers(x)


In [86]:
class Experiment:
    def __init__(self):
        global train_data
        global test_data
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        self.model = SimpleCnnModel().to(self.device)
        self.dataloader = torch.utils.data.DataLoader(
            train_data.in_distribution_dataset(),
            batch_size=64,
            shuffle=True
        )
        self.test_dataloader = torch.utils.data.DataLoader(
            test_data.in_distribution_dataset(),
            batch_size=64,
            shuffle=False
        )
        self.loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.num_epochs = 50
        
    def _train(self):
        for epoch in range(self.num_epochs):
            #print(f"=== Epoch {epoch}===")
            running_loss = torch.zeros(())
            running_count = 0
            for inputs, labels in self.dataloader:
                self.optimizer.zero_grad()
                outputs = self.model(inputs.to(self.device))
                loss = self.loss_fn(outputs, labels.to(self.device))
                loss.backward()
                self.optimizer.step()
                running_loss += loss.detach().to('cpu')
                running_count += outputs.shape[0]
                    
            #print('    Loss', running_loss.item() / running_count)  

            test_accuracy = torch.zeros(())
            test_count = 0
            with torch.no_grad():
                for inputs, labels in self.test_dataloader:
                    outputs = self.model(inputs.to(self.device))
                    predictions = outputs.argmax(dim=1)
                    test_accuracy += (predictions.to('cpu') == labels).sum()
                    test_count += outputs.shape[0]
            print('Epoch', f'{epoch:2}', 'Loss', f'{running_loss.item() / running_count:20}', 'Test', test_accuracy.item() / test_count)

In [88]:
Experiment()._train()

Epoch  0 Loss     0.68308212890625 Test 0.605
Epoch  1 Loss      0.6551603515625 Test 0.6305
Epoch  2 Loss    0.635703369140625 Test 0.6395
Epoch  3 Loss    0.614438427734375 Test 0.642
Epoch  4 Loss    0.599210009765625 Test 0.6775
Epoch  5 Loss    0.583020068359375 Test 0.6765
Epoch  6 Loss     0.57189814453125 Test 0.6895
Epoch  7 Loss       0.561106640625 Test 0.6895
Epoch  8 Loss      0.5526974609375 Test 0.7185
Epoch  9 Loss     0.54752431640625 Test 0.705
Epoch 10 Loss          0.539278125 Test 0.7145
Epoch 11 Loss    0.530228076171875 Test 0.713
Epoch 12 Loss    0.524926513671875 Test 0.706
Epoch 13 Loss    0.525994775390625 Test 0.72
Epoch 14 Loss     0.51737802734375 Test 0.7155
Epoch 15 Loss     0.51399697265625 Test 0.7215
Epoch 16 Loss      0.5121666015625 Test 0.7215
Epoch 17 Loss      0.5044685546875 Test 0.72
Epoch 18 Loss      0.4989287109375 Test 0.7285
Epoch 19 Loss      0.4943181640625 Test 0.716
Epoch 20 Loss     0.49219521484375 Test 0.7135
Epoch 21 Loss    0.4923