## Import Models and Configure Experiment

In [1]:
from model.vcl.model import MFVI_NN
from model.util.vcl_experiment import train, test, run_vcl, run_auto_vcl
from torchvision import datasets, transforms
from model.util.processing import *
import torch
SEED=42
epoch_per_task = 10
batch_size = 256
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu' 

### Permuted MNIST Experiment

#### Data preprocessing

In [2]:
import torch
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                                ])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [3]:
from tqdm import tqdm
def permute_mnist(mnist, perm):
    """Apply a fixed permutation to the pixels of each image in the dataset."""
    permuted_data = []
    for img, target in mnist:
        # Flatten the image, apply permutation and reshape back to 1x28x28
        img_permuted = img.view(-1)[perm].view(1, 28, 28)
        permuted_data.append((img_permuted, target))
    return permuted_data

# Initialize lists to store the permuted datasets
permuted_mnist_train_datasets = []
permuted_mnist_test_datasets = []
torch.manual_seed(SEED)
# Generate 10 permuted datasets
for _ in tqdm(range(10)):
    # Generate a fixed permutation
    fixed_permutation = torch.randperm(784)
    
    # Apply this permutation to the train and test datasets
    permuted_train = permute_mnist(mnist_trainset, fixed_permutation)
    permuted_test = permute_mnist(mnist_testset, fixed_permutation)
    
    # Store the permuted datasets
    permuted_mnist_train_datasets.append(permuted_train)
    permuted_mnist_test_datasets.append(permuted_test)

100%|██████████| 10/10 [00:29<00:00,  2.92s/it]


In [4]:
from torch.utils.data import DataLoader
batch_size = 256
pmnist_train_loaders = [DataLoader(m, batch_size=batch_size, shuffle=True) for m in permuted_mnist_train_datasets]
pmnist_test_loaders = [DataLoader(m, batch_size=batch_size, shuffle=False) for m in permuted_mnist_test_datasets]

#### Running

In [5]:
## beta = 0.01
p_trends_1 = []
torch.manual_seed(SEED)
coreset_size = 0
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, 
        coreset_size, beta=1e-2, device = device)
    p_trends_1.append(trend)

Average Accuracy across 1 tasks: 97.84%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 95.50%, Standard Deviation: 2.16%
Average Accuracy across 3 tasks: 89.42%, Standard Deviation: 8.50%
Average Accuracy across 4 tasks: 82.91%, Standard Deviation: 16.87%
Average Accuracy across 5 tasks: 79.45%, Standard Deviation: 18.97%
Average Accuracy across 6 tasks: 75.32%, Standard Deviation: 22.79%
Average Accuracy across 7 tasks: 70.47%, Standard Deviation: 22.40%
Average Accuracy across 8 tasks: 63.10%, Standard Deviation: 24.65%
Average Accuracy across 9 tasks: 60.58%, Standard Deviation: 25.56%
Average Accuracy across 10 tasks: 60.60%, Standard Deviation: 23.85%
Average Accuracy across 1 tasks: 97.84%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 95.47%, Standard Deviation: 2.22%
Average Accuracy across 3 tasks: 90.20%, Standard Deviation: 7.60%
Average Accuracy across 4 tasks: 80.44%, Standard Deviation: 18.16%
Average Accuracy across 5 tasks: 76.43%, Standard Dev

In [6]:
## beta = 1
p_trends_2 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_2.append(trend)

Average Accuracy across 1 tasks: 97.84%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 96.84%, Standard Deviation: 0.74%
Average Accuracy across 3 tasks: 96.07%, Standard Deviation: 0.47%
Average Accuracy across 4 tasks: 95.36%, Standard Deviation: 0.56%
Average Accuracy across 5 tasks: 94.44%, Standard Deviation: 0.99%
Average Accuracy across 6 tasks: 93.50%, Standard Deviation: 1.76%
Average Accuracy across 7 tasks: 92.33%, Standard Deviation: 2.64%
Average Accuracy across 8 tasks: 92.05%, Standard Deviation: 2.86%
Average Accuracy across 9 tasks: 90.83%, Standard Deviation: 3.91%
Average Accuracy across 10 tasks: 89.19%, Standard Deviation: 5.42%
Average Accuracy across 1 tasks: 97.84%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 96.84%, Standard Deviation: 0.68%
Average Accuracy across 3 tasks: 96.24%, Standard Deviation: 0.44%
Average Accuracy across 4 tasks: 95.60%, Standard Deviation: 0.45%
Average Accuracy across 5 tasks: 94.80%, Standard Deviation: 

