In [2]:
import torch
import os
from torchvision import datasets, transforms
from torch import optim, nn
from torch.utils.data import DataLoader, Subset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
num_epochs = 4
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
tf = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ])

In [5]:
def task_dataloader(task_num):
    task_dir = {0: [0, 1], 1: [2, 3], 2: [4, 5], 3: [6, 7], 4:[8, 9]}

    train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=tf)

    indices = [i for i, (_, label) in enumerate(train_dataset) if label in task_dir[task_num]]
    dataset = Subset(train_dataset, indices)
    task_train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    indices = [i for i, (_, label) in enumerate(test_dataset) if label in task_dir[task_num]]
    dataset = Subset(test_dataset, indices)
    task_test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return task_train_loader, task_test_loader

def load_all_data():
    train_dataset = datasets.MNIST(root='./data', train=True, transform=tf, download=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    test_dataset = datasets.MNIST(root='./data', train=False, transform=tf, download=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, test_loader

In [6]:
train_loader, test_loader = load_all_data()

In [7]:
class MLP(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=256):
        super(MLP, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
        )
        self.last = nn.Linear(hidden_dim, out_dim)

    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [8]:
def cal_acc(model, dataloader, device):
    model.eval()
    correct, total = 0., 0.
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct / total

In [9]:
def train(criterion, optimizer, model, num_epochs, trainloader, valloader, valid_out_dim, device):
    for epoch in range(num_epochs):
        model.train()
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs[:,:valid_out_dim], labels)
            loss.backward()
            optimizer.step()

        test_acc = cal_acc(model, valloader, device)

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}, Test Acc: {test_acc:.4f}")

In [10]:
task_0_train_loader, task_0_test_loader = task_dataloader(0)

In [11]:
model = MLP()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0005)

if not os.path.exists('task_1.pth'):
    train(criterion, optimizer, model, num_epochs, task_0_train_loader, test_loader, 2, device)

    torch.save(model.state_dict(), 'task_1.pth')
else:
    model.load_state_dict(torch.load('task_1.pth'))

In [12]:
task_1_train_loader, task_1_test_loader = task_dataloader(1)

# No Method

In [26]:
batch_num = 0

try:
    for images, labels in task_1_train_loader:
        model.train()
        
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs[:,:4], labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        task_0_acc = cal_acc(model, task_0_test_loader, device)
        task_1_acc = cal_acc(model, task_1_test_loader, device)
        task_all_acc = cal_acc(model, test_loader, device)

        print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f}, task_all_acc:{task_all_acc:.4f}")

        batch_num += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0172, task_all_acc:0.2192
Batch: 1, task_0_acc: 0.9985, task_1_acc: 0.0479, task_all_acc:0.2276
Batch: 2, task_0_acc: 0.9971, task_1_acc: 0.1000, task_all_acc:0.2397
Batch: 3, task_0_acc: 0.9941, task_1_acc: 0.1786, task_all_acc:0.2512
Batch: 4, task_0_acc: 0.9614, task_1_acc: 0.2578, task_all_acc:0.2594
Batch: 5, task_0_acc: 0.7803, task_1_acc: 0.3224, task_all_acc:0.2336
Batch: 6, task_0_acc: 0.5137, task_1_acc: 0.3885, task_all_acc:0.1914
Batch: 7, task_0_acc: 0.4312, task_1_acc: 0.4604, task_all_acc:0.1870
Train interrupt by keyboard


# reg to output

## Method 1

In [31]:
batch_num = 0

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)

    for images, labels in task_1_train_loader:
        follower.train()
        
        images, labels = images.to(device), labels.to(device)
        outputs = follower(images)

        with torch.no_grad():
            leader_output = model(images)

        loss = criterion(outputs[:,:4], labels) + torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) 

        optimizer_F.zero_grad()
        loss.backward()
        optimizer_F.step()

        task_0_acc = cal_acc(follower, task_0_test_loader, device)
        task_1_acc = cal_acc(follower, task_1_test_loader, device)
        task_all_acc = cal_acc(follower, test_loader, device)

        print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f}, task_all_acc:{task_all_acc:.4f}")

        batch_num += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0172, task_all_acc:0.2192
Batch: 1, task_0_acc: 0.9980, task_1_acc: 0.0391, task_all_acc:0.2264
Batch: 2, task_0_acc: 0.9980, task_1_acc: 0.0880, task_all_acc:0.2365
Batch: 3, task_0_acc: 0.9980, task_1_acc: 0.1516, task_all_acc:0.2485
Batch: 4, task_0_acc: 0.9980, task_1_acc: 0.2099, task_all_acc:0.2591
Batch: 5, task_0_acc: 0.9976, task_1_acc: 0.2687, task_all_acc:0.2698
Batch: 6, task_0_acc: 0.9976, task_1_acc: 0.3401, task_all_acc:0.2831
Batch: 7, task_0_acc: 0.9971, task_1_acc: 0.4260, task_all_acc:0.2993
Batch: 8, task_0_acc: 0.9971, task_1_acc: 0.4979, task_all_acc:0.3137
Batch: 9, task_0_acc: 0.9976, task_1_acc: 0.5703, task_all_acc:0.3283
Batch: 10, task_0_acc: 0.9946, task_1_acc: 0.6396, task_all_acc:0.3410
Batch: 11, task_0_acc: 0.9907, task_1_acc: 0.7000, task_all_acc:0.3521
Batch: 12, task_0_acc: 0.9854, task_1_acc: 0.7375, task_all_acc:0.3588
Batch: 13, task_0_acc: 0.9731, task_1_acc: 0.7656, task_all_acc:0.3620
Batch: 14, task_

