## Import Models and Configure Experiment

In [2]:
from model.vcl.model import MFVI_NN
from model.util.vcl_experiment import train, test, run_vcl, run_auto_vcl
from torchvision import datasets, transforms
from model.util.processing import *
import torch
SEED=42
epoch_per_task = 10
batch_size = 256
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu' 

### Permuted MNIST Experiment

#### Data preprocessing

In [3]:
from torchvision import datasets, transforms
import torch
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                                ])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
from tqdm import tqdm
def permute_mnist(mnist, perm):
    """Apply a fixed permutation to the pixels of each image in the dataset."""
    permuted_data = []
    for img, target in mnist:
        # Flatten the image, apply permutation and reshape back to 1x28x28
        img_permuted = img.view(-1)[perm].view(1, 28, 28)
        permuted_data.append((img_permuted, target))
    return permuted_data

# Initialize lists to store the permuted datasets
permuted_mnist_train_datasets = []
permuted_mnist_test_datasets = []
torch.manual_seed(SEED)
# Generate 10 permuted datasets
for _ in tqdm(range(10)):
    # Generate a fixed permutation
    fixed_permutation = torch.randperm(784)
    
    # Apply this permutation to the train and test datasets
    permuted_train = permute_mnist(mnist_trainset, fixed_permutation)
    permuted_test = permute_mnist(mnist_testset, fixed_permutation)
    
    # Store the permuted datasets
    permuted_mnist_train_datasets.append(permuted_train)
    permuted_mnist_test_datasets.append(permuted_test)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:32<00:00,  3.29s/it]


In [5]:
from torch.utils.data import DataLoader
batch_size = 256
pmnist_train_loaders = [DataLoader(m, batch_size=batch_size, shuffle=True) for m in permuted_mnist_train_datasets]
pmnist_test_loaders = [DataLoader(m, batch_size=batch_size, shuffle=False) for m in permuted_mnist_test_datasets]

#### Running

In [6]:
## beta = 0.01
p_trends_1 = []
torch.manual_seed(SEED)
coreset_size = 0
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, 
        coreset_size, beta=1e-2, device = device)
    p_trends_1.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 95.94%
Average Accuracy across 3 tasks: 90.39%
Average Accuracy across 4 tasks: 85.50%
Average Accuracy across 5 tasks: 77.25%
Average Accuracy across 6 tasks: 71.46%
Average Accuracy across 7 tasks: 68.96%
Average Accuracy across 8 tasks: 65.50%
Average Accuracy across 9 tasks: 61.87%
Average Accuracy across 10 tasks: 60.69%
Average Accuracy across 1 tasks: 97.88%
Average Accuracy across 2 tasks: 95.16%
Average Accuracy across 3 tasks: 89.28%
Average Accuracy across 4 tasks: 83.31%
Average Accuracy across 5 tasks: 78.06%
Average Accuracy across 6 tasks: 75.42%
Average Accuracy across 7 tasks: 71.26%
Average Accuracy across 8 tasks: 66.47%
Average Accuracy across 9 tasks: 60.38%
Average Accuracy across 10 tasks: 56.05%
Average Accuracy across 1 tasks: 97.72%
Average Accuracy across 2 tasks: 96.02%
Average Accuracy across 3 tasks: 92.20%
Average Accuracy across 4 tasks: 84.93%
Average Accuracy across 5 tasks: 77.10

