In [1]:
from dataloader_utils import get_conbined_permute_mnist, get_conbined_split_mnist, get_conbined_splitted_and_shuffled_mnist
from autoencoder import Autoencoder
from autoencoder_utils import *

import numpy as np

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import transforms

DEBUG = 1

In [2]:
# Define a simple CNN and MLP using PyTorch
class SmallCNN(nn.Module):
    def __init__(self, outdim):
        super(SmallCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 128)  # Adjust the input size based on your data
        self.fc2 = nn.Linear(128, outdim)    # Output size depends on the number of classes
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = x.reshape((x.shape[0], 1, 28, 28))
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 1024)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class MLP(nn.Module):
    def __init__(self, outdim):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, outdim)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image to a vector
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
# the training
def train(train_loader, num_epochs=10):
    auto_list = {}
    expert_list = {}
    the_loss_history = []

    #debug
    record = {}
    #https://stats.stackexchange.com/questions/521461/train-a-model-on-batches-with-multiple-epochs-vs-each-batch-with-multiple-epoch
    #for i, data in enumerate(train_loader):
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        images, labels, indicies = data

        #debug
        #show_image = images[0].cpu().detach().numpy().reshape((28,28))
        #plt.imshow(show_imagxe) # Plot the 28x28 image
        #plt.show()

        #initial
        if len(auto_list)==0:
            #debug
            if DEBUG: print(f"[@ batch {i}] NEW autoencoder at {len(auto_list)} for Task {indicies[0].item()}")
            record[len(auto_list)] = indicies[0].item()
            #initial autoencoder
            new_autoencoder = Autoencoder(input_dims=28*28, code_dims=CODE_DIM)
            for epoch in range(NEW_AUTOENCODER_EPOCH):
                new_autoencoder.optimize_params(images, images)
            auto_list[len(auto_list)] = new_autoencoder

            #initial expert
            # expert_list[len(auto_list)-1] = a expert
            classifier = nn.Linear(28*28, 10)
            expert_list[len(auto_list)-1] = classifier
            batch_loss_history = []
            for epoch in range(num_epochs):
              #to-do train exsisting expert
              optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
              loss_function = nn.CrossEntropyLoss()
              optimizer.zero_grad()
              outputs = classifier(images.view(images.shape[0], -1))
              loss = loss_function(outputs, labels)
              loss.backward()
              optimizer.step()
              batch_loss_history.append(loss)

            continue

        #find best autoencoder
        best_index = find_best_autoencoders(images, auto_list)
        best_autoencoder = auto_list[best_index]

        #calculate outliers
        outliers = find_num_of_outliers(images, best_autoencoder)
        if DEBUG: print(f"[@ batch {i}] outliers for best autoencoders {best_index}: {outliers}")

        if outliers > OUTLIER_THRESHOLD:
            #debug
            if DEBUG: print(f"[@ batch {i}] outliers for best autoencoders at index: {best_index} : {outliers}")
            if DEBUG: print(f"[@ batch {i}] NEW autoencoder at {len(auto_list)} for Task: {indicies[0].item()}")
            if indicies[0].item() in record.values():
                if DEBUG: print(f"[@ batch {i}] DUPLICATE autoencoder for Task: {indicies[0].item()}")
                record[len(auto_list)] = indicies[0].item()
            else:
                record[len(auto_list)] = indicies[0].item()

            best_autoencoder = Autoencoder(input_dims=28*28, code_dims=CODE_DIM)
            for epoch in range(NEW_AUTOENCODER_EPOCH):
                best_autoencoder.optimize_params(images, images)
            auto_list[len(auto_list)] = best_autoencoder


            #add new expert
            # expert_list[len(auto_list)-1] = a expert
            classifier = nn.Linear(28*28, 10)
            expert_list[len(auto_list)-1] = classifier
            batch_loss_history = []
            for epoch in range(num_epochs):
              #train exsisting expert
              optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
              loss_function = nn.CrossEntropyLoss()
              optimizer.zero_grad()
              outputs = classifier(images.view(images.shape[0], -1))
              loss = loss_function(outputs, labels)
              loss.backward()
              optimizer.step()
              batch_loss_history.append(loss)

        else:
            #debug
            if DEBUG: print(f"training autoencoder at {best_index} with index: {indicies[0].item()} with Task {indicies[0].item()}")
            if not indicies[0].item() in record.values():
                if DEBUG: print(f"[@ batch {i}] outliers for best autoencoders at index: {best_index} : {outliers}")
                if DEBUG: print(f"[@ batch {i}] MISSING autoencoder with Task {indicies[0].item()}")
                continue

            #train best autoencoder
            for epoch in range(TRAIN_AUTOENCODER_EPOCH):
                best_autoencoder.optimize_params(images, images)

            batch_loss_history = []
            #train exsisting expert
            for epoch in range(num_epochs):
              classifier = expert_list[best_index]
              optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
              loss_function = nn.CrossEntropyLoss()
              optimizer.zero_grad()
              outputs = classifier(images.view(images.shape[0], -1))
              loss = loss_function(outputs, labels)
              loss.backward()
              optimizer.step()
              batch_loss_history.append(loss)
            # print(f"Loop {i}: Loss = {batch_loss_history[-1]:.4f}")

            the_loss_history.append(batch_loss_history)
        # if i % (len(train_loader)/20) ==0 :
        #     print(f"Loop {i}: Loss = {batch_loss_history[-1]:.4f}, for task {indicies[0].item()}")

    # print('the_loss_history:', the_loss_history)
    print(f"BATCH_SIZE:{BATCH_SIZE} NUM_TASK:{NUM_TASK} train done!")

    #debug
    # print(expert_list)

    return auto_list, expert_list, the_loss_history

