In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import os
import torch.nn.init as init
from torchvision import datasets, transforms
from torch import optim, nn
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import logging
import math
import copy

In [2]:
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = './output'

In [3]:
def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.75 * param_group['lr']

# Dataloader

In [4]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [5]:
def task_dataloader(task_num):
    task_dir = {0: [0, 1], 1: [2, 3], 2: [4, 5], 3: [6, 7], 4:[8, 9]}

    train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=tf)

    indices = [i for i, (_, label) in enumerate(train_dataset) if label in task_dir[task_num]]
    dataset = Subset(train_dataset, indices)
    task_train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    indices = [i for i, (_, label) in enumerate(test_dataset) if label in task_dir[task_num]]
    dataset = Subset(test_dataset, indices)
    task_test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return task_train_loader, task_test_loader

def load_all_data():
    train_dataset = datasets.MNIST(root='./data', train=True, transform=tf, download=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, test_loader

In [6]:
train_loader, test_loader = load_all_data()
task_dataloaders = {}
for task in range(0, 5):
    train_dl, test_ld = task_dataloader(task)
    task_dataloaders[task] = (train_dl, test_ld)

# Model

In [7]:
class DeepModelWithAttention(nn.Module):
    def __init__(self, desired_sum):
        super(DeepModelWithAttention, self).__init__()
        self.desired_sum = desired_sum

        self.attn = nn.MultiheadAttention(embed_dim=2, num_heads= 1)
        self.fc = nn.Linear(2, 1)
        
    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        attn_output = attn_output.mean(dim = 1)
        output = self.fc(attn_output).squeeze(-1) 
        
        output = F.softmax(output, dim = 0) * self.desired_sum

        return output

In [8]:
class MLP_Enhance(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=400):
        super(MLP_Enhance, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Softmax(dim=1)
        )
        self.init_weights()

    def init_weights(self):
        for layer in self.linear:
            if isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(layer.bias, 0)
        last_linear = self.last[0]
        init.xavier_normal_(last_linear.weight)
        init.constant_(last_linear.bias, 0)


    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [9]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [10]:
def plot_acc_history(all_accuracies):
    plt.figure(figsize=(18, 12))

    for i, task_accuracies in enumerate(all_accuracies):
        plt.subplot(2, 3, i+1)
        for task, accs in task_accuracies.items():
            plt.plot(accs,'-', label=f'Task {task}')
        plt.title(f'Accuracy per Mini-Batch for Task {i}')
        plt.xlabel('Mini-Batch Number')
        plt.ylabel('Accuracy (%)')
        plt.ylim(-5, 105)
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.show()


def plot_task_acc_history(all_accuracies, save_name, epochs):
    for task in [0, 1, 2, 3, 4]:
        if (task == 0):
            task_history = []
        else:
            task_history =  [0] * (epochs * task)
        for i, task_accuracies in enumerate (all_accuracies):
            if (i >= task):
                task_history.extend(task_accuracies[task])
                    
        plt.plot(task_history, '-', label=f'Task {task}')

    plt.xlabel('Mini-Batch Number')
    plt.ylabel('Accuracy (%)')
    plt.ylim(-5, 105)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_name) 
    plt.show()

# Train Function

In [11]:
def pretrain_meta(meta_model, optimizer_M, task_num):
    target = torch.full((task_num, 1, 1), 2.0).to(device)
    criterion = nn.MSELoss()

    print("Pretrain on meta_model")
    for epoch in range(400):
        prev_outputs = None
        for k in range (task_num):
            past_output = F.softmax(torch.rand(batch_size, 10), dim=-1).to(device)
            outputs = F.softmax(torch.rand(batch_size, 10), dim=-1).to(device)
            diff = past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2]
            diff = diff.unsqueeze(0)
            if (prev_outputs == None):
                prev_outputs = diff
            else:
                prev_outputs = torch.cat((prev_outputs, diff), dim=0)

        meta_out = meta_model(prev_outputs).view(-1, 1, 1 )

        meta_model.train()
        optimizer_M.zero_grad()
        loss = criterion(meta_out, target)
        loss.backward()
        optimizer_M.step()
        
        if((epoch + 1) % 100 == 0):
            print(f'Epoch [{epoch+1}/{400}], Loss: {loss.item():.4f}')

In [12]:
task_num = 1
meta_model = DeepModelWithAttention(desired_sum = 2 * task_num).to(device)
optimizer_M = optim.Adam(meta_model.parameters(), lr=0.001)
pretrain_meta(meta_model, optimizer_M, task_num)

Pretrain on meta_model
Epoch [100/400], Loss: 0.0000
Epoch [200/400], Loss: 0.0000
Epoch [300/400], Loss: 0.0000
Epoch [400/400], Loss: 0.0000