In [7]:
p_trends_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=100, device = device)
    p_trends_3.append(trend)

Average Accuracy across 1 tasks: 97.84%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 85.37%, Standard Deviation: 12.47%
Average Accuracy across 3 tasks: 80.95%, Standard Deviation: 11.93%
Average Accuracy across 4 tasks: 77.96%, Standard Deviation: 11.46%
Average Accuracy across 5 tasks: 76.58%, Standard Deviation: 10.62%
Average Accuracy across 6 tasks: 75.71%, Standard Deviation: 9.97%
Average Accuracy across 7 tasks: 74.48%, Standard Deviation: 9.57%
Average Accuracy across 8 tasks: 73.66%, Standard Deviation: 9.20%
Average Accuracy across 9 tasks: 72.74%, Standard Deviation: 8.94%
Average Accuracy across 10 tasks: 71.99%, Standard Deviation: 8.64%
Average Accuracy across 1 tasks: 97.84%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 85.92%, Standard Deviation: 11.85%
Average Accuracy across 3 tasks: 81.32%, Standard Deviation: 11.65%
Average Accuracy across 4 tasks: 78.54%, Standard Deviation: 11.07%
Average Accuracy across 5 tasks: 77.16%, Standard Devi

In [8]:
p_trends_4 = []
p_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, p_beta= run_auto_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, 
        epoch_per_task, coreset_size, return_betas=True, device = device)
    p_trends_4.append(trend)
    p_betas.append(p_beta)

Average Accuracy across 1 tasks: 97.66%. Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 96.68%. Standard Deviation: 0.76%
Average Accuracy across 3 tasks: 95.83%. Standard Deviation: 0.22%
Average Accuracy across 4 tasks: 95.45%. Standard Deviation: 0.19%
Average Accuracy across 5 tasks: 94.57%. Standard Deviation: 0.38%
Average Accuracy across 6 tasks: 93.80%. Standard Deviation: 0.88%
Average Accuracy across 7 tasks: 92.88%. Standard Deviation: 1.28%
Average Accuracy across 8 tasks: 91.95%. Standard Deviation: 1.64%
Average Accuracy across 9 tasks: 91.39%. Standard Deviation: 2.27%
Average Accuracy across 10 tasks: 91.02%. Standard Deviation: 2.71%
Average Accuracy across 1 tasks: 97.77%. Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 96.70%. Standard Deviation: 0.63%
Average Accuracy across 3 tasks: 95.98%. Standard Deviation: 0.36%
Average Accuracy across 4 tasks: 95.64%. Standard Deviation: 0.36%
Average Accuracy across 5 tasks: 95.09%. Standard Deviation: 

In [9]:
p_chart = plot_trends_with_autovcl([p_trends_1, p_trends_2,p_trends_3, p_trends_4], p_betas, lower = 0.5)
p_chart

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


![Alt text](https://raw.githubusercontent.com/lukeyf/variational_continual_learning/main/results/figure/pm_avcl.png)

### Split MNIST Experiment with Custom Targets


In [10]:
tasks = [(0, 1), (8, 7), (9, 4), (6, 2), (3, 5)]
split_alike_train_loaders, split_alike_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)

In [11]:
coreset_size = 0
trends_alike_1 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
    beta=1e-2, binary_labels = tasks)
    trends_alike_1.append(trend)

Average Accuracy across 1 tasks: 99.86%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.66%, Standard Deviation: 0.16%
Average Accuracy across 3 tasks: 99.09%, Standard Deviation: 0.34%
Average Accuracy across 4 tasks: 98.23%, Standard Deviation: 1.27%
Average Accuracy across 5 tasks: 96.61%, Standard Deviation: 3.37%
Average Accuracy across 1 tasks: 99.95%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.68%, Standard Deviation: 0.08%
Average Accuracy across 3 tasks: 99.01%, Standard Deviation: 0.26%
Average Accuracy across 4 tasks: 96.51%, Standard Deviation: 3.28%
Average Accuracy across 5 tasks: 94.19%, Standard Deviation: 6.59%
Average Accuracy across 1 tasks: 99.91%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.71%, Standard Deviation: 0.11%
Average Accuracy across 3 tasks: 98.97%, Standard Deviation: 0.39%
Average Accuracy across 4 tasks: 96.44%, Standard Deviation: 2.90%
Average Accuracy across 5 tasks: 88.74%, Standard Deviation: 1

