In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import os
import torch.nn.init as init
from torchvision import datasets, transforms
from torch import optim, nn
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import logging
import math
import copy

In [2]:
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
output_dir = './output'

In [3]:
def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.75 * param_group['lr']

# Dataloader

In [4]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [5]:
def task_dataloader(task_num):
    task_dir = {0: [0, 1], 1: [2, 3], 2: [4, 5], 3: [6, 7], 4:[8, 9]}

    train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=tf)

    indices = [i for i, (_, label) in enumerate(train_dataset) if label in task_dir[task_num]]
    dataset = Subset(train_dataset, indices)
    task_train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    indices = [i for i, (_, label) in enumerate(test_dataset) if label in task_dir[task_num]]
    dataset = Subset(test_dataset, indices)
    task_test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return task_train_loader, task_test_loader

def load_all_data():
    train_dataset = datasets.MNIST(root='./data', train=True, transform=tf, download=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, test_loader

In [6]:
train_loader, test_loader = load_all_data()
task_dataloaders = {}
for task in range(0, 5):
    train_dl, test_ld = task_dataloader(task)
    task_dataloaders[task] = (train_dl, test_ld)

# Model

In [19]:
class DeepModelWithAttention(nn.Module):
    def __init__(self, input_size, desired_sum):
        super(DeepModelWithAttention, self).__init__()
        self.desired_sum = desired_sum

        self.attn = nn.MultiheadAttention(embed_dim=input_size, num_heads= 1)
        self.fc = nn.Linear(input_size, 1)
        
    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        attn_output = attn_output.mean(dim = 1)
        output = self.fc(attn_output).squeeze(-1) * self.desired_sum

        return output

In [20]:
class MLP_Enhance(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=400):
        super(MLP_Enhance, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Softmax(dim=1)
        )
        self.init_weights()

    def init_weights(self):
        for layer in self.linear:
            if isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(layer.bias, 0)
        last_linear = self.last[0]
        init.xavier_normal_(last_linear.weight)
        init.constant_(last_linear.bias, 0)


    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [21]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [22]:
def plot_acc_history(all_accuracies):
    plt.figure(figsize=(18, 12))

    for i, task_accuracies in enumerate(all_accuracies):
        plt.subplot(2, 3, i+1)
        for task, accs in task_accuracies.items():
            plt.plot(accs,'-', label=f'Task {task}')
        plt.title(f'Accuracy per Mini-Batch for Task {i}')
        plt.xlabel('Mini-Batch Number')
        plt.ylabel('Accuracy (%)')
        plt.ylim(-5, 105)
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.show()


def plot_task_acc_history(all_accuracies, save_name, epochs):
    for task in [0, 1, 2, 3, 4]:
        if (task == 0):
            task_history = []
        else:
            task_history =  [0] * (epochs * task)
        for i, task_accuracies in enumerate (all_accuracies):
            if (i >= task):
                task_history.extend(task_accuracies[task])
                    
        plt.plot(task_history, '-', label=f'Task {task}')

    plt.xlabel('Mini-Batch Number')
    plt.ylabel('Accuracy (%)')
    plt.ylim(-5, 105)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_name) 
    plt.show()

# Train Function

In [27]:
import random
def pretrain_meta(meta_model, optimizer_M):
    target = torch.full((5, 1, 1), 2.0).to(device)
    criterion = nn.MSELoss()

    print("Pretrain on meta_model")
    num_epochs = 50
    for epoch in range(num_epochs):
        prev_outputs = None
        for k in range (5):
            past_output = F.softmax(torch.rand(batch_size, 10), dim=-1).to(device)
            outputs = F.softmax(torch.rand(batch_size, 10), dim=-1).to(device)
            diff = past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2]
            diff = diff.unsqueeze(0)
            if (prev_outputs == None):
                prev_outputs = diff
            else:
                prev_outputs = torch.cat((prev_outputs, diff), dim=0)

        meta_out = meta_model(prev_outputs).view(-1, 1, 1)
        meta_model.train()
        optimizer_M.zero_grad()
        loss = criterion(meta_out, target)
        loss.backward()
        optimizer_M.step()
        
        if((epoch + 1) % 10 == 0):
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