## Method 2

In [13]:
batch_num = 0

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)

    for images, labels in task_1_train_loader:
        follower.train()
        
        images, labels = images.to(device), labels.to(device)
        outputs = follower(images)

        with torch.no_grad():
            leader_output = model(images)

        loss = criterion(outputs[:,:4], labels) + 10 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) 

        optimizer_F.zero_grad()
        loss.backward()
        optimizer_F.step()

        task_0_acc = cal_acc(follower, task_0_test_loader, device)
        task_1_acc = cal_acc(follower, task_1_test_loader, device)
        task_all_acc = cal_acc(follower, test_loader, device)

        print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f}, task_all_acc:{task_all_acc:.4f}")

        batch_num += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0172, task_all_acc:0.2192
Batch: 1, task_0_acc: 0.9980, task_1_acc: 0.0198, task_all_acc:0.2206
Batch: 2, task_0_acc: 0.9980, task_1_acc: 0.0203, task_all_acc:0.2206
Batch: 3, task_0_acc: 0.9990, task_1_acc: 0.0208, task_all_acc:0.2214
Batch: 4, task_0_acc: 0.9985, task_1_acc: 0.0312, task_all_acc:0.2237
Batch: 5, task_0_acc: 0.9985, task_1_acc: 0.0380, task_all_acc:0.2248
Batch: 6, task_0_acc: 0.9985, task_1_acc: 0.0505, task_all_acc:0.2266
Batch: 7, task_0_acc: 0.9985, task_1_acc: 0.0599, task_all_acc:0.2285
Batch: 8, task_0_acc: 0.9985, task_1_acc: 0.0729, task_all_acc:0.2300
Batch: 9, task_0_acc: 0.9985, task_1_acc: 0.0833, task_all_acc:0.2309
Batch: 10, task_0_acc: 0.9985, task_1_acc: 0.1005, task_all_acc:0.2334
Batch: 11, task_0_acc: 0.9985, task_1_acc: 0.1177, task_all_acc:0.2368
Batch: 12, task_0_acc: 0.9985, task_1_acc: 0.1323, task_all_acc:0.2388
Batch: 13, task_0_acc: 0.9985, task_1_acc: 0.1432, task_all_acc:0.2410
Batch: 14, task_

## Method 3

In [14]:
batch_num = 0

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)

    for images, labels in task_1_train_loader:
        follower.train()
        
        images, labels = images.to(device), labels.to(device)
        outputs = follower(images)

        with torch.no_grad():
            leader_output = model(images)

        loss = criterion(outputs[:,:4], labels) + 20 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) 

        optimizer_F.zero_grad()
        loss.backward()
        optimizer_F.step()

        task_0_acc = cal_acc(follower, task_0_test_loader, device)
        task_1_acc = cal_acc(follower, task_1_test_loader, device)
        task_all_acc = cal_acc(follower, test_loader, device)

        print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f}, task_all_acc:{task_all_acc:.4f}")

        batch_num += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0172, task_all_acc:0.2192
Batch: 1, task_0_acc: 0.9985, task_1_acc: 0.0130, task_all_acc:0.2179
Batch: 2, task_0_acc: 0.9985, task_1_acc: 0.0135, task_all_acc:0.2184
Batch: 3, task_0_acc: 0.9985, task_1_acc: 0.0208, task_all_acc:0.2201
Batch: 4, task_0_acc: 0.9985, task_1_acc: 0.0240, task_all_acc:0.2205
Batch: 5, task_0_acc: 0.9985, task_1_acc: 0.0328, task_all_acc:0.2209
Batch: 6, task_0_acc: 0.9985, task_1_acc: 0.0448, task_all_acc:0.2225
Batch: 7, task_0_acc: 0.9985, task_1_acc: 0.0589, task_all_acc:0.2246
Batch: 8, task_0_acc: 0.9985, task_1_acc: 0.0677, task_all_acc:0.2256
Batch: 9, task_0_acc: 0.9985, task_1_acc: 0.0797, task_all_acc:0.2279
Batch: 10, task_0_acc: 0.9985, task_1_acc: 0.0828, task_all_acc:0.2287
Batch: 11, task_0_acc: 0.9985, task_1_acc: 0.0927, task_all_acc:0.2307
Batch: 12, task_0_acc: 0.9985, task_1_acc: 0.0938, task_all_acc:0.2308
Batch: 13, task_0_acc: 0.9985, task_1_acc: 0.1068, task_all_acc:0.2335
Batch: 14, task_

## Method 4