In [7]:
## beta = 1
p_trends_2 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_2.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 96.74%
Average Accuracy across 3 tasks: 95.99%
Average Accuracy across 4 tasks: 95.13%
Average Accuracy across 5 tasks: 94.20%
Average Accuracy across 6 tasks: 93.41%
Average Accuracy across 7 tasks: 92.46%
Average Accuracy across 8 tasks: 91.92%
Average Accuracy across 9 tasks: 90.69%
Average Accuracy across 10 tasks: 89.34%
Average Accuracy across 1 tasks: 97.88%
Average Accuracy across 2 tasks: 96.85%
Average Accuracy across 3 tasks: 96.12%
Average Accuracy across 4 tasks: 95.38%
Average Accuracy across 5 tasks: 94.73%
Average Accuracy across 6 tasks: 93.88%
Average Accuracy across 7 tasks: 92.66%
Average Accuracy across 8 tasks: 92.01%
Average Accuracy across 9 tasks: 90.91%
Average Accuracy across 10 tasks: 90.04%
Average Accuracy across 1 tasks: 97.72%
Average Accuracy across 2 tasks: 96.72%
Average Accuracy across 3 tasks: 96.10%
Average Accuracy across 4 tasks: 95.49%
Average Accuracy across 5 tasks: 95.04

In [8]:
p_trends_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=100, device = device)
    p_trends_3.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 85.25%
Average Accuracy across 3 tasks: 80.89%
Average Accuracy across 4 tasks: 78.05%
Average Accuracy across 5 tasks: 76.51%
Average Accuracy across 6 tasks: 75.92%
Average Accuracy across 7 tasks: 74.48%
Average Accuracy across 8 tasks: 73.90%
Average Accuracy across 9 tasks: 73.00%
Average Accuracy across 10 tasks: 72.20%
Average Accuracy across 1 tasks: 97.88%
Average Accuracy across 2 tasks: 85.93%
Average Accuracy across 3 tasks: 81.30%
Average Accuracy across 4 tasks: 78.52%
Average Accuracy across 5 tasks: 77.02%
Average Accuracy across 6 tasks: 75.85%
Average Accuracy across 7 tasks: 74.84%
Average Accuracy across 8 tasks: 74.12%
Average Accuracy across 9 tasks: 73.22%
Average Accuracy across 10 tasks: 72.47%
Average Accuracy across 1 tasks: 97.72%
Average Accuracy across 2 tasks: 84.95%
Average Accuracy across 3 tasks: 81.02%
Average Accuracy across 4 tasks: 78.39%
Average Accuracy across 5 tasks: 77.19

In [10]:
p_trends_4 = []
p_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, p_beta= run_auto_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, 
        epoch_per_task, coreset_size, return_betas=True, device = device)
    p_trends_4.append(trend)
    p_betas.append(p_beta)

Average Accuracy across 1 tasks: 97.61%
Average Accuracy across 2 tasks: 96.81%
Average Accuracy across 3 tasks: 96.02%
Average Accuracy across 4 tasks: 95.59%
Average Accuracy across 5 tasks: 94.70%
Average Accuracy across 6 tasks: 93.90%
Average Accuracy across 7 tasks: 92.96%
Average Accuracy across 8 tasks: 91.81%
Average Accuracy across 9 tasks: 91.28%
Average Accuracy across 10 tasks: 90.48%
Average Accuracy across 1 tasks: 97.84%
Average Accuracy across 2 tasks: 96.54%
Average Accuracy across 3 tasks: 95.91%
Average Accuracy across 4 tasks: 95.41%
Average Accuracy across 5 tasks: 94.95%
Average Accuracy across 6 tasks: 93.95%
Average Accuracy across 7 tasks: 93.13%
Average Accuracy across 8 tasks: 92.51%
Average Accuracy across 9 tasks: 91.89%
Average Accuracy across 10 tasks: 90.91%
Average Accuracy across 1 tasks: 97.79%
Average Accuracy across 2 tasks: 96.62%
Average Accuracy across 3 tasks: 95.97%
Average Accuracy across 4 tasks: 95.51%
Average Accuracy across 5 tasks: 94.87

In [11]:
p_chart = plot_trends_with_autovcl([p_trends_1, p_trends_2,p_trends_3, p_trends_4], p_betas, lower = 0.5)
p_chart

![Alt text](results/figure/pm_avcl.png)

### Split MNIST Experiment with Custom Targets


In [12]:
tasks = [(0, 1), (8, 7), (9, 4), (6, 2), (3, 5)]
split_alike_train_loaders, split_alike_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)