In [12]:
coreset_size = 0
trends_alike_2 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
        beta=1, binary_labels = tasks)
    trends_alike_2.append(trend)

Average Accuracy across 1 tasks: 99.86%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.63%, Standard Deviation: 0.13%
Average Accuracy across 3 tasks: 98.69%, Standard Deviation: 1.08%
Average Accuracy across 4 tasks: 96.44%, Standard Deviation: 4.83%
Average Accuracy across 5 tasks: 94.69%, Standard Deviation: 8.23%
Average Accuracy across 1 tasks: 99.95%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.78%, Standard Deviation: 0.13%
Average Accuracy across 3 tasks: 98.96%, Standard Deviation: 0.45%
Average Accuracy across 4 tasks: 97.63%, Standard Deviation: 2.44%
Average Accuracy across 5 tasks: 97.00%, Standard Deviation: 3.65%
Average Accuracy across 1 tasks: 99.91%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.63%, Standard Deviation: 0.18%
Average Accuracy across 3 tasks: 99.19%, Standard Deviation: 0.44%
Average Accuracy across 4 tasks: 97.93%, Standard Deviation: 1.70%
Average Accuracy across 5 tasks: 94.30%, Standard Deviation: 7

In [13]:
coreset_size = 0
trends_alike_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders,
     optimizer, epoch_per_task, coreset_size, 
        beta=1e2, binary_labels = tasks)
    trends_alike_3.append(trend)

Average Accuracy across 1 tasks: 99.86%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.00%, Standard Deviation: 0.90%
Average Accuracy across 3 tasks: 97.93%, Standard Deviation: 1.62%
Average Accuracy across 4 tasks: 97.17%, Standard Deviation: 1.93%
Average Accuracy across 5 tasks: 96.07%, Standard Deviation: 3.08%
Average Accuracy across 1 tasks: 99.95%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.23%, Standard Deviation: 0.73%
Average Accuracy across 3 tasks: 97.38%, Standard Deviation: 1.82%
Average Accuracy across 4 tasks: 96.73%, Standard Deviation: 2.82%
Average Accuracy across 5 tasks: 96.32%, Standard Deviation: 3.34%
Average Accuracy across 1 tasks: 99.91%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.08%, Standard Deviation: 0.88%
Average Accuracy across 3 tasks: 98.41%, Standard Deviation: 1.20%
Average Accuracy across 4 tasks: 98.02%, Standard Deviation: 1.58%
Average Accuracy across 5 tasks: 97.29%, Standard Deviation: 1

In [14]:
trends_alike_4 = []
alike_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, alike_beta = run_auto_vcl(model, 
        split_alike_train_loaders,
        split_alike_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks,
        return_betas = True,
        device = device)
    trends_alike_4.append(trend)
    alike_betas.append(alike_beta)

Average Accuracy across 1 tasks: 99.91%. Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.75%. Standard Deviation: 0.10%
Average Accuracy across 3 tasks: 98.82%. Standard Deviation: 0.43%
Average Accuracy across 4 tasks: 98.33%. Standard Deviation: 0.29%
Average Accuracy across 5 tasks: 97.11%. Standard Deviation: 2.07%
Average Accuracy across 1 tasks: 99.91%. Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.70%. Standard Deviation: 0.15%
Average Accuracy across 3 tasks: 99.09%. Standard Deviation: 0.56%
Average Accuracy across 4 tasks: 98.30%. Standard Deviation: 1.00%
Average Accuracy across 5 tasks: 97.75%. Standard Deviation: 1.17%
Average Accuracy across 1 tasks: 99.95%. Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 99.80%. Standard Deviation: 0.10%
Average Accuracy across 3 tasks: 99.34%. Standard Deviation: 0.32%
Average Accuracy across 4 tasks: 98.76%. Standard Deviation: 0.47%
Average Accuracy across 5 tasks: 98.38%. Standard Deviation: 0

