In [148]:
import torch
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
import numpy as np
from tqdm.notebook import tqdm

In [20]:
class tuple_to_dict:
    
    def __init__(self, tuple_object):
        self.tuple = tuple_object
    
    def __len__(self):
        return len(self.tuple)
    
    def __getitem__(self, index):
        return {
            "input" : self.tuple[index][0],
            "target" : self.tuple[index][1],
        }

In [100]:
def get_mnist(train=False):
    from torchvision import datasets, transforms
    return tuple_to_dict(datasets.MNIST("./resources/data/raw", train=train, transform=transforms.Compose([transforms.Grayscale(num_output_channels=3),transforms.Resize([32, 32]), transforms.ToTensor()])))

In [121]:
class mnist_m:
    
    def __init__(self, path, train=False):
        
        if train:
            self.files = open(path+"/mnist_m_train_labels.txt").read().split("\n")[:-1]
            self.path = path + "/mnist_m_train/"
        else:
            self.files = open(path+"/mnist_m_test_labels.txt").read().split("\n")[:-1]
            self.path = path + "/mnist_m_test/"
        
        print("Fetching {} files".format(len(self.files)))
        
        self.data = []
        for e in self.files:
            ee = e.split(" ")
            self.data.append({
                "input" : ee[0],
                "target" : int(ee[1])
            })
        
        self.transform = transforms.Compose([transforms.ToTensor()])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        image = Image.open(self.path + self.data[index]['input'])
        
        return {
            "input" : self.transform(image),
            "target" : self.data[index]['target']
        }

In [122]:
dataset = {
    "train" :{
        "mnist" : torch.utils.data.DataLoader(get_mnist(train=True), batch_size=64),
        "mnist_m" : torch.utils.data.DataLoader(mnist_m("resources/data/raw/mnist_m/", train=True), batch_size=64),
    },
    "test" : {
        "mnist" : torch.utils.data.DataLoader(get_mnist(train=False), batch_size=64),
        "mnist_m" : torch.utils.data.DataLoader(mnist_m("resources/data/raw/mnist_m/", train=False), batch_size=64),
    }
}

Fetching 59001 files
Fetching 9001 files


In [124]:
class grl_grad(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input

    @staticmethod
    def backward(ctx, grad_output):
        return -1*grad_output



class grl(torch.nn.Module):
    
    def __init__(self):
        super(grl, self).__init__()
    
    def forward(self, x):
        grl_grad.apply(x)

In [167]:
class model(torch.nn.Module):        
    
    def __encoder__(self):
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(kernel_size=5, in_channels=3, out_channels=32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            torch.nn.Conv2d(kernel_size=5, in_channels=32, out_channels=48),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2)
        )
    
    def __classifier__(self):
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(in_features=self.encoder_shape, out_features=100),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=100, out_features=100),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=100, out_features=10),
        )
    
    def __domain__(self):
        self.domain =  torch.nn.Sequential(
            torch.nn.Linear(in_features=self.encoder_shape, out_features=100),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=100, out_features=1),
        )
    
    
    def __init__(self,):
        super(model, self).__init__()
        self.encoder_shape = 48*5*5   
        self.__encoder__()
        self.__classifier__()
        self.grl = grl()
        self.__domain__()
    
    
    def forward(self, x, split=False):
        encoded = self.encoder(x).reshape([-1, self.encoder_shape])
        batch_size = encoded.shape[0]
        start_index = 0
        end_index = batch_size// 2
        
        if split:
            start_index = end_index
            end_index= batch_size
        
        return {
                "classifier" : self.classifier(encoded[start_index:end_index]),
                "domain" : self.domain(encoded).squeeze(-1)
            }
    

In [168]:
def write_stats(stats, model, step):
    tensorboard_writer.add_scalar("loss/epoch/{}/domain".format(model), np.mean(loss['loss']['domain']), step)
    tensorboard_writer.add_scalar("loss/epoch/{}/classifier".format(model), np.mean(loss['loss']['classifier']), step)
    tensorboard_writer.add_scalar("acc/epoch/{}/domain".format(model), np.mean(loss['acc']['domain']), step)
    tensorboard_writer.add_scalar("acc/epoch/{}/classifier".format(model), np.mean(loss['acc']['classifier']), step)