In [16]:
batch_num = 0

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        print(f"Epoch {epoch}")
        for images, labels in task_1_train_loader:
            follower.train()
            
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            loss = criterion(outputs[:,:4], labels) + 100 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) 

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f}, task_all_acc:{task_all_acc:.4f}")

            batch_num += 1
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0172, task_all_acc:0.2192
Batch: 1, task_0_acc: 0.9985, task_1_acc: 0.0141, task_all_acc:0.2180
Batch: 2, task_0_acc: 0.9985, task_1_acc: 0.0125, task_all_acc:0.2180
Batch: 3, task_0_acc: 0.9990, task_1_acc: 0.0146, task_all_acc:0.2186
Batch: 4, task_0_acc: 0.9990, task_1_acc: 0.0234, task_all_acc:0.2191
Batch: 5, task_0_acc: 0.9990, task_1_acc: 0.0312, task_all_acc:0.2200
Batch: 6, task_0_acc: 0.9990, task_1_acc: 0.0453, task_all_acc:0.2226
Batch: 7, task_0_acc: 0.9990, task_1_acc: 0.0589, task_all_acc:0.2253
Batch: 8, task_0_acc: 0.9990, task_1_acc: 0.0651, task_all_acc:0.2261
Batch: 9, task_0_acc: 0.9985, task_1_acc: 0.0651, task_all_acc:0.2262
Batch: 10, task_0_acc: 0.9985, task_1_acc: 0.0693, task_all_acc:0.2270
Batch: 11, task_0_acc: 0.9985, task_1_acc: 0.0797, task_all_acc:0.2290
Batch: 12, task_0_acc: 0.9985, task_1_acc: 0.0948, task_all_acc:0.2317
Batch: 13, task_0_acc: 0.9985, task_1_acc: 0.1099, task_all_acc:0.2342
Batch: 1

# Random item to reg on output

In [25]:
batch_num = 0

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        print(f"Epoch {epoch}")
        for images, labels in task_1_train_loader:
            follower.train()
            
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)


            fake_image = torch.randn(batch_size, 1, 28, 28)

            with torch.no_grad():
                leader_fake_output = model(fake_image)

            fake_outputs = follower(images)

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) \
                                + 5 * torch.mean(torch.abs(leader_fake_output[:,:2] - fake_outputs[:,:2]))

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9990, task_1_acc: 0.0115,                        task_all_acc:0.2183, avg_task_1_2:0.5052
Batch: 1, task_0_acc: 0.9990, task_1_acc: 0.0203,                        task_all_acc:0.2204, avg_task_1_2:0.5097
Batch: 2, task_0_acc: 0.9985, task_1_acc: 0.0339,                        task_all_acc:0.2231, avg_task_1_2:0.5162
Batch: 3, task_0_acc: 0.9990, task_1_acc: 0.0505,                        task_all_acc:0.2269, avg_task_1_2:0.5248
Batch: 4, task_0_acc: 0.9990, task_1_acc: 0.0849,                        task_all_acc:0.2329, avg_task_1_2:0.5420
Batch: 5, task_0_acc: 0.9990, task_1_acc: 0.1203,                        task_all_acc:0.2396, avg_task_1_2:0.5597
Batch: 6, task_0_acc: 0.9985, task_1_acc: 0.1646,                        task_all_acc:0.2477, avg_task_1_2:0.5816
Batch: 7, task_0_acc: 0.9985, task_1_acc: 0.1990,                        task_all_acc:0.2540, avg_task_1_2:0.5987
Batch: 8, task_0_acc: 0.9980, task_1_acc: 0.2422,                        task_al

In [26]:
batch_num = 0

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        print(f"Epoch {epoch}")
        for images, labels in task_1_train_loader:
            follower.train()
            
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)


            fake_image = torch.randn(batch_size, 1, 28, 28)

            with torch.no_grad():
                leader_fake_output = model(fake_image)

            fake_outputs = follower(images)

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) \
                                + 5 * torch.mean(torch.abs(leader_fake_output - fake_outputs))

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0167,                        task_all_acc:0.2190, avg_task_1_2:0.5076
Batch: 1, task_0_acc: 0.9985, task_1_acc: 0.0203,                        task_all_acc:0.2210, avg_task_1_2:0.5094
Batch: 2, task_0_acc: 0.9985, task_1_acc: 0.0281,                        task_all_acc:0.2218, avg_task_1_2:0.5133
Batch: 3, task_0_acc: 0.9985, task_1_acc: 0.0427,                        task_all_acc:0.2242, avg_task_1_2:0.5206
Batch: 4, task_0_acc: 0.9985, task_1_acc: 0.0547,                        task_all_acc:0.2265, avg_task_1_2:0.5266
Batch: 5, task_0_acc: 0.9985, task_1_acc: 0.0604,                        task_all_acc:0.2266, avg_task_1_2:0.5295
Batch: 6, task_0_acc: 0.9985, task_1_acc: 0.0677,                        task_all_acc:0.2283, avg_task_1_2:0.5331
Batch: 7, task_0_acc: 0.9985, task_1_acc: 0.0927,                        task_all_acc:0.2333, avg_task_1_2:0.5456
Batch: 8, task_0_acc: 0.9985, task_1_acc: 0.1036,                        task_al

In [29]:
batch_num = 0

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        print(f"Epoch {epoch}")
        batch_num = 0
        for images, labels in task_1_train_loader:
            follower.train()
            
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)


            fake_image = torch.randn(batch_size, 1, 28, 28)

            with torch.no_grad():
                leader_fake_output = model(fake_image)

            fake_outputs = follower(images)

            loss = criterion(outputs[:,:4], labels) + 10 * torch.mean(torch.abs(leader_fake_output - fake_outputs))

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9990, task_1_acc: 0.0109,                        task_all_acc:0.2178, avg_task_1_2:0.5050
Batch: 0, task_0_acc: 0.9990, task_1_acc: 0.0151,                        task_all_acc:0.2186, avg_task_1_2:0.5071
Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0276,                        task_all_acc:0.2220, avg_task_1_2:0.5131
Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0484,                        task_all_acc:0.2260, avg_task_1_2:0.5235
Batch: 0, task_0_acc: 0.9971, task_1_acc: 0.0776,                        task_all_acc:0.2322, avg_task_1_2:0.5373
Batch: 0, task_0_acc: 0.9951, task_1_acc: 0.1042,                        task_all_acc:0.2372, avg_task_1_2:0.5496
Batch: 0, task_0_acc: 0.9722, task_1_acc: 0.1427,                        task_all_acc:0.2409, avg_task_1_2:0.5574
Batch: 0, task_0_acc: 0.9082, task_1_acc: 0.1807,                        task_all_acc:0.2361, avg_task_1_2:0.5445
Batch: 0, task_0_acc: 0.7759, task_1_acc: 0.1995,                        task_al