In [13]:
coreset_size = 0
trends_alike_1 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
    beta=1e-2, binary_labels = tasks)
    trends_alike_1.append(trend)

Average Accuracy across 1 tasks: 99.86%
Average Accuracy across 2 tasks: 99.61%
Average Accuracy across 3 tasks: 98.96%
Average Accuracy across 4 tasks: 95.33%
Average Accuracy across 5 tasks: 96.65%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.68%
Average Accuracy across 3 tasks: 98.92%
Average Accuracy across 4 tasks: 96.62%
Average Accuracy across 5 tasks: 93.88%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.71%
Average Accuracy across 3 tasks: 98.97%
Average Accuracy across 4 tasks: 96.44%
Average Accuracy across 5 tasks: 89.27%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 98.57%
Average Accuracy across 3 tasks: 93.01%
Average Accuracy across 4 tasks: 91.73%
Average Accuracy across 5 tasks: 90.44%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.64%
Average Accuracy across 3 tasks: 98.00%
Average Accuracy across 4 tasks: 97.70%
Average Accuracy across 5 tasks: 96.07%


In [14]:
coreset_size = 0
trends_alike_2 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
        beta=1, binary_labels = tasks)
    trends_alike_2.append(trend)

Average Accuracy across 1 tasks: 99.86%
Average Accuracy across 2 tasks: 99.58%
Average Accuracy across 3 tasks: 99.00%
Average Accuracy across 4 tasks: 98.19%
Average Accuracy across 5 tasks: 96.03%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.78%
Average Accuracy across 3 tasks: 98.91%
Average Accuracy across 4 tasks: 97.05%
Average Accuracy across 5 tasks: 95.20%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.63%
Average Accuracy across 3 tasks: 99.15%
Average Accuracy across 4 tasks: 97.29%
Average Accuracy across 5 tasks: 95.13%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.68%
Average Accuracy across 3 tasks: 99.04%
Average Accuracy across 4 tasks: 96.57%
Average Accuracy across 5 tasks: 93.66%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.63%
Average Accuracy across 3 tasks: 98.98%
Average Accuracy across 4 tasks: 98.83%
Average Accuracy across 5 tasks: 97.13%


In [15]:
coreset_size = 0
trends_alike_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders,
     optimizer, epoch_per_task, coreset_size, 
        beta=1e2, binary_labels = tasks)
    trends_alike_3.append(trend)

Average Accuracy across 1 tasks: 99.86%
Average Accuracy across 2 tasks: 99.00%
Average Accuracy across 3 tasks: 97.95%
Average Accuracy across 4 tasks: 97.36%
Average Accuracy across 5 tasks: 96.23%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.25%
Average Accuracy across 3 tasks: 97.33%
Average Accuracy across 4 tasks: 97.08%
Average Accuracy across 5 tasks: 96.50%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.08%
Average Accuracy across 3 tasks: 98.38%
Average Accuracy across 4 tasks: 97.98%
Average Accuracy across 5 tasks: 97.31%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.20%
Average Accuracy across 3 tasks: 98.01%
Average Accuracy across 4 tasks: 96.30%
Average Accuracy across 5 tasks: 96.26%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.13%
Average Accuracy across 3 tasks: 97.95%
Average Accuracy across 4 tasks: 98.10%
Average Accuracy across 5 tasks: 97.74%


In [16]:
trends_alike_4 = []
alike_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, alike_beta = run_auto_vcl(model, 
        split_alike_train_loaders,
        split_alike_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks,
        return_betas = True,
        device = device)
    trends_alike_4.append(trend)
    alike_betas.append(alike_beta)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.76%
