In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import os
import torch.nn.init as init
from torchvision import datasets, transforms
from torch import optim, nn
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import logging
import math
import copy

In [2]:
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = './output'

In [3]:
def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.75 * param_group['lr']

# Dataloader

In [4]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [5]:
def task_dataloader(task_num):
    task_dir = {0: [0, 1], 1: [2, 3], 2: [4, 5], 3: [6, 7], 4:[8, 9]}

    train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=tf)

    indices = [i for i, (_, label) in enumerate(train_dataset) if label in task_dir[task_num]]
    dataset = Subset(train_dataset, indices)
    task_train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    indices = [i for i, (_, label) in enumerate(test_dataset) if label in task_dir[task_num]]
    dataset = Subset(test_dataset, indices)
    task_test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return task_train_loader, task_test_loader

def load_all_data():
    train_dataset = datasets.MNIST(root='./data', train=True, transform=tf, download=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, test_loader

In [6]:
train_loader, test_loader = load_all_data()
task_dataloaders = {}
for task in range(0, 5):
    train_dl, test_ld = task_dataloader(task)
    task_dataloaders[task] = (train_dl, test_ld)

# Model

In [74]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.values = nn.Linear(embed_size, embed_size, bias=False)
        self.keys = nn.Linear(embed_size, embed_size, bias=False)
        self.queries = nn.Linear(embed_size, embed_size, bias=False)
        self.scale = torch.sqrt(torch.FloatTensor([embed_size])).to(device)

    def forward(self, values, keys, query):
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(query)

        energy = torch.matmul(queries, keys.transpose(-2, -1)) / self.scale
        attention = torch.softmax(energy, dim=-1)
        out = torch.matmul(attention, values)

        return out

class DeepModelWithAttention(nn.Module):
    def __init__(self):
        super(DeepModelWithAttention, self).__init__()
        self.fc1 = nn.Linear(128 * 2, 256)
        self.fc2 = nn.Linear(256, 256)
        self.attention = SelfAttention(embed_size=256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, 1)
        
    def forward(self, x):
        x = x.view(x.size(0), 256) 
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        
        x = x.unsqueeze(1) 
        x = self.attention(x, x, x)
        x = x.squeeze(1) 
        
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x.squeeze()

In [75]:
class MLP_Enhance(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=400):
        super(MLP_Enhance, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Softmax(dim=1)
        )
        self.init_weights()

    def init_weights(self):
        for layer in self.linear:
            if isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(layer.bias, 0)
        last_linear = self.last[0]
        init.xavier_normal_(last_linear.weight)
        init.constant_(last_linear.bias, 0)


    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [76]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    

    return correct / total

In [77]:
def plot_acc_history(all_accuracies):
    plt.figure(figsize=(18, 12))

    for i, task_accuracies in enumerate(all_accuracies):
        plt.subplot(2, 3, i+1)
        for task, accs in task_accuracies.items():
            plt.plot(accs,'-', label=f'Task {task}')
        plt.title(f'Accuracy per Mini-Batch for Task {i}')
        plt.xlabel('Mini-Batch Number')
        plt.ylabel('Accuracy (%)')
        plt.ylim(-5, 105)
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.show()


def plot_task_acc_history(all_accuracies, save_name, epochs):
    for task in [0, 1, 2, 3, 4]:
        if (task == 0):
            task_history = []
        else:
            task_history =  [0] * (epochs * task)
        for i, task_accuracies in enumerate (all_accuracies):
            if (i >= task):
                task_history.extend(task_accuracies[task])
                    
        plt.plot(task_history, '-', label=f'Task {task}')

    plt.xlabel('Mini-Batch Number')
    plt.ylabel('Accuracy (%)')
    plt.ylim(-5, 105)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_name) 
    plt.show()

# Train Function

