In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import os
import torch.nn.init as init
from torchvision import datasets, transforms
from torch import optim, nn
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import logging
import math
import copy

In [2]:
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = './final_out'

In [3]:
def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.75 * param_group['lr']

# Dataloader

In [4]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [5]:
def task_dataloader(task_num):
    task_dir = {0: [0, 1], 1: [2, 3], 2: [4, 5], 3: [6, 7], 4:[8, 9]}

    train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=tf)

    indices = [i for i, (_, label) in enumerate(train_dataset) if label in task_dir[task_num]]
    dataset = Subset(train_dataset, indices)
    task_train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    indices = [i for i, (_, label) in enumerate(test_dataset) if label in task_dir[task_num]]
    dataset = Subset(test_dataset, indices)
    task_test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return task_train_loader, task_test_loader

def load_all_data():
    train_dataset = datasets.MNIST(root='./data', train=True, transform=tf, download=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, test_loader

In [6]:
train_loader, test_loader = load_all_data()
task_dataloaders = {}
for task in range(0, 5):
    train_dl, test_ld = task_dataloader(task)
    task_dataloaders[task] = (train_dl, test_ld)

# Model

In [7]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.scale = torch.sqrt(torch.FloatTensor([embed_size])).to(device)

    def forward(self, values, keys, query):
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(query)

        energy = torch.matmul(queries, keys.transpose(-2, -1)) / self.scale
        attention = torch.softmax(energy, dim=-1)
        out = torch.matmul(attention, values)

        return out

class DeepModelWithAttention(nn.Module):
    def __init__(self):
        super(DeepModelWithAttention, self).__init__()
        self.fc1 = nn.Linear(128 * 10, 512)
        self.fc2 = nn.Linear(512, 256)
        self.attention = SelfAttention(embed_size=256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 1)
        
    def forward(self, x):
        x = x.view(x.size(0), -1) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        x = x.unsqueeze(1) 
        x = self.attention(x, x, x)
        x = x.squeeze(1) 
        
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x.squeeze()

In [8]:
class MLP_Enhance(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=400):
        super(MLP_Enhance, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Softmax(dim=1)
        )
        self.init_weights()

    def init_weights(self):
        for layer in self.linear:
            if isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(layer.bias, 0)
        last_linear = self.last[0]
        init.xavier_normal_(last_linear.weight)
        init.constant_(last_linear.bias, 0)


    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [9]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [10]:
def plot_acc_history(all_accuracies):
    plt.figure(figsize=(18, 12))

    for i, task_accuracies in enumerate(all_accuracies):
        plt.subplot(2, 3, i+1)
        for task, accs in task_accuracies.items():
            plt.plot(accs,'-', label=f'Task {task}')
        plt.title(f'Accuracy per Mini-Batch for Task {i}')
        plt.xlabel('Mini-Batch Number')
        plt.ylabel('Accuracy (%)')
        plt.ylim(-5, 105)
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.show()


def plot_task_acc_history(all_accuracies, save_name, epochs):
    for task in [0, 1, 2, 3, 4]:
        if (task == 0):
            task_history = []
        else:
            task_history =  [0] * (epochs * task)
        for i, task_accuracies in enumerate (all_accuracies):
            if (i >= task):
                task_history.extend(task_accuracies[task])
                    
        plt.plot(task_history, '-', label=f'Task {task}')

    plt.xlabel('Mini-Batch Number')
    plt.ylabel('Accuracy (%)')
    plt.ylim(-5, 105)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_name) 
    plt.show()

# Train Function