Average Accuracy across 3 tasks: 98.89%
Average Accuracy across 4 tasks: 98.78%
Average Accuracy across 5 tasks: 96.85%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.70%
Average Accuracy across 3 tasks: 98.85%
Average Accuracy across 4 tasks: 98.13%
Average Accuracy across 5 tasks: 97.73%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.80%
Average Accuracy across 3 tasks: 99.14%
Average Accuracy across 4 tasks: 98.79%
Average Accuracy across 5 tasks: 98.54%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.73%
Average Accuracy across 3 tasks: 98.99%
Average Accuracy across 4 tasks: 98.84%
Average Accuracy across 5 tasks: 97.95%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.73%
Average Accuracy across 3 tasks: 99.14%
Average Accuracy across 4 tasks: 98.83%
Average Accuracy across 5 tasks: 96.45%


In [17]:
alike_chart = plot_trends_with_autovcl([trends_alike_1, trends_alike_2, trends_alike_3, trends_alike_4,] ,alike_betas,title = 'Split MNIST Experiment with Custom Targets', lower = 0.9)
alike_chart

![Alt text](results/figure/pm_avcl.png)

### Mixed Experiment with Split CIFAR-10 and Split MNIST

In [18]:
transform_cifar = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # Convert image to grayscale
    transforms.Resize((28, 28)),
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))])

# Load the CIFAR-10 training dataset with the defined transform
cifar_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)

# Load the CIFAR-10 test dataset with the defined transform
cifar_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)


Files already downloaded and verified
Files already downloaded and verified


