In [0]:
import os
root_dir = "/content/drive/My Drive/bayesian-multitask"
os.chdir(root_dir)

In [0]:
import math
import time
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import random
from sklearn.preprocessing import KBinsDiscretizer
import torchvision
from torchvision import datasets, transforms
from mnist import MNIST
from cifar import CIFAR10

#num_tasks = 1
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class Data(Dataset):

    def __init__(self, feature_num, X, Y):
        self.num_tasks = len(Y)
        self.feature_num = feature_num

        self.X = []
        #self.X = torch.tensor(X, dtype=torch.float32, device=device)
        self.Y = []
        for i in range(self.num_tasks):
            self.X.append(0)
            self.Y.append(0)
        for i in range(self.num_tasks):
            #self.Y[i] = torch.from_numpy(Y[i])
            self.X[i] = torch.tensor(X[i], dtype=torch.float32, device=device)
            self.Y[i] = torch.tensor(Y[i], dtype=torch.float32, device=device)
    def __len__(self):
        return self.feature_num

    def __getitem__(self, idx):
        return [self.X[i][idx,:] for i in range(self.num_tasks)], [self.Y[i][idx,:] for i in range(self.num_tasks)]

class MultiTaskLossWrapper(nn.Module):
    def __init__(self, num_tasks, model, regression=True):
        super(MultiTaskLossWrapper, self).__init__()
        self.model = model
        self.num_tasks = num_tasks
        self.log_vars = nn.Parameter(torch.zeros((num_tasks), device=device))
        self.regression = regression


    def forward(self, input, targets, i):
        #print(targets)
        outputs = self.model(input)
        loss = 0
        task_losses = [0] * self.num_tasks
        precision = [0] * self.num_tasks
        if not self.regression:
            loss_fn = nn.NLLLoss()
        precision[i] = 0.5 * torch.exp(-self.log_vars[i])
        if self.regression:
            task_loss = torch.sum(precision[i] * (targets - outputs[i]) ** 2. + self.log_vars[i], -1)
        else:
            #print(outputs[i])
            #print(targets[i])
            #print(loss_fn(outputs[i], targets[i].long().squeeze()))
            #print(precision[i] * loss_fn(outputs[i], targets[i].long().squeeze()) + self.log_vars[i])
            task_loss = torch.sum(precision[i] * loss_fn(outputs[i], targets.long().squeeze()) + self.log_vars[i], -1)
            #print(task_loss)
            #print(1/0)
        task_losses[i] = task_loss
        loss += task_loss
        
        #loss = loss / self.num_tasks

        return torch.mean(loss), task_losses, self.log_vars.data.tolist()


class MTLModel(torch.nn.Module):
    def __init__(self, num_tasks):
        super(MTLModel, self).__init__()
        self.num_tasks = num_tasks
        self.shared_fc = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 64), nn.ReLU())
        self.nets = [0] * num_tasks
        
        for i in range(num_tasks):
            self.nets[i] = nn.Sequential(nn.Linear(64, 10), nn.LogSoftmax(dim=1)).to(device)
        
    def forward(self, x):
        shared_out = self.shared_fc(x)
        return [self.nets[i](shared_out) for i in range(self.num_tasks)]



In [0]:


def gen_mnist_data(sigmas, datasets=['mnist', 'fashion']):
    X = []
    Y = []
    X_test = []
    Y_test = []
    for idx, sigma in enumerate(sigmas):
        if datasets[idx] == 'mnist':
            mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True,
                                              transform=torchvision.transforms.Compose([
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                              ]))
            mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True,
                                              transform=torchvision.transforms.Compose([
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                              ]))
        elif datasets[idx] == 'fashion':
            mnist_train = torchvision.datasets.FashionMNIST('./data', train=True, download=True,
                                              transform=torchvision.transforms.Compose([
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                              ]))
            mnist_test = torchvision.datasets.FashionMNIST('./data', train=False, download=True,
                                              transform=torchvision.transforms.Compose([
                                                  torchvision.transforms.ToTensor(),
                                                  torchvision.transforms.Normalize((0.1307,), (0.3081,))
                                              ]))
        X.append(mnist_train.data.numpy())
        y = mnist_train.targets.numpy()
        X_test.append(mnist_test.data.numpy())
        y_test = mnist_test.targets.numpy()
        #Y.append(y[..., np.newaxis])
        #Y_test.append(y_test[..., np.newaxis])
    
        y_noise = np.zeros_like(y)
        y_test_noise = np.zeros_like(y_test)
        for i in range(len(y)):
            if abs(np.random.normal(0, sigma, 1)) > 2:
                y_noise[i] = random.choice(list(range(0, y[i])) + list(range(y[i]+1, num_classes)))
            else:
                y_noise[i] = y[i]
        for i in range(len(y_test)):
            if abs(np.random.normal(0, 0, 1)) > 2:
                y_test_noise[i] = random.choice(list(range(0, y_test[i])) + list(range(y_test[i]+1, num_classes)))
            else:
                y_test_noise[i] = y_test[i]
        Y.append(y_noise[..., np.newaxis])
        Y_test.append(y_test_noise[..., np.newaxis])
    print('Done creating data...')
    return X, Y, X_test, Y_test



