In [42]:
from dataloader_utils import get_conbined_permute_mnist, get_conbined_split_mnist, get_conbined_splitted_and_shuffled_mnist
from autoencoder import Autoencoder
from autoencoder_utils import * 

import numpy as np

import matplotlib.pyplot as plt
import pandas as pd

import torch
import torch.nn as nn

from tqdm.notebook import tqdm

In [44]:
# Define a simple CNN and MLP using PyTorch
class SmallCNN(nn.Module):
    def __init__(self, outdim):
        super(SmallCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 128)  # Adjust the input size based on your data
        self.fc2 = nn.Linear(128, outdim)    # Output size depends on the number of classes
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        self.loss = nn.CrossEntropyLoss()
        
    def forward(self, x):
        x = x.reshape((x.shape[0], 1, 28, 28))
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 1024)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class MLP(nn.Module):
    def __init__(self, outdim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, outdim)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image to a vector
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [4]:
def train(train_loader, expert_outdim):
    auto_list = {}
    expert_list = {}
    #debug
    record = {}

    #https://stats.stackexchange.com/questions/521461/train-a-model-on-batches-with-multiple-epochs-vs-each-batch-with-multiple-epoch
    
    #for i, data in enumerate(train_loader):
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        
        images, labels, indicies = data
        
       
        ###debug###
        #print(labels)
        #show_image = images[0].cpu().detach().numpy().reshape((28,28))
        #plt.imshow(show_image) # Plot the 28x28 image
        #plt.show()
        ###########
        
        #initial
        if len(auto_list)==0:
            
            ###debug###
            print(f"[@ batch {i}] NEW autoencoder at {len(auto_list)} for Task {indicies[0].item()}")
            record[indicies[0].item()] = BATCH_SIZE
            ###########
            
            #initial autoencoder
            new_autoencoder = Autoencoder(input_dims=28*28, code_dims=CODE_DIM)
            for epoch in range(NEW_AUTOENCODER_EPOCH):
                new_autoencoder.optimize_params(images, images)
            auto_list[len(auto_list)] = new_autoencoder

            #initial expert
            model = MLP(expert_outdim)
            #model = SmallCNN(expert_outdim)
            for _ in range(10):
                model.optimizer.zero_grad()
                predicted_output = model(images)
                fit = model.loss(predicted_output, labels)
                fit.backward()
                model.optimizer.step()
            expert_list[len(auto_list)-1] = model
            continue

        #find best autoencoder
        best_index = find_best_autoencoders(images, auto_list)
        best_autoencoder = auto_list[best_index]


        #calculate outliers
        outliers = find_num_of_outliers(images, best_autoencoder)
        #print(f"[@ batch {i}] outliers for best autoencoders {best_index}: {outliers}")

        if outliers > OUTLIER_THRESHOLD:
            
            ###debug###
            print(f"[@ batch {i}] outliers for best autoencoders at index: {best_index} : {outliers}")
            print(f"[@ batch {i}] NEW autoencoder at {len(auto_list)} for Task: {indicies[0].item()}")
            if indicies[0].item() in record.keys():
                print(f"[@ batch {i}] DUPLICATE autoencoder for Task: {indicies[0].item()}")
            record[indicies[0].item()] = BATCH_SIZE
            ###########
            
            #add new autoencoder
            best_autoencoder = Autoencoder(input_dims=28*28, code_dims=CODE_DIM)
            for epoch in range(NEW_AUTOENCODER_EPOCH):
                best_autoencoder.optimize_params(images, images)
            auto_list[len(auto_list)] = best_autoencoder

            #add new expert
            model = MLP(expert_outdim)
            #model = SmallCNN(expert_outdim)
            # Train new expert here if required
            for _ in range(10):
                model.optimizer.zero_grad()
                predicted_output = model(images)
                fit = model.loss(predicted_output, labels)
                fit.backward()
                model.optimizer.step()
            expert_list[len(auto_list)-1] = model
            
        else:
            
            ###debug###
            #print(f"training autoencoder at {best_index} with index: {indicies[0].item()}")
            if not indicies[0].item() in record.keys():
                print(f"[@ batch {i}] outliers for best autoencoders at index: {best_index} : {outliers}")
                print(f"[@ batch {i}] MISSING autoencoder with Task {indicies[0].item()}")
            else:
                record[indicies[0].item()] += BATCH_SIZE
            ###########
            
            #train best autoencoder
            for epoch in range(TRAIN_AUTOENCODER_EPOCH):
                best_autoencoder.optimize_params(images, images)

            #train exsisting expert
            model = expert_list[best_index]
            # Train new expert here if required
            for _ in range(10):
                model.optimizer.zero_grad()
                predicted_output = model(images)
                fit = model.loss(predicted_output, labels)
                fit.backward()
                model.optimizer.step()

        #if i % 100 ==1 :
            #best_pred = best_autoencoder.get_prediction(images)
            #sample_loss = torch.mean(best_autoencoder.get_unreduced_loss(best_pred, images), dim=1)
            #sort_loss, _ = torch.sort(sample_loss, descending=True)
            #plt.plot(list(range(sample_loss.shape[0])), sort_loss.cpu().detach().numpy())
            #tsplot(sample_loss.cpu().detach().numpy())
            #print(record[indicies[0].item()])

    print("Complete!")
    
    ###debug###
    #print(expert_list)
    ###########
    
    return auto_list, expert_list