In [4]:
def test(test_loader, auto_list, expert_list):
    with torch.no_grad():
        test_loss = 0
        correct = 0
        total = 0
        for images, labels, indices in test_loader:
            images = images.view(images.shape[0], -1)  # Flatten images
            best_index = find_best_autoencoders(images, auto_list)
            best_autoencoder = auto_list[best_index]
            classifier = expert_list[best_index]

            # Forward pass
            outputs = classifier(images)
            loss = nn.CrossEntropyLoss()(outputs, labels)
            test_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        ac = 100 * correct / total
        print(f"Test Loss: {test_loss / len(test_loader)}")
        print(f"Accuracy: {ac}%")
        return test_loss/len(test_loader), ac

In [5]:
# BATCH_SIZE = 300

# OUTLIER_THRESHOLD = 0.2*BATCH_SIZE
# NEW_AUTOENCODER_EPOCH = 100
# TRAIN_AUTOENCODER_EPOCH = 10
# CODE_DIM = 350

# NUM_TASK = 5
# RANDOM_SEED = np.random.randint(100)
# #RANDOM_SEED = 42
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# train_loader, test_loader = get_conbined_permute_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)

In [7]:
# auto_list, expert_list, the_loss_history = train(train_loader)

In [8]:
import random

# the autoencoder training
# num_train_batch is the number of batches of data to train the autoencoder
def train_ae(train_loader, num_epochs=10, num_train_batch=1):
    auto_list = {}
    group_size = len(train_loader)
    count_wrong = {}
    count_wrong_vital = []
    for v in range(10):
      counter_vital = 0
      counter = 0
      random_idx1 = random.randrange(0, group_size//2)
      the_ae = None
      taskNo1 = -1
      trained_counter = 1
      for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
          images, labels, indicies = data
          taskNo = indicies[0].item()
          if i == random_idx1:
              #initial autoencoder
              the_ae = Autoencoder(input_dims=28*28, code_dims=CODE_DIM)
              auto_list[taskNo] = the_ae
              taskNo1 = taskNo
              if count_wrong.get(taskNo1)==None: count_wrong[taskNo1] = []
              for epoch in range(NEW_AUTOENCODER_EPOCH):
                  the_ae.optimize_params(images, images)
          elif taskNo == taskNo1 and trained_counter < num_train_batch:
            trained_counter += 1
            for epoch in range(TRAIN_AUTOENCODER_EPOCH):
              the_ae.optimize_params(images, images)
      # print('trained_counter', trained_counter)
      for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
          images, labels, indicies = data
          taskNo = indicies[0].item()
          outliers = find_num_of_outliers(images, the_ae)
          if outliers <= OUTLIER_THRESHOLD:
            if taskNo != taskNo1:
              counter_vital += 1
              counter += 1
          elif taskNo1 == taskNo:
            counter += 1
      count_wrong[taskNo1].append(counter)
      count_wrong_vital.append(counter_vital)

    return count_wrong, count_wrong_vital, auto_list

In [None]:
BATCH_SIZE = 300

OUTLIER_THRESHOLD = 0.2*BATCH_SIZE
NEW_AUTOENCODER_EPOCH = 100
TRAIN_AUTOENCODER_EPOCH = 10
CODE_DIM = 350

NUM_TASK = 5
RANDOM_SEED = np.random.randint(100)
#RANDOM_SEED = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader, test_loader = get_conbined_permute_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)
error_rates = {
    1: [],
    2: [],
    5: [],
    10: [],
}
for x in range(10):
  for i in (1,2,5,10):
    count_wrong, count_wrong_vital, auto_list = train_ae(train_loader, num_train_batch=i)
    for t in range(NUM_TASK):
      e_mean = np.mean([0] if count_wrong.get(t)==None else count_wrong.get(t))
      if x == 0:
        error_rates[i].append(e_mean)
      else:
        error_rates[i][t] = (error_rates[i][t]*x + e_mean)/(x+1)
      print(f'{i} Task{t} {count_wrong.get(t)}, {count_wrong_vital}, {e_mean}')

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./MNIST/PermutedMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:08<00:00, 1207241.87it/s]