In [15]:
alike_chart = plot_trends_with_autovcl([trends_alike_1, trends_alike_2, trends_alike_3, trends_alike_4,] ,alike_betas,title = 'Split MNIST Experiment with Custom Targets', lower = 0.9)
alike_chart

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


![Alt text](https://raw.githubusercontent.com/lukeyf/variational_continual_learning/main/results/figure/sp_avcl.png)

### Mixed Experiment with Split CIFAR-10 and Split MNIST

In [16]:
transform_cifar = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # Convert image to grayscale
    transforms.Resize((28, 28)),
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))])

# Load the CIFAR-10 training dataset with the defined transform
cifar_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)

# Load the CIFAR-10 test dataset with the defined transform
cifar_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)


Files already downloaded and verified
Files already downloaded and verified


In [17]:
tasks = [(0,1),(2,3),(4,5),(6,7),(8,9)]
mixed_tasks = [tasks[i//2] for i in range(len(tasks)*2)]

In [18]:
torch.manual_seed(SEED)
mnist_train_loaders, mnist_test_loaders = \
    create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)
cifar_train_loaders, cifar_test_loaders = \
    create_split_dataloaders(cifar_train_dataset, cifar_test_dataset, tasks, batch_size=batch_size)

In [19]:
mixed_train_loaders = [mnist_train_loaders, cifar_train_loaders]
mixed_test_loaders = [mnist_test_loaders, cifar_test_loaders]

mixed_train_loaders = [mixed_train_loaders[i%2][i//2] for i in range(len(mixed_tasks))]
mixed_test_loaders = [mixed_test_loaders[i%2][i//2] for i in range(len(mixed_tasks))]

In [20]:
coreset_size= 0
torch.manual_seed(SEED)
mixed_trends_1 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = mixed_tasks)
    mixed_trends_1.append(trend)

Average Accuracy across 1 tasks: 99.91%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 93.98%, Standard Deviation: 5.93%
Average Accuracy across 3 tasks: 92.99%, Standard Deviation: 9.40%
Average Accuracy across 4 tasks: 85.95%, Standard Deviation: 13.20%
Average Accuracy across 5 tasks: 84.63%, Standard Deviation: 12.84%
Average Accuracy across 6 tasks: 83.66%, Standard Deviation: 13.20%
Average Accuracy across 7 tasks: 74.05%, Standard Deviation: 16.38%
Average Accuracy across 8 tasks: 77.16%, Standard Deviation: 14.78%
Average Accuracy across 9 tasks: 78.19%, Standard Deviation: 16.38%
Average Accuracy across 10 tasks: 79.71%, Standard Deviation: 12.82%
Average Accuracy across 1 tasks: 99.95%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 93.38%, Standard Deviation: 6.43%
Average Accuracy across 3 tasks: 93.69%, Standard Deviation: 8.62%
Average Accuracy across 4 tasks: 86.13%, Standard Deviation: 13.32%
Average Accuracy across 5 tasks: 86.12%, Standard Dev

In [21]:
torch.manual_seed(SEED)
mixed_trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = mixed_tasks)
    mixed_trends_2.append(trend)

Average Accuracy across 1 tasks: 99.91%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 93.70%, Standard Deviation: 6.20%
Average Accuracy across 3 tasks: 93.66%, Standard Deviation: 8.49%
Average Accuracy across 4 tasks: 86.30%, Standard Deviation: 12.33%
Average Accuracy across 5 tasks: 88.25%, Standard Deviation: 12.88%
Average Accuracy across 6 tasks: 83.81%, Standard Deviation: 13.17%
Average Accuracy across 7 tasks: 83.07%, Standard Deviation: 14.39%
Average Accuracy across 8 tasks: 83.35%, Standard Deviation: 13.95%
Average Accuracy across 9 tasks: 84.00%, Standard Deviation: 14.70%
Average Accuracy across 10 tasks: 84.11%, Standard Deviation: 13.21%
Average Accuracy across 1 tasks: 99.95%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 92.90%, Standard Deviation: 7.05%
Average Accuracy across 3 tasks: 94.30%, Standard Deviation: 7.71%
Average Accuracy across 4 tasks: 87.17%, Standard Deviation: 12.01%
Average Accuracy across 5 tasks: 85.07%, Standard Dev

In [22]:
torch.manual_seed(SEED)
mixed_trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, 
        epoch_per_task, coreset_size, beta=1e2, binary_labels = mixed_tasks)
    mixed_trends_3.append(trend)

Average Accuracy across 1 tasks: 99.91%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 87.88%, Standard Deviation: 12.08%
Average Accuracy across 3 tasks: 91.22%, Standard Deviation: 10.24%
Average Accuracy across 4 tasks: 84.14%, Standard Deviation: 15.02%
Average Accuracy across 5 tasks: 86.77%, Standard Deviation: 14.98%
Average Accuracy across 6 tasks: 83.40%, Standard Deviation: 15.90%
Average Accuracy across 7 tasks: 85.61%, Standard Deviation: 15.74%
Average Accuracy across 8 tasks: 83.93%, Standard Deviation: 15.29%
Average Accuracy across 9 tasks: 85.29%, Standard Deviation: 14.28%
Average Accuracy across 10 tasks: 84.39%, Standard Deviation: 13.94%
Average Accuracy across 1 tasks: 99.95%, Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 88.45%, Standard Deviation: 11.50%
Average Accuracy across 3 tasks: 91.80%, Standard Deviation: 10.11%
Average Accuracy across 4 tasks: 85.01%, Standard Deviation: 14.37%
Average Accuracy across 5 tasks: 87.32%, Standard

In [23]:
mixed_trends_4 = []
mixed_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, m_beta = run_auto_vcl(model, 
        mixed_train_loaders,
        mixed_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = mixed_tasks,
        return_betas = True, 
        device = device)
    mixed_trends_4.append(trend)
    mixed_betas.append(m_beta)

Average Accuracy across 1 tasks: 99.91%. Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 93.60%. Standard Deviation: 6.25%
Average Accuracy across 3 tasks: 95.49%. Standard Deviation: 5.83%
Average Accuracy across 4 tasks: 86.72%. Standard Deviation: 11.99%
Average Accuracy across 5 tasks: 89.45%. Standard Deviation: 11.63%
Average Accuracy across 6 tasks: 86.47%. Standard Deviation: 12.54%
Average Accuracy across 7 tasks: 87.95%. Standard Deviation: 12.61%
Average Accuracy across 8 tasks: 85.74%. Standard Deviation: 13.30%
Average Accuracy across 9 tasks: 86.78%. Standard Deviation: 12.86%
Average Accuracy across 10 tasks: 85.10%. Standard Deviation: 13.35%
Average Accuracy across 1 tasks: 99.91%. Standard Deviation: 0.00%
Average Accuracy across 2 tasks: 93.95%. Standard Deviation: 6.00%
Average Accuracy across 3 tasks: 95.72%. Standard Deviation: 5.50%
Average Accuracy across 4 tasks: 86.00%. Standard Deviation: 11.06%
Average Accuracy across 5 tasks: 88.36%. Standard Dev

In [24]:
# mixed_trends_5 = []
# for i in range(5):
#     model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#     trend = run_auto_vcl(model, 
#         mixed_train_loaders,
#         mixed_test_loaders,
#         optimizer, 
#         epoch_per_task, 
#         coreset_size,
#         binary_labels = mixed_tasks,
#         dor = True)
#     mixed_trends_5.append(trend)

In [25]:
mixed_chart = plot_trends_with_autovcl([mixed_trends_1, mixed_trends_2, mixed_trends_3, mixed_trends_4], mixed_betas,title='Mixed Experiment',lower=0.7)
mixed_chart

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


![Alt text](https://raw.githubusercontent.com/lukeyf/variational_continual_learning/main/results/figure/m_avcl.png)

### Final Visualization

In [26]:
(alike_chart|p_chart|mixed_chart
).configure_axis(
        labelFontSize=26,
        titleFontSize=26
    ).configure_legend(
        labelFontSize=26,
        titleFontSize=26
    ).configure_title(
        fontSize=28
    )

  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)
  col = df[col_name].apply(to_list_if_array, convert_dtype=False)


![Alt text](https://raw.githubusercontent.com/lukeyf/variational_continual_learning/main/results/figure/combined_avcl.png)