In [30]:
try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:
            follower.train()
            
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(images)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0182,                        task_all_acc:0.2194, avg_task_1_2:0.5084
Batch: 1, task_0_acc: 0.9985, task_1_acc: 0.0255,                        task_all_acc:0.2210, avg_task_1_2:0.5120
Batch: 2, task_0_acc: 0.9985, task_1_acc: 0.0344,                        task_all_acc:0.2227, avg_task_1_2:0.5165
Batch: 3, task_0_acc: 0.9990, task_1_acc: 0.0474,                        task_all_acc:0.2253, avg_task_1_2:0.5232
Batch: 4, task_0_acc: 0.9990, task_1_acc: 0.0630,                        task_all_acc:0.2276, avg_task_1_2:0.5310
Batch: 5, task_0_acc: 0.9990, task_1_acc: 0.0724,                        task_all_acc:0.2296, avg_task_1_2:0.5357
Batch: 6, task_0_acc: 0.9985, task_1_acc: 0.0802,                        task_all_acc:0.2294, avg_task_1_2:0.5394
Batch: 7, task_0_acc: 0.9985, task_1_acc: 0.0943,                        task_all_acc:0.2322, avg_task_1_2:0.5464
Batch: 8, task_0_acc: 0.9990, task_1_acc: 0.1135,                        task_al

In [34]:
def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.8 * param_group['lr']

try:
    follower = MLP()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:
            follower.train()
            
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(images)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        adjust_learning_rate(optimizer_F)
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9990, task_1_acc: 0.0177,                        task_all_acc:0.2198, avg_task_1_2:0.5084
Batch: 1, task_0_acc: 0.9985, task_1_acc: 0.0208,                        task_all_acc:0.2227, avg_task_1_2:0.5097
Batch: 2, task_0_acc: 0.9985, task_1_acc: 0.0307,                        task_all_acc:0.2241, avg_task_1_2:0.5146
Batch: 3, task_0_acc: 0.9990, task_1_acc: 0.0464,                        task_all_acc:0.2256, avg_task_1_2:0.5227
Batch: 4, task_0_acc: 0.9990, task_1_acc: 0.0651,                        task_all_acc:0.2283, avg_task_1_2:0.5321
Batch: 5, task_0_acc: 0.9990, task_1_acc: 0.0714,                        task_all_acc:0.2284, avg_task_1_2:0.5352
Batch: 6, task_0_acc: 0.9990, task_1_acc: 0.0839,                        task_all_acc:0.2301, avg_task_1_2:0.5414
Batch: 7, task_0_acc: 0.9990, task_1_acc: 0.0953,                        task_all_acc:0.2322, avg_task_1_2:0.5472
Batch: 8, task_0_acc: 0.9990, task_1_acc: 0.1115,                        task_al

In [59]:
class MLP_Enhance(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=400):
        super(MLP_Enhance, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )
        self.last = nn.Linear(hidden_dim, out_dim)

    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [74]:
model = MLP_Enhance()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 0.0005)

if not os.path.exists('task_1_400.pth'):
    train(criterion, optimizer, model, num_epochs, task_0_train_loader, test_loader, 2, device)

    torch.save(model.state_dict(), 'task_1_400.pth')
else:
    model.load_state_dict(torch.load('task_1_400.pth'))

In [61]:
try:
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1_400.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:

            # Train Follower:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(images)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()


            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        adjust_learning_rate(optimizer_F)
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9966, task_1_acc: 0.0120,                        task_all_acc:0.2178, avg_task_1_2:0.5043
Batch: 1, task_0_acc: 0.9956, task_1_acc: 0.0214,                        task_all_acc:0.2182, avg_task_1_2:0.5085
Batch: 2, task_0_acc: 0.9951, task_1_acc: 0.0385,                        task_all_acc:0.2226, avg_task_1_2:0.5168
Batch: 3, task_0_acc: 0.9951, task_1_acc: 0.0948,                        task_all_acc:0.2353, avg_task_1_2:0.5450
Batch: 4, task_0_acc: 0.9956, task_1_acc: 0.1672,                        task_all_acc:0.2513, avg_task_1_2:0.5814
Batch: 5, task_0_acc: 0.9961, task_1_acc: 0.2240,                        task_all_acc:0.2640, avg_task_1_2:0.6100
Batch: 6, task_0_acc: 0.9961, task_1_acc: 0.2891,                        task_all_acc:0.2765, avg_task_1_2:0.6426
Batch: 7, task_0_acc: 0.9956, task_1_acc: 0.3427,                        task_all_acc:0.2900, avg_task_1_2:0.6692
Batch: 8, task_0_acc: 0.9956, task_1_acc: 0.4016,                        task_al