In [82]:
import random
def pretrain_meta(meta_model, optimizer_M):
    target = torch.full((5, ), 2.0).to(device)
    criterion = nn.MSELoss()

    print("Pretrain on meta_model")
    num_epochs = 50
    for epoch in range(num_epochs):
        prev_outputs = None
        for k in range (5):
            past_output = F.softmax(torch.rand(batch_size, 10), dim=-1).to(device)
            outputs = F.softmax(torch.rand(batch_size, 10), dim=-1).to(device)
            diff = past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2]
            diff = diff.unsqueeze(0)
            if (prev_outputs == None):
                prev_outputs = diff
            else:
                prev_outputs = torch.cat((prev_outputs, diff), dim=0)

        meta_out = meta_model(prev_outputs).view(-1, 1, 1)
        meta_model.train()
        optimizer_M.zero_grad()
        loss = criterion(meta_out, target)
        loss.backward()
        optimizer_M.step()
        
        if((epoch + 1) % 10 == 0):
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [83]:
def train(model, task_num, criterion, optimizer, prev_models, task_dataloaders, epoches = 4):
    train_loader = task_dataloaders[task_num][0]
    task_accuracies = {task: [] for task in range(task_num + 1)}
    distribution_factor  = 20

    meta_model = DeepModelWithAttention().to(device)
    optimizer_M = optim.Adam(meta_model.parameters(), lr=0.001)
    pretrain_meta(meta_model, optimizer_M)
    
    
    valid_out_dim = task_num * 2 + 2
    logging.info(f"##########Task {task_num}##########")
    for e in range(epoches):
        logging.info(f"Epoch {e}")
        batch_num = 0
        for images, labels in train_loader:
            model.train()
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            random_data = torch.rand(256, 1, 28, 28).to(device)
            random_data = (random_data - 0.5) / 0.5
            random_target = torch.randint(0, 10, (256,), dtype=torch.int64).to(device)
            fake_output = model(random_data)
            dis_loss = distribution_factor * criterion(fake_output, random_target)

            prev_outputs = None

            for k, prev_mod in enumerate(prev_models):
                with torch.no_grad():
                    past_output = prev_mod(images)
                
                diff = past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2]
                diff = diff.unsqueeze(0)
                if (prev_outputs == None):
                    prev_outputs = diff
                else:
                    prev_outputs = torch.cat((prev_outputs, diff), dim=0)
                
            reg_loss = 0
            if(prev_outputs is not None):
                print(prev_outputs.size())
                meta_out = meta_model(prev_outputs).view(-1, 1, 1)
                reg_loss = torch.mean(meta_out * diff)
            
                if (batch_num % 20 == 0):
                    print(meta_out)
                    
            loss = criterion(outputs[:,:valid_out_dim], labels) + dis_loss + reg_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            batch_num += 1

        adjust_learning_rate(optimizer)

        avg_acc = 0
        log_message = []
        log_message.append(f"Epoch: {e}, ")
        for task in range(task_num + 1):
            acc = cal_acc(model, task_dataloaders[task][0], device)
            avg_acc += acc
            task_accuracies[task].append(acc * 100)
            log_message.append(f"Task {task} acc: {acc * 100:.4f}")

        logging.info(', '.join(log_message) + f", Task avg acc:{avg_acc*100/(task_num + 1):.4f}")
    
    return model, task_accuracies

# Train

In [84]:
def setup_logging(file_name):
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(file_name),
                            logging.StreamHandler()
                        ],
                        force=True)
    

def train_split_task(task_name, task_dataloaders):
    print(f"Training on {task_name}")
    prev_models = []
    model = MLP_Enhance()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.005)
    criterion = nn.CrossEntropyLoss()

    epochs, lr = 3, 0.005
    file_name = f"{output_dir}/{task_name}_epochs={epochs}"
    log_file_name = file_name + '.log'
    img_file_name = file_name + '.png'

    setup_logging(log_file_name)
    all_accuracies = []
    for task in [0, 1, 2, 3, 4]:
        model, task_accuracies = train(model, task, criterion, optimizer, prev_models, task_dataloaders, epoches = 4)
        all_accuracies.append(task_accuracies)
        prev_models.append(copy.deepcopy(model))

    plot_task_acc_history(all_accuracies, img_file_name, epochs)

In [85]:
train_split_task("0", task_dataloaders)

Training on 0
Pretrain on meta_model
Epoch [1/20], Loss: 4.4301
Epoch [2/20], Loss: 4.2952
Epoch [3/20], Loss: 4.1678
Epoch [4/20], Loss: 4.0422
Epoch [5/20], Loss: 3.9155
Epoch [6/20], Loss: 3.7835
Epoch [7/20], Loss: 3.6075
Epoch [8/20], Loss: 3.4123
Epoch [9/20], Loss: 3.1477
Epoch [10/20], Loss: 2.8265
Epoch [11/20], Loss: 2.4944
Epoch [12/20], Loss: 2.0126


  return F.mse_loss(input, target, reduction=self.reduction)
2024-06-01 13:09:28,699 - INFO - ##########Task 0##########
2024-06-01 13:09:28,699 - INFO - Epoch 0


Epoch [13/20], Loss: 1.5425
Epoch [14/20], Loss: 1.1233
Epoch [15/20], Loss: 0.5338
Epoch [16/20], Loss: 0.1572
Epoch [17/20], Loss: 0.0038
Epoch [18/20], Loss: 0.2287
Epoch [19/20], Loss: 0.4810
Epoch [20/20], Loss: 0.7749