In [5]:
def test(test_loader, auto_list, expert_list):
    with torch.no_grad():
        test_loss = 0
        correct = 0
        total = 0
        for images, labels, indices in test_loader:
            images = images.view(images.shape[0], -1)  # Flatten images
            best_index = find_best_autoencoders(images, auto_list)
            best_autoencoder = auto_list[best_index]
            classifier = expert_list[best_index]

            # Forward pass
            outputs = classifier(images)
            loss = classifier.loss(outputs, labels)
            test_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print(f"Test Loss: {test_loss / len(test_loader)}")
        print(f"Accuracy: {100 * correct / total}%")

In [10]:
BATCH_SIZE = 300

OUTLIER_THRESHOLD = 0.2*BATCH_SIZE
NEW_AUTOENCODER_EPOCH = 100
TRAIN_AUTOENCODER_EPOCH = 10
CODE_DIM = 350

NUM_TASK = 10
RANDOM_SEED = np.random.randint(100)
#RANDOM_SEED = 42

In [12]:
train_loader, test_loader = get_conbined_permute_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)

In [13]:
auto_list, expert_list = train(train_loader, 10)

  0%|          | 0/2000 [00:00<?, ?it/s]

[@ batch 0] NEW autoencoder at 0 for Task 6
[@ batch 1] outliers for best autoencoders at index: 0 : 300
[@ batch 1] NEW autoencoder at 1 for Task: 4
[@ batch 4] outliers for best autoencoders at index: 1 : 300
[@ batch 4] NEW autoencoder at 2 for Task: 1
[@ batch 5] outliers for best autoencoders at index: 1 : 300
[@ batch 5] NEW autoencoder at 3 for Task: 2
[@ batch 8] outliers for best autoencoders at index: 3 : 300
[@ batch 8] NEW autoencoder at 4 for Task: 0
[@ batch 11] outliers for best autoencoders at index: 1 : 300
[@ batch 11] NEW autoencoder at 5 for Task: 7
[@ batch 12] outliers for best autoencoders at index: 4 : 300
[@ batch 12] NEW autoencoder at 6 for Task: 5
[@ batch 14] outliers for best autoencoders at index: 1 : 300
[@ batch 14] NEW autoencoder at 7 for Task: 9
[@ batch 18] outliers for best autoencoders at index: 5 : 298
[@ batch 18] NEW autoencoder at 8 for Task: 3
[@ batch 21] outliers for best autoencoders at index: 2 : 300
[@ batch 21] NEW autoencoder at 9 for 

In [14]:
test(test_loader, auto_list, expert_list)

Test Loss: 0.45130700755496317
Accuracy: 92.785%


In [15]:
#unsuccess case due to relatedness of different tasks (e.g. [2, 7] and [3, 1])
#BATCH_SIZE = 1000

#OUTLIER_THRESHOLD = 0.1*BATCH_SIZE
#NEW_AUTOENCODER_EPOCH = 500
#TRAIN_AUTOENCODER_EPOCH = 20
#CODE_DIM = 350

#NUM_TASK = 3
#RANDOM_SEED = np.random.randint(100)
#RANDOM_SEED = 42

#train_loader, test_loader = get_conbined_split_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)

#auto_list, expert_list = train(train_loader, 2)

#test(test_loader, auto_list, expert_list)

In [31]:
BATCH_SIZE = 300

OUTLIER_THRESHOLD = 0.2*BATCH_SIZE
NEW_AUTOENCODER_EPOCH = 100
TRAIN_AUTOENCODER_EPOCH = 10
CODE_DIM = 500

NUM_TASK = 3
RANDOM_SEED = np.random.randint(100)
#RANDOM_SEED = 42

In [33]:
train_loader, test_loader, shuffle_idx = get_conbined_splitted_and_shuffled_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)

split classes: [[8, 6], [2, 9], [0, 7]]


In [18]:
#MLP classifier, no need to unshuffled
#show_image = images[9].cpu().detach().numpy().reshape((28,28))
#plt.imshow(show_image) # Plot the 28x28 image
#plt.show()

#newImage = np.array([0.0]*(784))
#show_image = show_image.reshape(-1)
#for i,j in enumerate(shuffle_idx):
#    newImage[j] = show_image[i]
#plt.imshow(newImage.reshape(28,28)) # Plot the 28x28 image
#plt.show()


In [34]:
auto_list, expert_list = train(train_loader, 2)

  0%|          | 0/120 [00:00<?, ?it/s]

[@ batch 0] NEW autoencoder at 0 for Task 2
[@ batch 1] outliers for best autoencoders at index: 0 : 198
[@ batch 1] NEW autoencoder at 1 for Task: 0
[@ batch 3] outliers for best autoencoders at index: 1 : 87
[@ batch 3] NEW autoencoder at 2 for Task: 1
Complete!


In [35]:
test(test_loader, auto_list, expert_list)

Test Loss: 0.08548424832886212
Accuracy: 97.50877779635512%