In [77]:
try:
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1_400.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:

            # Train Follower:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (2):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(images)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()


            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        adjust_learning_rate(optimizer_F)
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9985, task_1_acc: 0.0276,                        task_all_acc:0.2219, avg_task_1_2:0.5131
Batch: 1, task_0_acc: 0.9985, task_1_acc: 0.0688,                        task_all_acc:0.2292, avg_task_1_2:0.5336
Batch: 2, task_0_acc: 0.9990, task_1_acc: 0.1005,                        task_all_acc:0.2352, avg_task_1_2:0.5498
Batch: 3, task_0_acc: 0.9985, task_1_acc: 0.1536,                        task_all_acc:0.2451, avg_task_1_2:0.5761
Batch: 4, task_0_acc: 0.9985, task_1_acc: 0.1891,                        task_all_acc:0.2518, avg_task_1_2:0.5938
Batch: 5, task_0_acc: 0.9976, task_1_acc: 0.2161,                        task_all_acc:0.2568, avg_task_1_2:0.6069
Batch: 6, task_0_acc: 0.9976, task_1_acc: 0.2375,                        task_all_acc:0.2609, avg_task_1_2:0.6175
Batch: 7, task_0_acc: 0.9976, task_1_acc: 0.2672,                        task_all_acc:0.2665, avg_task_1_2:0.6324
Batch: 8, task_0_acc: 0.9976, task_1_acc: 0.2990,                        task_al