Extracting ./MNIST/PermutedMNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/PermutedMNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./MNIST/PermutedMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 130054.86it/s]


Extracting ./MNIST/PermutedMNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/PermutedMNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/PermutedMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:01<00:00, 1055414.62it/s]


Extracting ./MNIST/PermutedMNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/PermutedMNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/PermutedMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4542329.22it/s]


Extracting ./MNIST/PermutedMNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/PermutedMNIST/raw



  self.pid = os.fork()
  self.pid = os.fork()


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

1 Task0 [18, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 9.5
1 Task1 [2, 25], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 13.5
1 Task2 [2, 4], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 3.0
1 Task3 [0, 1], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 0.5
1 Task4 [20, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 10.0


  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

[1, 3, 0, 8, 1, 25, 11, 2, 2, 20]

[1, 0, 10, 0, 6, 6, 6, 0, 0, 0]

[0, 0, 10, 5, 5, 2, 6, 10, 12, 18]



In [None]:
import matplotlib.pyplot as plt

for i in range(NUM_TASK):
  x = [1, 2, 5, 10]
  y = [ error_rates[v][i] for v in x ]

  plt.figure(figsize=(8, 6))
  plt.plot(x, y, marker='o', linestyle='-', color='b')
  plt.title(f'Error Rates by Number of Data Batches Trained for Task {i} of Permuted MNIST')
  plt.xlabel('Number of Batches of Data')
  plt.ylabel('Error Rate(%)')
  plt.grid(True)
  plt.xticks(x)
  plt.show()

In [None]:
auto_loss = [[],[]]
auto_lossx = [[],[]]
outliers1 = [[],[]]
outliers2 = [[],[]]
# ixx = {0,1,2,7,12}
# train autoencoder for two tasks and get loss
def traint(train_loader, num_epochs=10):
    #debug
    record = {}
    idx1 = idxx[0]
    idx2 = idxx[1]
    #https://stats.stackexchange.com/questions/521461/train-a-model-on-batches-with-multiple-epochs-vs-each-batch-with-multiple-epoch
    #for i, data in enumerate(train_loader):
    for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
        images, labels, indicies = data
        if i in idx1:
          thei = 0
          thex = 1
        elif i in idx2:
          thei = 1
          thex = 0
        else:
          continue
        # if i in ixx: continue
        # thei = 0
        # thex = 1
        best_autoencoder = auto_list[thei]
        py = best_autoencoder.get_prediction(images)
        ploss = best_autoencoder.get_reduced_loss(py, images)
        an_autoencoder = auto_list[thex]
        py2 = an_autoencoder.get_prediction(images)
        ploss2 = an_autoencoder.get_reduced_loss(py2, images)

        outliers1[thei].append(find_num_of_outliers(images, best_autoencoder))
        outliers2[thex].append(find_num_of_outliers(images, an_autoencoder))


        auto_loss[thei].append(ploss)
        auto_lossx[thex].append(ploss2)
        # print(f'Task {thei} loss {ploss}')

    # print('the_loss_history:', auto_loss)
    print(f"BATCH_SIZE:{BATCH_SIZE} NUM_TASK:{NUM_TASK} loss done!")

    #debug
    # print(expert_list)

    return auto_list, expert_list, the_loss_history

_, _, _ = traint(train_loader)

In [None]:
indices = list(range(len(auto_loss[0])))

# Plotting the data
plt.figure()
plt.plot(indices, auto_loss[0], marker='', label='Task 1')  # Using 'o' as the marker for each point
plt.plot(indices, auto_lossx[0], marker='', label='Task 2')  # Using 'o' as the marker for each point

# Adding title and labels
# plt.ylim(0.01, 0.1)
plt.title('Loss of Autoencoder 1')
plt.xlabel('Index')
plt.ylabel('Value')
plt.legend()

# Display the plot
plt.show()

# Plotting the data
plt.figure()
plt.plot(indices, auto_lossx[1], marker='', label='Task 1')  # Using 'o' as the marker for each point
plt.plot(indices, auto_loss[1], marker='', label='Task 2')  # Using 'o' as the marker for each point

# Adding title and labels
# plt.ylim(0.01, 0.1)
plt.title('Loss for Autoencoder 2')
plt.xlabel('Index')
plt.ylabel('Value')
plt.legend()

# Display the plot
plt.show()


In [None]:
indices = list(range(len(outliers1[0])))

# Plotting the data
plt.figure()
plt.plot(indices, outliers1[0], marker='', label='Task 1')  # Using 'o' as the marker for each point
plt.plot(indices, outliers2[0], marker='', label='Task 2')  # Using 'o' as the marker for each point
plt.axhline(y=OUTLIER_THRESHOLD, color='r', linestyle='--', label='THRESHOLD')  # Adding a red dashed line for y = 0.5

# Adding title and labels
# plt.ylim(0.01, 0.1)
plt.title('Outliers of Autoencoder 1')
plt.xlabel('Index')
plt.ylabel('Value')
plt.legend()

# Display the plot
plt.show()

indices = list(range(len(outliers2[0])))
# Plotting the data
plt.figure()
plt.plot(indices, outliers1[1], marker='', label='Task 1')  # Using 'o' as the marker for each point
plt.plot(indices, outliers2[1], marker='', label='Task 2')  # Using 'o' as the marker for each point
plt.axhline(y=OUTLIER_THRESHOLD, color='r', linestyle='--', label='THRESHOLD')  # Adding a red dashed line for y = 0.5

# Adding title and labels
# plt.ylim(0.01, 0.1)
plt.title('Outliers for Autoencoder 2')
plt.xlabel('Index')
plt.ylabel('Value')
plt.legend()

# Display the plot
plt.show()


In [None]:
# exp_losses = {}
# acs = {}
# test_losses = {}
# for i in range(1,11):
#   NUM_TASK = i
#   RANDOM_SEED = np.random.randint(50)+50
#   print(f"NUM_TASK:{NUM_TASK} Started!")
#   train_loader, test_loader = get_conbined_permute_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)
#   auto_list, expert_list, the_loss_history = train(train_loader)
#   exp_losses[NUM_TASK] = the_loss_history
#   test_loss, ac = test(test_loader, auto_list, expert_list)
#   acs[NUM_TASK] = ac
#   test_losses[NUM_TASK] = test_loss
#   print(f"NUM_TASK:{NUM_TASK} Finished!------------------")

In [None]:
BATCH_SIZE = 320

OUTLIER_THRESHOLD = 0.1*BATCH_SIZE
NEW_AUTOENCODER_EPOCH = 200
TRAIN_AUTOENCODER_EPOCH = 20
CODE_DIM = 350

NUM_TASK = 5
RANDOM_SEED = np.random.randint(100)
#RANDOM_SEED = 42
train_loader, test_loader, idx = get_conbined_splitted_and_shuffled_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)

error_ratess = {
    1: [],
    2: [],
    5: [],
    10: [],
}
for x in range(10):
  for i in {1, 2, 5, 10}:
    count_wrong_split, count_wrong_vital_split, auto_list_split = train_ae(train_loader, num_train_batch=i)
    for t in range(NUM_TASK):
      e_mean = np.mean([0] if count_wrong_split.get(t)==None else count_wrong_split.get(t))
      if x == 0:
        error_ratess[i].append(e_mean)
      else:
        error_ratess[i][t] = (error_ratess[i][t]*x + e_mean)/(x+1)
      print(x, i, t, error_ratess)
      # print(f'{x} {i} Task{t} {count_wrong_split.get(t)}, {count_wrong_vital_split}, {e_mean}')

In [None]:
import matplotlib.pyplot as plt

for i in range(NUM_TASK):
  x = [1, 2, 5, 10]
  y = [ error_ratess[v][i]/188 for v in x ]

  plt.figure(figsize=(8, 6))
  plt.plot(x, y, marker='o', linestyle='-', color='b')
  plt.title(f'Error Rates by Number of Data Batches Trained for Task {i} of Splitted&Shuffled MNIST')
  plt.xlabel('Number of Batches of Data')
  plt.ylabel('Error Rate (%)')
  plt.grid(True)
  plt.xticks(x)
  plt.show()

In [None]:
BATCH_SIZE = 320

OUTLIER_THRESHOLD = 0.1*BATCH_SIZE
NEW_AUTOENCODER_EPOCH = 200
TRAIN_AUTOENCODER_EPOCH = 20
CODE_DIM = 350

NUM_TASK = 5
RANDOM_SEED = np.random.randint(100)
#RANDOM_SEED = 42

In [None]:
exp_losses = {}
acs = {}
test_losses = {}
for i in range(1,6):
  NUM_TASK = i
  RANDOM_SEED = np.random.randint(100)
  print(f"NUM_TASK:{NUM_TASK} Started!")
  train_loader, test_loader, idx = get_conbined_splitted_and_shuffled_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)
  auto_list, expert_list, the_loss_history = train(train_loader)
  exp_losses[NUM_TASK] = the_loss_history
  test_loss, ac = test(test_loader, auto_list, expert_list)
  acs[NUM_TASK] = ac
  test_losses[NUM_TASK] = test_loss
  print(f"NUM_TASK:{NUM_TASK} Finished!------------------")

In [None]:
for i in range(1,6):
    ac = acs[i]
    print(f"{ac},",end='')


In [None]:
# train_loader, test_loader, idx = get_conbined_splitted_and_shuffled_mnist(NUM_TASK, BATCH_SIZE, RANDOM_SEED)

In [None]:
# auto_list, expert_list, the_loss_history = train(train_loader)

In [None]:
# test(test_loader, auto_list, expert_list)