2024-06-01 13:09:34,530 - INFO - Epoch: 0, , Task 0 acc: 96.8511, Task avg acc:96.8511
2024-06-01 13:09:34,530 - INFO - Epoch 1
2024-06-01 13:09:40,248 - INFO - Epoch: 1, , Task 0 acc: 91.6374, Task avg acc:91.6374
2024-06-01 13:09:40,249 - INFO - Epoch 2
2024-06-01 13:09:45,401 - INFO - Epoch: 2, , Task 0 acc: 99.2746, Task avg acc:99.2746
2024-06-01 13:09:45,401 - INFO - Epoch 3
2024-06-01 13:09:50,435 - INFO - Epoch: 3, , Task 0 acc: 97.8635, Task avg acc:97.8635
2024-06-01 13:09:50,593 - INFO - ##########Task 1##########
2024-06-01 13:09:50,594 - INFO - Epoch 0


Pretrain on meta_model
Epoch [1/20], Loss: 4.1522
Epoch [2/20], Loss: 4.0151
Epoch [3/20], Loss: 3.8882
Epoch [4/20], Loss: 3.7642
Epoch [5/20], Loss: 3.6257
Epoch [6/20], Loss: 3.4694
Epoch [7/20], Loss: 3.2825
Epoch [8/20], Loss: 3.0503
Epoch [9/20], Loss: 2.7943
Epoch [10/20], Loss: 2.4537
Epoch [11/20], Loss: 2.0693
Epoch [12/20], Loss: 1.6587
Epoch [13/20], Loss: 1.1896
Epoch [14/20], Loss: 0.7060
Epoch [15/20], Loss: 0.2740
Epoch [16/20], Loss: 0.0242
Epoch [17/20], Loss: 0.0694
Epoch [18/20], Loss: 0.4538
Epoch [19/20], Loss: 0.7560
Epoch [20/20], Loss: 0.6252
torch.Size([1, 128, 2])
tensor([[[3.2678]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
tor

2024-06-01 13:09:58,394 - INFO - Epoch: 0, , Task 0 acc: 87.5159, Task 1 acc: 0.0000, Task avg acc:43.7580
2024-06-01 13:09:58,395 - INFO - Epoch 1


torch.Size([1, 128, 2])
tensor([[[10.8942]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
tensor([[[10.4985]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Si

2024-06-01 13:10:05,854 - INFO - Epoch: 1, , Task 0 acc: 44.2203, Task 1 acc: 0.0000, Task avg acc:22.1102
2024-06-01 13:10:05,855 - INFO - Epoch 2


torch.Size([1, 128, 2])
tensor([[[11.3684]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
tensor([[[10.8486]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Si

2024-06-01 13:10:13,374 - INFO - Epoch: 2, , Task 0 acc: 39.1661, Task 1 acc: 0.0665, Task avg acc:19.6163
2024-06-01 13:10:13,375 - INFO - Epoch 3


torch.Size([1, 128, 2])
tensor([[[11.4031]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
tensor([[[11.1482]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Size([1, 128, 2])
torch.Si

2024-06-01 13:10:20,949 - INFO - Epoch: 3, , Task 0 acc: 41.5816, Task 1 acc: 0.0000, Task avg acc:20.7908
2024-06-01 13:10:21,102 - INFO - ##########Task 2##########
2024-06-01 13:10:21,103 - INFO - Epoch 0


Pretrain on meta_model
Epoch [1/20], Loss: 3.7295
Epoch [2/20], Loss: 3.5807
Epoch [3/20], Loss: 3.4235
Epoch [4/20], Loss: 3.2507
Epoch [5/20], Loss: 3.0820
Epoch [6/20], Loss: 2.8924
Epoch [7/20], Loss: 2.6741
Epoch [8/20], Loss: 2.4126
Epoch [9/20], Loss: 2.1083
Epoch [10/20], Loss: 1.7231
Epoch [11/20], Loss: 1.3163
Epoch [12/20], Loss: 0.8885
Epoch [13/20], Loss: 0.4907
Epoch [14/20], Loss: 0.1250
Epoch [15/20], Loss: 0.0050
Epoch [16/20], Loss: 0.1648
Epoch [17/20], Loss: 0.4721
Epoch [18/20], Loss: 0.6772
Epoch [19/20], Loss: 0.6102
Epoch [20/20], Loss: 0.4330
torch.Size([2, 128, 2])
tensor([[[8.2616]],

        [[2.3298]]], device='cuda:0', grad_fn=<ViewBackward0>)
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.Size([2, 128, 2])
torch.

KeyboardInterrupt: 