In [31]:
def train(model, task_num, criterion, optimizer, prev_models, task_dataloaders, epoches = 4):
    train_loader = task_dataloaders[task_num][0]
    task_accuracies = {task: [] for task in range(task_num + 1)}
    distribution_factor  = 20

    if (task_num != 0):
        meta_model = DeepModelWithAttention(desired_sum = 2 * task_num).to(device)
        optimizer_M = optim.Adam(meta_model.parameters(), lr=0.001)
        pretrain_meta(meta_model, optimizer_M, task_num)
    
    valid_out_dim = task_num * 2 + 2
    logging.info(f"##########Task {task_num}##########")
    for e in range(epoches):
        logging.info(f"Epoch {e}")
        batch_num = 0
        for images, labels in train_loader:
            model.train()
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            random_data = torch.rand(256, 1, 28, 28).to(device)
            random_data = (random_data - 0.5) / 0.5
            random_target = torch.randint(0, 10, (256,), dtype=torch.int64).to(device)
            fake_output = model(random_data)
            dis_loss = distribution_factor * criterion(fake_output, random_target)

            reg_loss = 0
            prev_outputs = None
            if (task_num != 0):
                meta_model.eval()
                for k, prev_mod in enumerate(prev_models):
                    with torch.no_grad():
                        past_output = prev_mod(images)
                    
                    reg_loss += 2 * torch.mean(torch.abs(past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2]))
                    '''
                    diff = diff.unsqueeze(0)
                    if (prev_outputs == None):
                        prev_outputs = diff
                    else:
                        prev_outputs = torch.cat((prev_outputs, diff), dim=0)'''

                '''with torch.no_grad():
                    meta_out = meta_model(prev_outputs).view(-1, 1, 1)'''
                
                '''meta_out = 2
                reg_loss = torch.mean(meta_out * prev_outputs)'''
                
                ''' if (batch_num == 0):
                    print(meta_out)'''
                    
            loss = criterion(outputs[:,:valid_out_dim], labels) + dis_loss + reg_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            '''if (task_num != 0):
                meta_model.train()
                prev_outputs = None
                outputs = outputs.detach()
                for k, prev_mod in enumerate(prev_models):
                    with torch.no_grad():
                        past_output = prev_mod(images)
                    
                    diff = torch.abs(past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2])
                    diff = diff.unsqueeze(0)
                    if (prev_outputs == None):
                        prev_outputs = diff
                    else:
                        prev_outputs = torch.cat((prev_outputs, diff), dim=0)

                meta_out = meta_model(prev_outputs).view(-1, 1, 1)
                meta_loss = torch.mean(meta_out * prev_outputs)

                optimizer_M.zero_grad()
                meta_loss.backward()
                optimizer_M.step()'''

            batch_num += 1

        adjust_learning_rate(optimizer)

        avg_acc = 0
        log_message = []
        log_message.append(f"Epoch: {e}, ")
        for task in range(task_num + 1):
            acc = cal_acc(model, task_dataloaders[task][0], device)
            avg_acc += acc
            task_accuracies[task].append(acc * 100)
            log_message.append(f"Task {task} acc: {acc * 100:.4f}")

        logging.info(', '.join(log_message) + f", Task avg acc:{avg_acc*100/(task_num + 1):.4f}")
        if(task_num != 0):
            logging.info(f'loss:{loss.item()}, reg_loss:{reg_loss:.4f}')
    
    return model, task_accuracies

# Train

In [32]:
def setup_logging(file_name):
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(file_name),
                            logging.StreamHandler()
                        ],
                        force=True)
    

def train_split_task(task_name, task_dataloaders):
    print(f"Training on {task_name}")
    prev_models = []
    model = MLP_Enhance()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.005)
    criterion = nn.CrossEntropyLoss()

    epochs, lr = 3, 0.005
    file_name = f"{output_dir}/{task_name}_epochs={epochs}"
    log_file_name = file_name + '.log'
    img_file_name = file_name + '.png'

    setup_logging(log_file_name)
    all_accuracies = []
    for task in [0, 1, 2, 3, 4]:
        model, task_accuracies = train(model, task, criterion, optimizer, prev_models, task_dataloaders, epoches = 4)
        all_accuracies.append(task_accuracies)
        prev_models.append(copy.deepcopy(model))

    plot_task_acc_history(all_accuracies, img_file_name, epochs)

In [33]:
train_split_task("0", task_dataloaders)

2024-06-02 01:06:58,342 - INFO - ##########Task 0##########
2024-06-02 01:06:58,343 - INFO - Epoch 0


Training on 0


