In [None]:
# papermill parameters
aid = 'interactive'
print(f'aid={aid}')

In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.models as models
from torch import nn
from torch import optim

from tqdm import tqdm
import os
import numpy as np
from PIL import Image
import pickle as pkl
import matplotlib.pyplot as plt

import util
from DuckDataset import DuckDataset

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

## Load data and train a model

In [None]:
# dataset class for the background images
class ImageDataset(torch.utils.data.Dataset):
    def __init__(self, 
                 images, 
                 transform = None,
            ):
            super(ImageDataset, self).__init__()
            self.images = images
            self.transform = transform    

    def __getitem__(self, index):
        assert index < len(self.images), 'Invalid index!'
        # get the image
        img = self.images[index]
        # apply transform
        if self.transform is not None:
            img = self.transform(img)
        return img, 0
            
    def __len__(self):
        return len(self.images)

In [None]:
# load training and validation set
duckdata_dir = 'data'
mode = 'id'

if mode == 'id':

    trainset = pkl.load(open(f'{duckdata_dir}/duck_train.pkl' , 'rb'))
    valset =  pkl.load(open(f'{duckdata_dir}/duck_val.pkl' , 'rb'))

    
if mode == 'iid': 
    
    train_images = pkl.load(open(f'{duckdata_dir}/imagenet10_train.pkl' , 'rb'))
    val_images = pkl.load(open(f'{duckdata_dir}/imagenet10_val.pkl', 'rb'))

    background_train = ImageDataset(train_images, transform=transforms.Compose([transforms.RandomResizedCrop(224), 
                                                                            transforms.RandomHorizontalFlip()]))
    background_val = ImageDataset(val_images, transform=transforms.Compose([transforms.RandomResizedCrop(224),
                                                                            transforms.RandomHorizontalFlip()]))
    
    
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
    random_size = (50, 100)

    trainset = DuckDataset(background_train,
                           random_size=random_size,
                           transform=transforms.Compose([transforms.ToTensor(), normalize]),
                           uniform_yellow = False
                           )
    valset = DuckDataset(background_val,
                         random_size=random_size,
                         transform=transforms.Compose([transforms.ToTensor(), normalize]),
                         uniform_yellow = False
                         )
    
trainloader = DataLoader(trainset, batch_size=64, shuffle=False, num_workers=6)
valloader = DataLoader(valset, batch_size=64, shuffle=False, num_workers=6)
     

In [None]:
# specify model

net_name = 'resnet18'
net = models.resnet18()
net.fc = nn.Linear(512, 2) # 2-class problem
net.to(device)

optimizer = torch.optim.Adam(net.parameters(), 0.001)

In [None]:
# train model 
util.train(net, optimizer, trainloader, valloader, device, 25, eps= 0.005)

In [None]:
# save model 
model_dir = 'models'
torch.save(net.state_dict(), f'{model_dir}/duck_{mode}_model_{net_name}_{aid}.pkl')

In [None]:
# remove net from gpu
net = net.to('cpu')