In [75]:
try:
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1_400.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:

            # Train Follower:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(fake_image)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()


            task_0_acc = cal_acc(follower, task_0_test_loader, device)
            task_1_acc = cal_acc(follower, task_1_test_loader, device)  
            task_all_acc = cal_acc(follower, test_loader, device)

            print(f"Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                        task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        adjust_learning_rate(optimizer_F)
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Batch: 0, task_0_acc: 0.9980, task_1_acc: 0.0401,                        task_all_acc:0.2274, avg_task_1_2:0.5191
Batch: 1, task_0_acc: 0.9966, task_1_acc: 0.0875,                        task_all_acc:0.2357, avg_task_1_2:0.5420
Batch: 2, task_0_acc: 0.9961, task_1_acc: 0.1276,                        task_all_acc:0.2423, avg_task_1_2:0.5618
Batch: 3, task_0_acc: 0.9956, task_1_acc: 0.1661,                        task_all_acc:0.2483, avg_task_1_2:0.5809
Batch: 4, task_0_acc: 0.9941, task_1_acc: 0.2057,                        task_all_acc:0.2539, avg_task_1_2:0.5999
Batch: 5, task_0_acc: 0.9937, task_1_acc: 0.2453,                        task_all_acc:0.2611, avg_task_1_2:0.6195
Batch: 6, task_0_acc: 0.9932, task_1_acc: 0.2724,                        task_all_acc:0.2663, avg_task_1_2:0.6328
Batch: 7, task_0_acc: 0.9922, task_1_acc: 0.2953,                        task_all_acc:0.2705, avg_task_1_2:0.6438
Batch: 8, task_0_acc: 0.9897, task_1_acc: 0.3099,                        task_al

In [79]:
model = MLP_Enhance()
model = model.to(device)
optimizer_L = torch.optim.Adam(model.parameters(), 0.0001)

if not os.path.exists('task_1_400.pth'):
    train(criterion, optimizer, model, num_epochs, task_0_train_loader, test_loader, 2, device)

    torch.save(model.state_dict(), 'task_1_400.pth')
else:
    model.load_state_dict(torch.load('task_1_400.pth'))

In [80]:
try:
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1_400.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:
            # Train Follower:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(images)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            # Train Leader
            if (batch_num + 1) % 10 == 0:
                model.train()
                diff_loss = 0
                for i in range (4):
                    fake_image = torch.randn(batch_size, 1, 28, 28)

                    with torch.no_grad():
                        follower_fake_output = follower(fake_image)
                    
                    model_fake_outputs = model(fake_image)

                    diff_loss += torch.mean(torch.abs(follower_fake_output - model_fake_outputs))
                    
                optimizer_L.zero_grad()
                diff_loss.backward()
                optimizer_L.step()


            if (batch_num + 1) % 10 == 0:
                task_0_acc = cal_acc(follower, task_0_test_loader, device)
                task_1_acc = cal_acc(follower, task_1_test_loader, device)  
                task_all_acc = cal_acc(follower, test_loader, device)

                print(f"Follower: Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                            task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")
                
                task_0_acc = cal_acc(model, task_0_test_loader, device)
                task_1_acc = cal_acc(model, task_1_test_loader, device)  
                task_all_acc = cal_acc(model, test_loader, device)
                
                print(f"Leader: Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                            task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        adjust_learning_rate(optimizer_F)
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Follower: Batch: 9, task_0_acc: 0.9971, task_1_acc: 0.3208,                            task_all_acc:0.2768, avg_task_1_2:0.6590
Leader: Batch: 9, task_0_acc: 0.9951, task_1_acc: 0.0073,                            task_all_acc:0.2274, avg_task_1_2:0.5012
Follower: Batch: 19, task_0_acc: 0.9922, task_1_acc: 0.5385,                            task_all_acc:0.3187, avg_task_1_2:0.7654
Leader: Batch: 19, task_0_acc: 0.9941, task_1_acc: 0.0083,                            task_all_acc:0.2265, avg_task_1_2:0.5012
Follower: Batch: 29, task_0_acc: 0.8975, task_1_acc: 0.5297,                            task_all_acc:0.2977, avg_task_1_2:0.7136
Leader: Batch: 29, task_0_acc: 0.9937, task_1_acc: 0.0078,                            task_all_acc:0.2259, avg_task_1_2:0.5007
Follower: Batch: 39, task_0_acc: 0.9053, task_1_acc: 0.4000,                            task_all_acc:0.2722, avg_task_1_2:0.6526
Leader: Batch: 39, task_0_acc: 0.9932, task_1_acc: 0.0151,                            task_all_ac

In [82]:
model = MLP_Enhance()
model = model.to(device)
optimizer_L = torch.optim.Adam(model.parameters(), 0.0001)

if not os.path.exists('task_1_400.pth'):
    train(criterion, optimizer, model, num_epochs, task_0_train_loader, test_loader, 2, device)

    torch.save(model.state_dict(), 'task_1_400.pth')
else:
    model.load_state_dict(torch.load('task_1_400.pth'))

In [83]:
try:
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1_400.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:
            # Train Follower:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(fake_image)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            # Train Leader
            if (batch_num + 1) % 10 == 0:
                model.train()
                diff_loss = 0
                for i in range (4):
                    fake_image = torch.randn(batch_size, 1, 28, 28)

                    with torch.no_grad():
                        follower_fake_output = follower(fake_image)
                    
                    model_fake_outputs = model(fake_image)

                    diff_loss += torch.mean(torch.abs(follower_fake_output - model_fake_outputs))
                    
                optimizer_L.zero_grad()
                diff_loss.backward()
                optimizer_L.step()


            if (batch_num + 1) % 10 == 0:
                task_0_acc = cal_acc(follower, task_0_test_loader, device)
                task_1_acc = cal_acc(follower, task_1_test_loader, device)  
                task_all_acc = cal_acc(follower, test_loader, device)

                print(f"Follower: Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                            task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")
                
                task_0_acc = cal_acc(model, task_0_test_loader, device)
                task_1_acc = cal_acc(model, task_1_test_loader, device)  
                task_all_acc = cal_acc(model, test_loader, device)
                
                print(f"Leader: Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                            task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        adjust_learning_rate(optimizer_F)
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Follower: Batch: 9, task_0_acc: 0.9775, task_1_acc: 0.2938,                            task_all_acc:0.2669, avg_task_1_2:0.6356
Leader: Batch: 9, task_0_acc: 0.9941, task_1_acc: 0.0057,                            task_all_acc:0.2265, avg_task_1_2:0.4999
Follower: Batch: 19, task_0_acc: 0.9189, task_1_acc: 0.3224,                            task_all_acc:0.2588, avg_task_1_2:0.6207
Leader: Batch: 19, task_0_acc: 0.9937, task_1_acc: 0.0063,                            task_all_acc:0.2256, avg_task_1_2:0.5000
Follower: Batch: 29, task_0_acc: 0.8418, task_1_acc: 0.4344,                            task_all_acc:0.2656, avg_task_1_2:0.6381
Leader: Batch: 29, task_0_acc: 0.9932, task_1_acc: 0.0109,                            task_all_acc:0.2252, avg_task_1_2:0.5021
Follower: Batch: 39, task_0_acc: 0.7847, task_1_acc: 0.3240,                            task_all_acc:0.2311, avg_task_1_2:0.5543
Leader: Batch: 39, task_0_acc: 0.9922, task_1_acc: 0.0156,                            task_all_ac

In [84]:
try:
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(torch.load('task_1_400.pth'))
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    epoch = 0
    while(True):
        batch_num = 0
        print(f"Epoch {epoch}")

        for images, labels in task_1_train_loader:
            # Train Follower:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            with torch.no_grad():
                leader_output = model(images)

            diff_loss  = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(fake_image)

                diff_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:4], labels) + 5 * torch.mean(torch.abs(leader_output[:,:2] - outputs[:,:2])) + diff_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            # Train Leader
            if (batch_num + 1) % 10 == 0:
                model.train()
                diff_loss = 0
                for i in range (10):
                    fake_image = torch.randn(batch_size, 1, 28, 28)

                    with torch.no_grad():
                        follower_fake_output = follower(fake_image)
                    
                    model_fake_outputs = model(fake_image)

                    diff_loss += torch.mean(torch.abs(follower_fake_output - model_fake_outputs))
                    
                optimizer_L.zero_grad()
                diff_loss.backward()
                optimizer_L.step()


            if (batch_num + 1) % 10 == 0:
                task_0_acc = cal_acc(follower, task_0_test_loader, device)
                task_1_acc = cal_acc(follower, task_1_test_loader, device)  
                task_all_acc = cal_acc(follower, test_loader, device)

                print(f"Follower: Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                            task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")
                
                task_0_acc = cal_acc(model, task_0_test_loader, device)
                task_1_acc = cal_acc(model, task_1_test_loader, device)  
                task_all_acc = cal_acc(model, test_loader, device)
                
                print(f"Leader: Batch: {batch_num}, task_0_acc: {task_0_acc:.4f}, task_1_acc:{task_1_acc: .4f},\
                            task_all_acc:{task_all_acc:.4f}, avg_task_1_2:{(task_0_acc + task_1_acc)/2:.4f}")

            batch_num += 1
        adjust_learning_rate(optimizer_F)
        epoch += 1
except KeyboardInterrupt:
    print("Train interrupt by keyboard")

Epoch 0
Follower: Batch: 9, task_0_acc: 0.9458, task_1_acc: 0.5557,                            task_all_acc:0.3122, avg_task_1_2:0.7508
Leader: Batch: 9, task_0_acc: 0.9653, task_1_acc: 0.2130,                            task_all_acc:0.2478, avg_task_1_2:0.5892
Follower: Batch: 19, task_0_acc: 0.7271, task_1_acc: 0.4125,                            task_all_acc:0.2375, avg_task_1_2:0.5698
Leader: Batch: 19, task_0_acc: 0.9634, task_1_acc: 0.2156,                            task_all_acc:0.2479, avg_task_1_2:0.5895
Follower: Batch: 29, task_0_acc: 0.6714, task_1_acc: 0.3839,                            task_all_acc:0.2200, avg_task_1_2:0.5276
Leader: Batch: 29, task_0_acc: 0.9595, task_1_acc: 0.2229,                            task_all_acc:0.2482, avg_task_1_2:0.5912
Follower: Batch: 39, task_0_acc: 0.6650, task_1_acc: 0.2198,                            task_all_acc:0.1846, avg_task_1_2:0.4424
Leader: Batch: 39, task_0_acc: 0.9585, task_1_acc: 0.2276,                            task_all_ac

# Best Thing So far

In [253]:
import torch.nn.init as init
class MLP_Enhance(nn.Module):
    def __init__(self, out_dim=10, in_channel=1, img_sz=28, hidden_dim=400):
        super(MLP_Enhance, self).__init__()
        self.in_dim = in_channel*img_sz*img_sz
        self.linear = nn.Sequential(
            nn.Linear(self.in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),
        )
        self.last = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.Softmax(dim=1)
        )
        self.init_weights()

    def init_weights(self):
        # Initialize Linear layers using He initialization
        for layer in self.linear:
            if isinstance(layer, nn.Linear):
                init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                init.constant_(layer.bias, 0)
        # Initialize the last layer using Xavier initialization
        init.xavier_normal_(self.last.weight)
        init.constant_(self.last.bias, 0)


    def features(self, x):
        x = self.linear(x.view(-1,self.in_dim))
        return x

    def logits(self, x):
        x = self.last(x)
        return x

    def forward(self, x):
        x = self.features(x)
        x = self.logits(x)
        return x