2024-06-02 01:07:03,831 - INFO - Epoch: 0, , Task 0 acc: 95.7350, Task avg acc:95.7350
2024-06-02 01:07:03,832 - INFO - Epoch 1
2024-06-02 01:07:09,110 - INFO - Epoch: 1, , Task 0 acc: 95.3922, Task avg acc:95.3922
2024-06-02 01:07:09,111 - INFO - Epoch 2
2024-06-02 01:07:14,397 - INFO - Epoch: 2, , Task 0 acc: 94.2044, Task avg acc:94.2044
2024-06-02 01:07:14,398 - INFO - Epoch 3
2024-06-02 01:07:19,800 - INFO - Epoch: 3, , Task 0 acc: 97.7997, Task avg acc:97.7997


Pretrain on meta_model
Epoch [100/400], Loss: 0.0000
Epoch [200/400], Loss: 0.0000
Epoch [300/400], Loss: 0.0000


2024-06-02 01:07:20,852 - INFO - ##########Task 1##########
2024-06-02 01:07:20,853 - INFO - Epoch 0


Epoch [400/400], Loss: 0.0000


2024-06-02 01:07:28,401 - INFO - Epoch: 0, , Task 0 acc: 65.9279, Task 1 acc: 26.9365, Task avg acc:46.4322
2024-06-02 01:07:28,402 - INFO - loss:47.443511962890625, reg_loss:0.2794
2024-06-02 01:07:28,404 - INFO - Epoch 1
2024-06-02 01:07:36,020 - INFO - Epoch: 1, , Task 0 acc: 73.5730, Task 1 acc: 41.3979, Task avg acc:57.4855
2024-06-02 01:07:36,021 - INFO - loss:47.53248596191406, reg_loss:0.3147
2024-06-02 01:07:36,022 - INFO - Epoch 2
2024-06-02 01:07:43,572 - INFO - Epoch: 2, , Task 0 acc: 74.6811, Task 1 acc: 42.7443, Task avg acc:58.7127
2024-06-02 01:07:43,573 - INFO - loss:47.54520034790039, reg_loss:0.3417
2024-06-02 01:07:43,574 - INFO - Epoch 3
2024-06-02 01:07:51,170 - INFO - Epoch: 3, , Task 0 acc: 77.9496, Task 1 acc: 47.3737, Task avg acc:62.6616
2024-06-02 01:07:51,171 - INFO - loss:47.32173538208008, reg_loss:0.3405


Pretrain on meta_model
Epoch [100/400], Loss: 0.0000
Epoch [200/400], Loss: 0.0000
Epoch [300/400], Loss: 0.0000


2024-06-02 01:07:52,452 - INFO - ##########Task 2##########
2024-06-02 01:07:52,453 - INFO - Epoch 0


Epoch [400/400], Loss: 0.0000


2024-06-02 01:08:01,943 - INFO - Epoch: 0, , Task 0 acc: 61.1129, Task 1 acc: 10.2311, Task 2 acc: 17.1336, Task avg acc:29.4925
2024-06-02 01:08:01,944 - INFO - loss:48.25156784057617, reg_loss:0.4009
2024-06-02 01:08:01,945 - INFO - Epoch 1
2024-06-02 01:08:11,648 - INFO - Epoch: 1, , Task 0 acc: 65.7765, Task 1 acc: 5.9840, Task 2 acc: 23.9045, Task avg acc:31.8883
2024-06-02 01:08:11,650 - INFO - loss:48.185611724853516, reg_loss:0.4076
2024-06-02 01:08:11,651 - INFO - Epoch 2
2024-06-02 01:08:22,326 - INFO - Epoch: 2, , Task 0 acc: 64.3654, Task 1 acc: 6.1503, Task 2 acc: 31.5104, Task avg acc:34.0087
2024-06-02 01:08:22,328 - INFO - loss:47.86362075805664, reg_loss:0.4387
2024-06-02 01:08:22,328 - INFO - Epoch 3
2024-06-02 01:08:32,511 - INFO - Epoch: 3, , Task 0 acc: 64.2538, Task 1 acc: 6.9481, Task 2 acc: 32.7856, Task avg acc:34.6625
2024-06-02 01:08:32,513 - INFO - loss:48.04780960083008, reg_loss:0.4380


Pretrain on meta_model
Epoch [100/400], Loss: 0.0000
Epoch [200/400], Loss: 0.0000
Epoch [300/400], Loss: 0.0000


2024-06-02 01:08:34,123 - INFO - ##########Task 3##########
2024-06-02 01:08:34,123 - INFO - Epoch 0


Epoch [400/400], Loss: 0.0000


2024-06-02 01:08:46,591 - INFO - Epoch: 0, , Task 0 acc: 70.6314, Task 1 acc: 11.4777, Task 2 acc: 23.9404, Task 3 acc: 3.0263, Task avg acc:27.2689
2024-06-02 01:08:46,592 - INFO - loss:48.56460189819336, reg_loss:0.4033
2024-06-02 01:08:46,594 - INFO - Epoch 1


KeyboardInterrupt: 