In [11]:
def pretrain_meta(meta_model, optimizer_M):
    target = torch.full((batch_size, ), 2.0)
    criterion = nn.MSELoss()

    print("Pretrain on meta_model")
    num_epochs = 10
    for epoch in range(num_epochs):
        for _ in range (10):
            input = F.softmax(torch.rand(batch_size, batch_size, 10), dim=-1)
            meta_model.train()
            optimizer_M.zero_grad()
            outputs = meta_model(input)
            loss = criterion(outputs, target)
            loss.backward()
            optimizer_M.step()
            
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [12]:
def train(model, task_num, criterion, optimizer, prev_models, task_dataloaders, epoches = 4):
    train_loader = task_dataloaders[task_num][0]
    task_accuracies = {task: [] for task in range(task_num + 1)}
    distribution_factor  = 20

    meta_model = DeepModelWithAttention().to(device)
    optimizer_M = optim.Adam(meta_model.parameters(), lr=0.001)
    pretrain_meta(meta_model, optimizer_M)
    
    
    valid_out_dim = task_num * 2 + 2
    logging.info(f"##########Task {task_num}##########")
    for e in range(epoches):
        logging.info(f"Epoch {e}")
        batch_num = 0
        for images, labels in train_loader:
            model.train()
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            random_data = torch.rand(256, 1, 28, 28).to(device)
            random_data = (random_data - 0.5) / 0.5
            random_target = torch.randint(0, 10, (256,), dtype=torch.int64).to(device)
            fake_output = model(random_data)
            dis_loss = distribution_factor * criterion(fake_output, random_target)

            prev_outputs = None

            for k, prev_mod in enumerate(prev_models):
                with torch.no_grad():
                    past_output = prev_mod(images).unsqueeze(0)

                if (prev_outputs == None):
                    prev_outputs = past_output
                else:
                    prev_outputs = torch.cat((prev_outputs, past_output), dim=0)
                
            
            if(batch_num ==0 and prev_outputs is not None):
                meta_out = meta_model(prev_outputs)
                print(meta_out)
                    
            loss = criterion(outputs[:,:valid_out_dim], labels) + dis_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_num += 1

        adjust_learning_rate(optimizer)

        avg_acc = 0
        log_message = []
        log_message.append(f"Epoch: {e}, ")
        for task in range(task_num + 1):
            acc = cal_acc(model, task_dataloaders[task][0], device)
            avg_acc += acc
            task_accuracies[task].append(acc * 100)
            log_message.append(f"Task {task} acc: {acc * 100:.4f}")

        logging.info(', '.join(log_message) + f", Task avg acc:{avg_acc*100/(task_num + 1):.4f}")
    
    return model, task_accuracies

# Train

In [13]:
def train_split_task(task_name, task_dataloaders):
    print(f"Training on {task_name}")
    prev_models = []
    model = MLP_Enhance()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.005)
    criterion = nn.CrossEntropyLoss()

    all_accuracies = []
    for task in [0, 1, 2, 3, 4]:
        model, task_accuracies = train(model, task, criterion, optimizer, prev_models, task_dataloaders, epoches = 4)
        all_accuracies.append(task_accuracies)
        prev_models.append(copy.deepcopy(model))

In [14]:
train_split_task("0", task_dataloaders)

Training on 0
Pretrain on meta_model
Epoch [10/100], Loss: 0.2461
Epoch [20/100], Loss: 0.0694
Epoch [30/100], Loss: 0.0278
Epoch [40/100], Loss: 0.0109
Epoch [50/100], Loss: 0.0026
Epoch [60/100], Loss: 0.0002
Epoch [70/100], Loss: 0.0002
Epoch [80/100], Loss: 0.0002
Epoch [90/100], Loss: 0.0000
Epoch [100/100], Loss: 0.0000
Pretrain on meta_model
Epoch [10/100], Loss: 0.2302
Epoch [20/100], Loss: 0.0709
Epoch [30/100], Loss: 0.0287
Epoch [40/100], Loss: 0.0127
Epoch [50/100], Loss: 0.0034
Epoch [60/100], Loss: 0.0005
Epoch [70/100], Loss: 0.0001
Epoch [80/100], Loss: 0.0002
Epoch [90/100], Loss: 0.0000
Epoch [100/100], Loss: 0.0000
tensor(0.1052, device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor(0.1052, device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor(0.1052, device='cuda:0', grad_fn=<SqueezeBackward0>)
tensor(0.1052, device='cuda:0', grad_fn=<SqueezeBackward0>)
Pretrain on meta_model
Epoch [10/100], Loss: 0.6806
Epoch [20/100], Loss: 0.0669
Epoch [30/100], Loss: 0.0174
Epoch [4