In [98]:
import torch
import os
import torch.nn.init as init
from torchvision import datasets, transforms
from torch import optim, nn
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import logging

In [99]:
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [100]:
train_settings = {
    "Baseline": {"epoches": 4, "lr": 0.001, "l2": None, "loss_type": None},
    "Baseline": {"epoches": 4, "lr": 0.001, "l2": 1e-4, "loss_type": None},
    "Distribution": {"epoches": 4, "lr": 0.001, "l2": None, "loss_type": "distribution"},
    "FollowLeader": {"epoches": 4, "lr": 0.001, "l2": None, "loss_type": "follow_leader"},
    "Interupt": {"epoches": 4, "lr": 0.001, "l2": None, "loss_type": "Interupt"},
    "Interupt": {"epoches": 4, "lr": 0.0005, "l2": None, "loss_type": "Interupt"},
}

## Data Loader

In [101]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [102]:
def task_dataloader(task_num):
    task_dir = {0: [0, 1], 1: [2, 3], 2: [4, 5], 3: [6, 7], 4:[8, 9]}

    train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=tf)

    indices = [i for i, (_, label) in enumerate(train_dataset) if label in task_dir[task_num]]
    dataset = Subset(train_dataset, indices)
    task_train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    indices = [i for i, (_, label) in enumerate(test_dataset) if label in task_dir[task_num]]
    dataset = Subset(test_dataset, indices)
    task_test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return task_train_loader, task_test_loader

def load_all_data():
    train_dataset = datasets.MNIST(root='./data', train=True, transform=tf, download=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, test_loader

## Model

In [103]:
def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.75 * param_group['lr']

In [104]:
class MLP_Enhance(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=400):
        super(MLP_Enhance, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )
        self.last = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Softmax(dim=1)
        )
        self.init_weights()

    def init_weights(self):
        for layer in self.linear:
            if isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(layer.bias, 0)
        last_linear = self.last[0]
        init.xavier_normal_(last_linear.weight)
        init.constant_(last_linear.bias, 0)


    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [105]:
train_loader, test_loader = load_all_data()
task_dataloaders = {}
for task in range(0, 5):
    train_dl, test_ld = task_dataloader(task)
    task_dataloaders[task] = (train_dl, test_ld)

# Eval

In [106]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [107]:
def plot_acc_history(all_accuracies):
    plt.figure(figsize=(18, 12))

    for i, task_accuracies in enumerate(all_accuracies):
        plt.subplot(2, 3, i+1)
        for task, accs in task_accuracies.items():
            plt.plot(accs,'-', label=f'Task {task}')
        plt.title(f'Accuracy per Mini-Batch for Task {i}')
        plt.xlabel('Mini-Batch Number')
        plt.ylabel('Accuracy (%)')
        plt.ylim(-5, 105)
        plt.legend()
        plt.grid(True)

    plt.tight_layout()
    plt.show()


def plot_task_acc_history(all_accuracies, save_name):
    for task in [0, 1, 2, 3, 4]:
        if (task == 0):
            task_history = []
        else:
            task_history =  [0] * (50 * task)
        for i, task_accuracies in enumerate (all_accuracies):
            if (i >= task):
                task_history.extend(task_accuracies[task])
                    
        plt.plot(task_history, '-', label=f'Task {task}')

    plt.title(f'Accuracy per Mini-Batch for base line')
    plt.xlabel('Mini-Batch Number')
    plt.ylabel('Accuracy (%)')
    plt.ylim(-5, 105)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(save_name) 

# Train