In [85]:
task_dataloaders = {}
for task in range(0, 5):
    train_dl, test_ld = task_dataloader(task)
    task_dataloaders[task] = (train_dl, test_ld)

In [254]:
def adjust_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.75 * param_group['lr']

In [283]:
def train(model, task_num, criterion, epoches):
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(model.state_dict())
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.005)
    train_loader = task_dataloaders[task_num][0]

    valid_out_dim = task_num * 2 + 2
    print(f"##########Task {task_num}##########")
    for e in range(epoches):
        print(f"Epoch {e}")

        for images, labels in train_loader:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            except_mean = 0.2

            # Loss part 1
            empty_loss = 0
            for task in range (5):
                    task_start = task * 2
                    task_end = (task + 1) * 2

                    fake_image = torch.randn(batch_size, 1, 28, 28)
                    fake_output = follower(fake_image)

                    real_mean = torch.mean(torch.abs(fake_output[:,task_start:task_end]))
                    empty_loss += torch.abs(except_mean - real_mean)
            

            # Loss part 2
            with torch.no_grad():
                leader_output = model(images)

            if (task_num == 0):
                diff_loss =  0
            else:
                diff_loss =  10 * torch.mean(torch.abs(leader_output[:,:valid_out_dim - 2] - outputs[:,:valid_out_dim - 2]))


            # Loss part 3
            interupt_loss = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(images)

                if (task_num == 0):
                    interupt_loss =  0
                else:
                    interupt_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:valid_out_dim], labels) + empty_loss + diff_loss + 0.65 * interupt_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

        adjust_learning_rate(optimizer_F)
        avg_acc = 0
        for task in range(task_num + 1):
            acc = cal_acc(follower, task_dataloaders[task][0], device)
            avg_acc += acc
            print(f"Task {task} acc: { acc * 100:.4f}", end = ', ')
        
        print(f"Task avg acc:{avg_acc*100/(task_num + 1):.4f}")
    
    return follower

In [284]:
model = MLP_Enhance()
model = model.to(device)
for task in [0, 1, 2, 3, 4]:
    model = train(model, task, criterion, 10)

##########Task 0##########
Epoch 0
Task 0 acc: 99.5057, Task avg acc:99.5057
Epoch 1
Task 0 acc: 99.6811, Task avg acc:99.6811
Epoch 2
Task 0 acc: 99.5615, Task avg acc:99.5615
Epoch 3
Task 0 acc: 99.5934, Task avg acc:99.5934
Epoch 4
Task 0 acc: 99.6253, Task avg acc:99.6253
Epoch 5
Task 0 acc: 99.5137, Task avg acc:99.5137
Epoch 6
Task 0 acc: 99.5615, Task avg acc:99.5615
Epoch 7
Task 0 acc: 99.6652, Task avg acc:99.6652
Epoch 8
Task 0 acc: 99.4898, Task avg acc:99.4898
Epoch 9
Task 0 acc: 99.5376, Task avg acc:99.5376
##########Task 1##########
Epoch 0
Task 0 acc: 66.1033, Task 1 acc: 68.9079, Task avg acc:67.5056
Epoch 1
Task 0 acc: 86.3999, Task 1 acc: 61.2866, Task avg acc:73.8432
Epoch 2
Task 0 acc: 73.7723, Task 1 acc: 69.0409, Task avg acc:71.4066
Epoch 3
Task 0 acc: 80.3412, Task 1 acc: 70.3042, Task avg acc:75.3227
Epoch 4
Task 0 acc: 72.9592, Task 1 acc: 74.5180, Task avg acc:73.7386
Epoch 5
Task 0 acc: 72.9592, Task 1 acc: 70.3956, Task avg acc:71.6774
Epoch 6
Task 0 acc: 