In [179]:
class train:
    
    def __init__(self, model_to_train):
        self.model = model_to_train()
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 
        self.model.to(self.device)
        self.loss = {
            "logistic" : torch.nn.BCEWithLogitsLoss(),
            "softmax" : torch.nn.CrossEntropyLoss(),
        }
        self.optimizer = torch.optim.SGD(lr=0.01, momentum=0.9, params=self.model.parameters())
        self.__get_tensorboard()
    
    def loop(self, X, to_cpu=False):
        X = X.to(self.device)
        if to_cpu:
            return self.model(X).cpu()    
        return self.model(X)
    
    def compute_loss(self, X, output):
        print(X['target'].keys(), output.keys())
        loss = self.loss['softmax'](target=X['target']['classifier'], input=output['classifier'])
        loss += self.loss['logistic'](target=X['target']['domain'], input=output['domain'])
        return loss
    
    def stats(self, X):
        output = self.loop(X['input'])
        loss = self.compute_loss(X, output)
        classifier_acc = np.count_nonzero(torch.nn.softmax(output['classifier']).argmax(-1) == X['target']['classifier'])/len(X['target']['classifier'])
        target_out = self.model(X['input'], split=False)
        domain_acc = np.count_nonzero(torch.nn.softmax(target_out['classifier']).argmax(-1) == X['target']['domain_classifier'])/len(X['target']['domain_classifier'])
        
        return {
            "accuracy" : {
                "classifier" : classifier_acc,
                "domain" : domain_acc,
            },
            "loss" : loss
        }
    
    def cycle(self, X):
        self.optimizer.zero_grad()
        output = self.loop(X['input'])
        cur_loss = self.compute_loss(X, output)
        cur_loss.backward()
        self.optimizer.step()
        return cur_loss
    
    def batch(self, mnist_batch, mnist_m_batch, task):
        
        mnist_batch_size = mnist_batch['input'].shape[0]
        mnist_m_batch_size = mnist_m_batch['input'].shape[0]
        
        X = {
            "input" : torch.cat([mnist_batch['input'], mnist_m_batch['input']], axis=0).to(self.device),
            "target" : {
                "classifier" : mnist_batch['target'].to(self.device),
                "domain" : torch.cat([torch.ones(mnist_batch_size), torch.zeros(mnist_m_batch_size)], axis=0).to(self.device),
                "domain_classifier" : mnist_m_batch['target'].to(self.device)
            }
        }
        
        if task:
            return trainer.cycle(X)
        else:
            return train.stats(X)
    
    def __get_tensorboard(self):
        from torch.utils.tensorboard import SummaryWriter
        from datetime import datetime

        task_type = "MNIST"
        time_date = "{}".format(datetime.now())
        self.tensorboard_writer = SummaryWriter("log/{}/{}".format(task_type, time_date))

In [180]:
trainer = train(model)

In [174]:
best_acc = 0
for epoch in tqdm(range(1000)):
    
    mnist_m_iter = dataset['train']['mnist_m'].__iter__()
    for mnist_batch in dataset['train']['mnist']:
        
        mnist_m_batch = next(mnist_m_iter)
        loss = trainer.batch(mnist_batch, mnist_m_batch, True)
        trainer.tensorboard_writer.add_scalar("loss/batch/train", loss.item())
    
    stats = {
        "loss":{
            "domain":[],
            "classifer":[],
        },
        "acc" : {
            "domain":[],
            "classifier":[]
        }
    }
    mnist_m_iter = dataset['train']['mnist_m'].__iter__()
    for mnist_batch in dataset['train']['mnist']:
        
        mnist_m_batch = next(mnist_m_iter)
        batch_stats = trainer.batch(mnist_batch, mnist_m_batch, False)
        stats['loss']['domain'].append(batch_stats['loss']['domain'])
        stats['loss']['classifier'].append(batch_stats['loss']['classifier'])
        stats['acc']['domain'].append(batch_stats['acc']['domain'])
        stats['acc']['classifier'].append(batch_stats['acc']['classifier'])
    
    write_stats(stats, "train", epoch)
    
    stats = {
        "loss":{
            "domain":[],
            "classifer":[],
        },
        "acc" : {
            "domain":[],
            "classifier":[]
        }
    }
    mnist_m_iter = dataset['test']['mnist_m'].__iter__()
    for mnist_batch in dataset['test']['mnist']:
        
        mnist_m_batch = next(mnist_m_iter)
        batch_stats = trainer.batch(mnist_batch, mnist_m_batch, False)
        stats['loss']['domain'].append(batch_stats['loss']['domain'])
        stats['loss']['classifier'].append(batch_stats['loss']['classifier'])
        stats['acc']['domain'].append(batch_stats['acc']['domain'])
        stats['acc']['classifier'].append(batch_stats['acc']['classifier'])
    
    write_stats(stats, "test", epoch)
    
    
    if np.mean(stats['acc']['domain']) > best_acc:
        torch.save(segnet_model.state_dict(), "./weights/q3_{}".format(epoch))
        
print('Finished Training')

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

dict_keys(['classifier', 'domain', 'domain_classifier']) dict_keys(['classifier', 'domain'])


NameError: name 'tensorboard_writer' is not defined