In [31]:
def train(model, task_num, criterion, optimizer, prev_models, task_dataloaders, epoches = 4):
    train_loader = task_dataloaders[task_num][0]
    task_accuracies = {task: [] for task in range(task_num + 1)}
    distribution_factor  = 20

    if (task_num != 0):
        meta_model = DeepModelWithAttention(desired_sum = 2, input_size= task_num).to(device)
        optimizer_M = optim.Adam(meta_model.parameters(), lr=0.001)
        pretrain_meta(meta_model, optimizer_M)
    
    valid_out_dim = task_num * 2 + 2
    logging.info(f"##########Task {task_num}##########")
    for e in range(epoches):
        logging.info(f"Epoch {e}")
        batch_num = 0
        for images, labels in train_loader:
            model.train()
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            random_data = torch.rand(256, 1, 28, 28).to(device)
            random_data = (random_data - 0.5) / 0.5
            random_target = torch.randint(0, 10, (256,), dtype=torch.int64).to(device)
            fake_output = model(random_data)
            dis_loss = distribution_factor * criterion(fake_output, random_target)

            reg_loss = 0
            prev_outputs = None
            if (task_num != 0):
                meta_model.eval()
                for k, prev_mod in enumerate(prev_models):
                    with torch.no_grad():
                        past_output = prev_mod(images)
                    
                    diff = torch.abs(past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2])
                    diff = diff.unsqueeze(0)
                    if (prev_outputs == None):
                        prev_outputs = diff
                    else:
                        prev_outputs = torch.cat((prev_outputs, diff), dim=0)
                reg_loss = 0
                if(prev_outputs is not None):
                    with torch.no_grad():
                        meta_out = meta_model(prev_outputs).view(-1, 1, 1)
                    reg_loss = torch.mean(meta_out * diff)
                
                    if (batch_num % 20 == 0):
                        print(meta_out)
                    
            loss = criterion(outputs[:,:valid_out_dim], labels) + dis_loss + reg_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if (task_num != 0):
                meta_model.train()
                prev_outputs = None
                outputs = outputs.detach()
                for k, prev_mod in enumerate(prev_models):
                    with torch.no_grad():
                        past_output = prev_mod(images)
                    
                    diff = torch.abs(past_output[:,k*2:k * 2 + 2] - outputs[:,k*2:k * 2 + 2])
                    diff = diff.unsqueeze(0)
                    if (prev_outputs == None):
                        prev_outputs = diff
                    else:
                        prev_outputs = torch.cat((prev_outputs, diff), dim=0)
                reg_loss = 0
                if(prev_outputs is not None):
                    meta_out = meta_model(prev_outputs).view(-1, 1, 1)
                    reg_loss = torch.mean(meta_out * diff)

                optimizer_M.zero_grad()
                reg_loss.backward()
                optimizer_M.step()

                batch_num += 1

        adjust_learning_rate(optimizer)

        avg_acc = 0
        log_message = []
        log_message.append(f"Epoch: {e}, ")
        for task in range(task_num + 1):
            acc = cal_acc(model, task_dataloaders[task][0], device)
            avg_acc += acc
            task_accuracies[task].append(acc * 100)
            log_message.append(f"Task {task} acc: {acc * 100:.4f}")

        logging.info(', '.join(log_message) + f", Task avg acc:{avg_acc*100/(task_num + 1):.4f}")
    
    return model, task_accuracies

# Train

In [32]:
def setup_logging(file_name):
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(file_name),
                            logging.StreamHandler()
                        ],
                        force=True)
    

def train_split_task(task_name, task_dataloaders):
    print(f"Training on {task_name}")
    prev_models = []
    model = MLP_Enhance()
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr = 0.005)
    criterion = nn.CrossEntropyLoss()

    epochs, lr = 3, 0.005
    file_name = f"{output_dir}/{task_name}_epochs={epochs}"
    log_file_name = file_name + '.log'
    img_file_name = file_name + '.png'

    setup_logging(log_file_name)
    all_accuracies = []
    for task in [0, 1, 2, 3, 4]:
        model, task_accuracies = train(model, task, criterion, optimizer, prev_models, task_dataloaders, epoches = 4)
        all_accuracies.append(task_accuracies)
        prev_models.append(copy.deepcopy(model))

    plot_task_acc_history(all_accuracies, img_file_name, epochs)

In [33]:
train_split_task("0", task_dataloaders)

2024-06-01 15:28:53,426 - INFO - ##########Task 0##########
2024-06-01 15:28:53,427 - INFO - Epoch 0


Training on 0


2024-06-01 15:28:58,963 - INFO - Epoch: 0, , Task 0 acc: 43.7978, Task avg acc:43.7978
2024-06-01 15:28:58,964 - INFO - Epoch 1
2024-06-01 15:29:04,015 - INFO - Epoch: 1, , Task 0 acc: 93.0644, Task avg acc:93.0644
2024-06-01 15:29:04,015 - INFO - Epoch 2
2024-06-01 15:29:09,069 - INFO - Epoch: 2, , Task 0 acc: 96.0539, Task avg acc:96.0539
2024-06-01 15:29:09,070 - INFO - Epoch 3
2024-06-01 15:29:14,097 - INFO - Epoch: 3, , Task 0 acc: 96.7873, Task avg acc:96.7873


UnboundLocalError: local variable 'meta_model' referenced before assignment