In [252]:
torch.save(model.state_dict(), 'task_2_400.pth')

In [257]:
def train_2(model, task_num, criterion, epoches):
    follower = MLP_Enhance()
    follower = follower.to(device)
    follower.load_state_dict(model.state_dict())
    criterion = nn.CrossEntropyLoss()
    optimizer_F = torch.optim.Adam(follower.parameters(), 0.0005)
    train_loader = task_dataloaders[task_num][0]

    valid_out_dim = task_num * 2 + 2
    print(f"##########Task {task_num}##########")
    for e in range(epoches):
        print(f"Epoch {e}")

        for images, labels in train_loader:
            follower.train()
            images, labels = images.to(device), labels.to(device)
            outputs = follower(images)

            empty_loss  = 0
            except_mean = 0.2

            # Loss part 1
            for i in range (10):
                task_start = task_num * 2
                task_end = (task_num + 1) * 2

                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    fake_output = model(fake_image)

                real_mean = torch.mean(torch.abs(fake_output[:,task_start:task_end]))
                empty_loss += torch.abs(except_mean - real_mean)
            

            # Loss part 2
            with torch.no_grad():
                leader_output = model(images)

            if (task_num == 0):
                diff_loss =  0
            else:
                diff_loss =  10 * torch.mean(torch.abs(leader_output[:,:valid_out_dim - 2] - outputs[:,:valid_out_dim - 2]))


            # Loss part 4
            for i in range (10):
                task_start = task_num * 2
                task_end = (task_num + 1) * 2

                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    fake_output = model(fake_image)
                
                fake_output_follower = model(fake_image)

                real_mean = torch.mean(torch.abs(fake_output[:,task_start:task_end]))
                empty_loss += torch.abs(except_mean - real_mean)


            # Loss part 3
            interupt_loss = 0
            for i in range (4):
                fake_image = torch.randn(batch_size, 1, 28, 28)

                with torch.no_grad():
                    leader_fake_output = model(fake_image)
                
                fake_outputs = follower(images)

                if (task_num == 0):
                    diff_loss =  0
                else:
                    interupt_loss += torch.mean(torch.abs(leader_fake_output - fake_outputs))

            loss = criterion(outputs[:,:valid_out_dim], labels) + empty_loss + diff_loss + 1.3 * interupt_loss

            optimizer_F.zero_grad()
            loss.backward()
            optimizer_F.step()

            if (task_num == 2):
                avg_acc = 0
                for task in range(task_num + 1):
                    acc = cal_acc(follower, task_dataloaders[task][0], device)
                    avg_acc += acc
                    print(f"Task {task} acc: { acc * 100:.4f}", end = ', ')
                
                print(f"Task avg acc:{avg_acc*100/(task_num + 1):.4f}")

        adjust_learning_rate(optimizer_F)
    
    return follower

In [258]:
model = MLP_Enhance()
model = model.to(device)
model.load_state_dict(torch.load('task_2_400.pth'))
model = train_2(model, 2, criterion, 4)

##########Task 2##########
Epoch 0
Task 0 acc: 79.8788, Task 1 acc: 89.6193, Task 2 acc: 0.0000, Task avg acc:56.4994
Task 0 acc: 71.9627, Task 1 acc: 88.4890, Task 2 acc: 0.0000, Task avg acc:53.4839
Task 0 acc: 60.6744, Task 1 acc: 86.6772, Task 2 acc: 0.0000, Task avg acc:49.1172
Task 0 acc: 52.7583, Task 1 acc: 83.2862, Task 2 acc: 0.0000, Task avg acc:45.3482
Task 0 acc: 54.5281, Task 1 acc: 80.9342, Task 2 acc: 0.0000, Task avg acc:45.1541
Task 0 acc: 55.4927, Task 1 acc: 80.8178, Task 2 acc: 0.0000, Task avg acc:45.4368
Task 0 acc: 57.3262, Task 1 acc: 79.3966, Task 2 acc: 0.0000, Task avg acc:45.5743
Task 0 acc: 58.5938, Task 1 acc: 79.7374, Task 2 acc: 0.0000, Task avg acc:46.1104
Task 0 acc: 59.9171, Task 1 acc: 80.9092, Task 2 acc: 0.0000, Task avg acc:46.9421
Task 0 acc: 62.2608, Task 1 acc: 81.3913, Task 2 acc: 0.0000, Task avg acc:47.8840
Task 0 acc: 61.6948, Task 1 acc: 83.8265, Task 2 acc: 0.0000, Task avg acc:48.5071
Task 0 acc: 63.5762, Task 1 acc: 83.3943, Task 2 acc

KeyboardInterrupt: 