In [19]:
tasks = [(0,1),(2,3),(4,5),(6,7),(8,9)]
mixed_tasks = [tasks[i//2] for i in range(len(tasks)*2)]

In [20]:
torch.manual_seed(SEED)
mnist_train_loaders, mnist_test_loaders = \
    create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)
cifar_train_loaders, cifar_test_loaders = \
    create_split_dataloaders(cifar_train_dataset, cifar_test_dataset, tasks, batch_size=batch_size)

In [21]:
mixed_train_loaders = [mnist_train_loaders, cifar_train_loaders]
mixed_test_loaders = [mnist_test_loaders, cifar_test_loaders]

mixed_train_loaders = [mixed_train_loaders[i%2][i//2] for i in range(len(mixed_tasks))]
mixed_test_loaders = [mixed_test_loaders[i%2][i//2] for i in range(len(mixed_tasks))]

In [22]:
coreset_size= 0
torch.manual_seed(SEED)
mixed_trends_1 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = mixed_tasks)
    mixed_trends_1.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.88%
Average Accuracy across 3 tasks: 93.46%
Average Accuracy across 4 tasks: 85.73%
Average Accuracy across 5 tasks: 78.96%
Average Accuracy across 6 tasks: 81.81%
Average Accuracy across 7 tasks: 64.92%
Average Accuracy across 8 tasks: 74.86%
Average Accuracy across 9 tasks: 77.66%
Average Accuracy across 10 tasks: 79.41%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 93.46%
Average Accuracy across 3 tasks: 93.64%
Average Accuracy across 4 tasks: 85.17%
Average Accuracy across 5 tasks: 84.72%
Average Accuracy across 6 tasks: 80.62%
Average Accuracy across 7 tasks: 78.59%
Average Accuracy across 8 tasks: 76.27%
Average Accuracy across 9 tasks: 78.19%
Average Accuracy across 10 tasks: 78.41%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 93.73%
Average Accuracy across 3 tasks: 93.30%
Average Accuracy across 4 tasks: 86.82%
Average Accuracy across 5 tasks: 86.86

In [23]:
torch.manual_seed(SEED)
mixed_trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = mixed_tasks)
    mixed_trends_2.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.73%
Average Accuracy across 3 tasks: 93.79%
Average Accuracy across 4 tasks: 85.86%
Average Accuracy across 5 tasks: 86.55%
Average Accuracy across 6 tasks: 83.70%
Average Accuracy across 7 tasks: 85.45%
Average Accuracy across 8 tasks: 84.48%
Average Accuracy across 9 tasks: 84.78%
Average Accuracy across 10 tasks: 84.06%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 92.85%
Average Accuracy across 3 tasks: 94.19%
Average Accuracy across 4 tasks: 87.29%
Average Accuracy across 5 tasks: 88.78%
Average Accuracy across 6 tasks: 83.32%
Average Accuracy across 7 tasks: 83.49%
Average Accuracy across 8 tasks: 80.47%
Average Accuracy across 9 tasks: 82.66%
Average Accuracy across 10 tasks: 80.88%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 93.18%
Average Accuracy across 3 tasks: 93.47%
Average Accuracy across 4 tasks: 86.28%
Average Accuracy across 5 tasks: 88.71

In [24]:
torch.manual_seed(SEED)
mixed_trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, 
        epoch_per_task, coreset_size, beta=1e2, binary_labels = mixed_tasks)
    mixed_trends_3.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 88.13%
Average Accuracy across 3 tasks: 91.60%
Average Accuracy across 4 tasks: 84.49%
Average Accuracy across 5 tasks: 87.14%
Average Accuracy across 6 tasks: 83.97%
Average Accuracy across 7 tasks: 85.79%
Average Accuracy across 8 tasks: 84.36%
Average Accuracy across 9 tasks: 85.40%
Average Accuracy across 10 tasks: 84.34%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 88.28%
Average Accuracy across 3 tasks: 91.52%
Average Accuracy across 4 tasks: 84.68%
Average Accuracy across 5 tasks: 86.86%
Average Accuracy across 6 tasks: 83.60%
Average Accuracy across 7 tasks: 85.60%
Average Accuracy across 8 tasks: 83.73%
Average Accuracy across 9 tasks: 84.37%
Average Accuracy across 10 tasks: 82.85%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 86.15%
Average Accuracy across 3 tasks: 89.40%
Average Accuracy across 4 tasks: 83.87%
Average Accuracy across 5 tasks: 86.76

In [25]:
mixed_trends_4 = []
mixed_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, m_beta = run_auto_vcl(model, 
        mixed_train_loaders,
        mixed_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = mixed_tasks,
        return_betas = True, 
        device = device)
    mixed_trends_4.append(trend)
    mixed_betas.append(m_beta)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.93%
Average Accuracy across 3 tasks: 95.52%
Average Accuracy across 4 tasks: 86.24%
Average Accuracy across 5 tasks: 88.86%
Average Accuracy across 6 tasks: 84.91%
Average Accuracy across 7 tasks: 86.89%
Average Accuracy across 8 tasks: 85.10%
Average Accuracy across 9 tasks: 86.58%
Average Accuracy across 10 tasks: 85.52%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.65%
Average Accuracy across 3 tasks: 95.41%
Average Accuracy across 4 tasks: 86.44%
Average Accuracy across 5 tasks: 89.08%
Average Accuracy across 6 tasks: 78.69%
Average Accuracy across 7 tasks: 80.68%
Average Accuracy across 8 tasks: 83.22%
Average Accuracy across 9 tasks: 86.28%
Average Accuracy across 10 tasks: 85.87%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.65%
Average Accuracy across 3 tasks: 95.12%
Average Accuracy across 4 tasks: 87.34%
Average Accuracy across 5 tasks: 89.05

In [26]:
# mixed_trends_5 = []
# for i in range(5):
#     model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#     trend = run_auto_vcl(model, 
#         mixed_train_loaders,
#         mixed_test_loaders,
#         optimizer, 
#         epoch_per_task, 
#         coreset_size,
#         binary_labels = mixed_tasks,
#         dor = True)
#     mixed_trends_5.append(trend)

In [27]:
mixed_chart = plot_trends_with_autovcl([mixed_trends_1, mixed_trends_2, mixed_trends_3, mixed_trends_4], mixed_betas,title='Mixed Experiment',lower=0.7)
mixed_chart

![Alt text](results/figure/m_avcl.png)

### Final Visualization

In [28]:
(alike_chart|p_chart|mixed_chart
).configure_axis(
        labelFontSize=26,
        titleFontSize=26
    ).configure_legend(
        labelFontSize=26,
        titleFontSize=26
    ).configure_title(
        fontSize=28
    )

![Alt text](results/figure/combined_avcl.png)