def calc_acc(model, train_data_loader, val_data_loader, epoch, num_tasks):
    with torch.no_grad():
        correct_train_list = [0] * num_tasks
        total_train_list = [0] * num_tasks
        correct_test_list = [0] * num_tasks
        total_test_list = [0] * num_tasks

        for X, Y in train_data_loader:
            for i in range(num_tasks):
                images, labels = X[i], Y[i]
                images = images.to(device)
            
                train_outputs = model(images.reshape(batch_size, 784))
                correct_train, total_train = 0, 0
                labels = torch.flatten(labels.to(device))
                #train_outputs = model.nets[0](model.shared_fc(images.reshape(100, 784)))
                _, pred_train = torch.max(train_outputs[i], 1)
                total_train = labels.shape[0]
                correct_train = (pred_train == labels).sum().item()
                correct_train_list[i] += correct_train
                total_train_list[i] += total_train
        for X_test, Y_test in val_data_loader:
            for i in range(num_tasks):
                images, labels = X_test[i], Y_test[i]
                images = images.to(device)
            
                test_outputs = model(images.reshape(batch_size, 784))
                #test_outputs = model.nets[0](model.shared_fc(images.reshape(100, 784)))
                for i in range(num_tasks):
                    correct_test, total_test = 0, 0
                    labels = torch.flatten(labels.to(device))
                    _, pred_test = torch.max(test_outputs[i], 1)
                    total_test = labels.shape[0]
                    correct_test = (pred_test == labels).sum().item()
                    correct_test_list[i] += correct_test
                    total_test_list[i] += total_test
    
    print('Train acc: ', [correct_train_list[i] / total_train_list[i] for i in range(num_tasks)])
    print('   Val acc: ', [correct_test_list[i] / total_test_list[i] for i in range(num_tasks)])
    return correct_train, total_train, correct_test, total_test




In [0]:
seed = 41
np.random.seed(seed)

feature_num = 60000
val_feature_num = 10000
nb_epoch = 1000
batch_size = 200
hidden_dim = 512
num_classes = 10
lr = 0.00001
patience = 1000
delta = 1e-4
max_num_tasks = 2

random_sigma = False
regression = False
import random


if random_sigma:
    sigmas = [random.randint(1, 3) for _ in range(max_num_tasks)]
else:
    sigmas = [2 for _ in range(max_num_tasks)]

In [0]:
#X, Y_data, X_val, Y_val_data = gen_mnist_data(sigma)
X, Y_data, X_val, Y_val_data = gen_mnist_data(sigmas=[1,3], datasets=['mnist', 'mnist'])
X = [x.reshape(60000, 784) for x in X]
X_val = [x_val.reshape(10000, 784) for x_val in X_val]
for num_tasks in range(2, max_num_tasks+1):
    Y = Y_data[:num_tasks]
    Y_val = Y_val_data[:num_tasks]
    lowest_val_loss = None
    counter = 0
    early_stop = False

    train_data = Data(feature_num, X, Y)
    val_data = Data(val_feature_num, X_val, Y_val)
    train_data_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
    val_data_loader = DataLoader(val_data, shuffle=True, batch_size=batch_size)

    if regression:
        model = MTLModel(num_tasks)
        mtl = MultiTaskLossWrapper(num_tasks, model)
    else:
        model = MTLModel(num_tasks)
        mtl = MultiTaskLossWrapper(num_tasks, model, regression=False)
    
    model.to(device)
    mtl.to(device)

    # https://github.com/keras-team/keras/blob/master/keras/optimizers.py
    # k.epsilon() = keras.backend.epsilon()
    optimizer = torch.optim.SGD(mtl.parameters(), lr=lr)

    loss_list = []
    val_loss_list = []
    times = []
    for t in range(nb_epoch):
        cumulative_loss = 0.0
        cumulative_val_loss = 0.0
        cumulative_task_losses = [0] * num_tasks
        cumulative_task_losses_val = [0] * num_tasks
        loss = 0.0
        for X_batch, Y_batch in train_data_loader:
            for idx in range(num_tasks):
                X_batch_idx, Y_batch_idx = X_batch[idx].to(device), Y_batch[idx].to(device)

                task_loss, task_losses, log_vars = mtl(X_batch_idx, Y_batch_idx, idx)
                cumulative_task_losses[idx] += task_losses[idx]
                cumulative_loss += task_loss.item()
                if idx == 0:
                    loss = task_loss
                else:
                    loss += task_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        with torch.no_grad():
            for X_val_batch, Y_val_batch in val_data_loader:
                for idx in range(num_tasks):
                    X_val_batch_idx, Y_val_batch_idx = X_val_batch[idx].to(device), Y_val_batch[idx].to(device)
                    val_task_loss, task_losses_val, _ = mtl(X_val_batch_idx, Y_val_batch_idx, idx)
                    cumulative_task_losses_val[idx] += task_losses_val[idx]
                    cumulative_val_loss += val_task_loss.item()

        loss_list.append(cumulative_loss/(feature_num / batch_size))
        val_loss_list.append(cumulative_val_loss/(val_feature_num / batch_size))
        
        val_loss_batch = cumulative_val_loss/(val_feature_num / batch_size)
        if lowest_val_loss is None:
            lowest_val_loss = val_loss_batch
        elif val_loss_batch > lowest_val_loss - delta:
            counter += 1
            if counter >= patience:
                early_stop = True
        else:
            lowest_val_loss = val_loss_batch
            counter = 0
        if t % 50 == 0:
            correct_train, total_train, correct_test, total_test = calc_acc(model, train_data_loader, val_data_loader, t, num_tasks)
            print('   Log vars: ', [math.exp(log_var) ** 0.5 for log_var in log_vars])
        if early_stop:
            print('Epochs:', t)
            break
    correct_train, total_train, correct_test, total_test = calc_acc(model, train_data_loader, val_data_loader, t, num_tasks)
    pred_log_vars = [math.exp(log_var) ** 0.5 for log_var in log_vars]

    print('Finished Task', num_tasks)