In [111]:
def train(model, task_num, criterion, epoches = 4, lr = 0.001, l2 = None, loss_type = None):
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(model.state_dict())
    criterion = nn.CrossEntropyLoss()
    if l2 is not None:
        optimizer_F = torch.optim.Adam(follower.parameters(), lr)
    else:
        optimizer_F = torch.optim.Adam(follower.parameters(), lr, weight_decay = l2)
        
    train_loader = task_dataloaders[task_num][0]

    task_accuracies = {task: [] for task in range(task_num + 1)}

    valid_out_dim = task_num * 2 + 2
    logging.info(f"##########Task {task_num}##########")
    for e in range(epoches):
        logging.info(f"Epoch {e}")
        batch_num = 0
        for images, labels in train_loader:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            reg_loss = 0
            if (loss_type == "distribution"):
                except_mean = 0.1
                for task in range (10):
                    task_start = task
                    task_end = (task + 1)

                    fake_image = torch.randn(batch_size, 1, 28, 28).to(device)
                    fake_output = follower(fake_image)

                    real_mean = torch.mean(torch.abs(fake_output[:,task_start:task_end]))
                    reg_loss += torch.abs(except_mean - real_mean)
            elif (loss_type == "follow_leader"):
                with torch.no_grad():
                    leader_output = model(images)
                if (task_num != 0):
                    reg_loss =  2 * torch.mean(torch.abs(leader_output[:,:valid_out_dim - 2] - outputs[:,:valid_out_dim - 2]))
            elif (loss_type == "interupt"):
                for _ in range (4):
                    fake_image = torch.randn(batch_size, 1, 28, 28).to(device)

                    with torch.no_grad():
                        leader_fake_output = model(fake_image)
                    
                    fake_outputs = follower(images)

                    if (task_num != 0):
                        reg_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:valid_out_dim], labels) + reg_loss
            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            if(batch_num % 10 == 0):
                avg_acc = 0
                log_message = []
                for task in range(task_num + 1):
                    acc = cal_acc(follower, task_dataloaders[task][0], device)
                    avg_acc += acc
                    task_accuracies[task].append(acc * 100)
                    log_message.append(f"Batch num: {batch_num}, Task {task} acc: {acc * 100:.4f}")

                logging.info(', '.join(log_message) + f", Task avg acc:{avg_acc*100/(task_num + 1):.4f}")

            batch_num += 1

        adjust_learning_rate(optimizer_F)
    
    return follower, task_accuracies

In [112]:
def setup_logging(task_name, epochs, lr, l2):
    log_filename = f"{task_name}_{epochs}_lr={lr}_l2={l2}.log"
    logging.basicConfig(level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s',
                        handlers=[
                            logging.FileHandler(log_filename),
                            logging.StreamHandler()
                        ],
                        force=True)
    
def train_split_task(Task_Name):
    model = MLP_Enhance()
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()

    settings = train_settings[Task_Name]
    epochs, lr, l2, loss_type = settings['epochs'], settings['lr'], settings['l2'], settings['loss_type']

    if os.path.exists(f"{Task_Name}_{epochs}_lr={lr}_l2={l2}.log") == False:
        setup_logging(Task_Name, epochs, lr, l2)
        all_accuracies = []
        for task in [0, 1, 2, 3, 4]:
            model, task_accuracies = train(model, task, criterion, epoches = epochs, lr = lr, l2 = l2, loss_type = loss_type)
            all_accuracies.append(task_accuracies)

    plot_task_acc_history(all_accuracies, f"{Task_Name}_{epochs}_{lr}_{l2}.png")

# Experiment

In [113]:
for key in train_settings.keys():
    train_split_task(key)

2024-05-25 22:52:51,167 - INFO - ##########Task 0##########
2024-05-25 22:52:51,181 - INFO - Epoch 0
2024-05-25 22:52:54,022 - INFO - Batch num: 0, Task 0 acc: 3.6990, Task avg acc:3.6990
2024-05-25 22:52:56,214 - INFO - Batch num: 10, Task 0 acc: 98.1904, Task avg acc:98.1904
2024-05-25 22:52:58,938 - INFO - Batch num: 20, Task 0 acc: 99.6572, Task avg acc:99.6572
2024-05-25 22:53:01,332 - INFO - Batch num: 30, Task 0 acc: 99.7768, Task avg acc:99.7768
2024-05-25 22:53:03,742 - INFO - Batch num: 40, Task 0 acc: 99.8087, Task avg acc:99.8087
2024-05-25 22:53:05,882 - INFO - Batch num: 50, Task 0 acc: 99.8565, Task avg acc:99.8565
2024-05-25 22:53:08,172 - INFO - Batch num: 60, Task 0 acc: 99.8884, Task avg acc:99.8884
2024-05-25 22:53:10,156 - INFO - Batch num: 70, Task 0 acc: 99.9123, Task avg acc:99.9123
2024-05-25 22:53:12,820 - INFO - Batch num: 80, Task 0 acc: 99.9203, Task avg acc:99.9203
2024-05-25 22:53:15,358 - INFO - Batch num: 90, Task 0 acc: 99.9123, Task avg acc:99.9123
20

KeyboardInterrupt: 