In [1]:
from model.vcl.model import MFVI_NN
from model.util.vcl_experiment import train, test, run_vcl, run_auto_vcl
from torchvision import datasets, transforms
from model.util.processing import *
import torch
SEED=42
epoch_per_task = 10
batch_size = 256
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [2]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                                ])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

### Permuted MNIST Experiment

In [3]:
from tqdm import tqdm

# Initialize lists to store the permuted datasets
permuted_mnist_train_datasets = []
permuted_mnist_test_datasets = []
torch.manual_seed(SEED)
# Generate 10 permuted datasets
for _ in tqdm(range(10)):
    # Generate a fixed permutation
    fixed_permutation = torch.randperm(784)
    
    # Apply this permutation to the train and test datasets
    permuted_train = permute_mnist(mnist_trainset, fixed_permutation)
    permuted_test = permute_mnist(mnist_testset, fixed_permutation)
    
    # Store the permuted datasets
    permuted_mnist_train_datasets.append(permuted_train)
    permuted_mnist_test_datasets.append(permuted_test)

100%|██████████| 10/10 [00:33<00:00,  3.33s/it]


In [4]:
from torch.utils.data import DataLoader
batch_size = 256
pmnist_train_loaders = [DataLoader(m, batch_size=batch_size, shuffle=True) for m in permuted_mnist_train_datasets]
pmnist_test_loaders = [DataLoader(m, batch_size=batch_size, shuffle=False) for m in permuted_mnist_test_datasets]

### Split MNIST Experiment

In [5]:
tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
split_train_loaders, split_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)

In [6]:
torch.manual_seed(SEED)
coreset_size = 0
trends_1 = []
ind_acc_no_core = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, ind_acc = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size,
     binary_labels = tasks, return_individual_acc = True, use_prior = True, device = device)
    trends_1.append(trend)
    ind_acc_no_core.append(ind_acc)

Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.66%
Average Accuracy across 3 tasks: 99.39%
Average Accuracy across 4 tasks: 99.50%
Average Accuracy across 5 tasks: 97.50%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.68%
Average Accuracy across 3 tasks: 99.56%
Average Accuracy across 4 tasks: 98.79%
Average Accuracy across 5 tasks: 97.61%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.88%
Average Accuracy across 3 tasks: 99.49%
Average Accuracy across 4 tasks: 99.15%
Average Accuracy across 5 tasks: 98.30%
Average Accuracy across 1 tasks: 100.00%
Average Accuracy across 2 tasks: 99.73%
Average Accuracy across 3 tasks: 99.61%
Average Accuracy across 4 tasks: 99.14%
Average Accuracy across 5 tasks: 98.48%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.68%
Average Accuracy across 3 tasks: 99.75%
Average Accuracy across 4 tasks: 99.39%
Average Accuracy across 5 tasks: 97.38%

In [7]:
torch.manual_seed(SEED)
coreset_size = 40
trends_2 = []
ind_acc_with_core = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, ind_acc = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size,
     binary_labels = tasks, return_individual_acc = True,use_prior = True, device = device)
    trends_2.append(trend)
    ind_acc_with_core.append(ind_acc)

Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.83%
Average Accuracy across 3 tasks: 99.69%
Average Accuracy across 4 tasks: 98.62%
Average Accuracy across 5 tasks: 97.43%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.78%
Average Accuracy across 3 tasks: 99.61%
Average Accuracy across 4 tasks: 98.65%
Average Accuracy across 5 tasks: 98.85%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.81%
Average Accuracy across 3 tasks: 99.02%
Average Accuracy across 4 tasks: 98.69%
Average Accuracy across 5 tasks: 97.08%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.61%
Average Accuracy across 3 tasks: 99.66%
Average Accuracy across 4 tasks: 99.10%
Average Accuracy across 5 tasks: 98.74%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.76%
Average Accuracy across 3 tasks: 99.66%
Average Accuracy across 4 tasks: 99.62%
Average Accuracy across 5 tasks: 99.15%


In [20]:
import pandas as pd
import numpy as np
import altair as alt

# Calculate the means of trends_1 and trends_2 along the first axis (assuming they're numpy arrays)
mean_trends_1 = np.mean(trends_1, axis=0)
mean_trends_2 = np.mean(trends_2, axis=0)

# Create a range for the x-axis based on the length of the trends
x_range = list(range(len(mean_trends_1)))

# Prepare the data for Altair
data = pd.DataFrame({
    'X': x_range * 2,  # Repeat x_range twice because we have two sets of Y values
    'Mean Trends': np.concatenate([mean_trends_1, mean_trends_2]),
    'Trend Type': ['VCL'] * len(mean_trends_1) + ['VCL + Core'] * len(mean_trends_2)
})

x_tick_values = list(range(len(mean_trends_1)))

chart = alt.Chart(data).mark_line().encode(
    x=alt.X('X', axis=alt.Axis(values=x_tick_values, title='Task #')),
    y=alt.Y('Mean Trends', scale=alt.Scale(domain=[0.95, 1]), title='Mean Trends'),
    color='Trend Type'
).properties(
    width=600,
    height=300,
    title='Split MNIST Experiment'
)

chart.display()


In [32]:
# get vcl split 0/1 and split 2/3 accuracy
split_01 = []
split_23 = []
for run in ind_acc_no_core:
    split_01.append([])
    for step in run:
        split_01[-1].append(step[0])

for run in ind_acc_no_core:
    split_23.append([])
    for step in run:
        if len(step)>1:
            split_23[-1].append(step[1])

avg_split_01 = np.average(split_01, axis=0)
avg_split_23 = np.average(split_23, axis=0)

In [39]:
# Create a DataFrame for the first array
df_1 = pd.DataFrame({
    'X': range(len(avg_split_01)),
    'Y': avg_split_01,
    'Series': 'Split 0/1'
})

# Create a DataFrame for the second array, starting from X=1
df_2 = pd.DataFrame({
    'X': range(1, 1 + len(avg_split_23)),
    'Y': avg_split_23,
    'Series': 'Split 2/3'
})

# Combine both DataFrames
df_combined = pd.concat([df_1, df_2])

# Plot using Altair
chart = alt.Chart(df_combined).mark_line(point=True).encode(
    x=alt.X('X:O', axis=alt.Axis(title='Task #', tickCount=len(df_combined['X'].unique()))),  # 'O' for ordinal
    y=alt.Y('Y:Q', scale=alt.Scale(domain=[0.9, 1]), title='Accuracy'),
    color='Series:N',
    tooltip=['X', 'Y', 'Series']
).properties(
    width=600,
    height=300,
    title='Comparison of Split 0/1 and Split 2/3'
)

chart.display()

### Permuted MNIST Experiment

In [40]:
p_trends_no_core = []
torch.manual_seed(SEED)
coreset_size = 0
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_no_core.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 96.74%
Average Accuracy across 3 tasks: 95.99%
Average Accuracy across 4 tasks: 95.13%
Average Accuracy across 5 tasks: 94.20%
Average Accuracy across 6 tasks: 93.41%
Average Accuracy across 7 tasks: 92.46%
Average Accuracy across 8 tasks: 91.92%
Average Accuracy across 9 tasks: 90.69%
Average Accuracy across 10 tasks: 89.34%
Average Accuracy across 1 tasks: 97.88%
Average Accuracy across 2 tasks: 96.85%
Average Accuracy across 3 tasks: 96.12%
Average Accuracy across 4 tasks: 95.38%
Average Accuracy across 5 tasks: 94.73%
Average Accuracy across 6 tasks: 93.88%
Average Accuracy across 7 tasks: 92.66%
Average Accuracy across 8 tasks: 92.01%
Average Accuracy across 9 tasks: 90.91%
Average Accuracy across 10 tasks: 90.04%
Average Accuracy across 1 tasks: 97.72%
Average Accuracy across 2 tasks: 96.72%
Average Accuracy across 3 tasks: 96.10%
Average Accuracy across 4 tasks: 95.49%
Average Accuracy across 5 tasks: 95.04

In [41]:
p_trends_core_1 = []
torch.manual_seed(SEED)
coreset_size = 200
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_core_1.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 96.58%
Average Accuracy across 3 tasks: 95.61%
Average Accuracy across 4 tasks: 95.19%
Average Accuracy across 5 tasks: 91.72%
Average Accuracy across 6 tasks: 92.89%
Average Accuracy across 7 tasks: 92.79%
Average Accuracy across 8 tasks: 92.72%
Average Accuracy across 9 tasks: 92.40%
Average Accuracy across 10 tasks: 90.98%
Average Accuracy across 1 tasks: 97.66%
Average Accuracy across 2 tasks: 96.46%
Average Accuracy across 3 tasks: 95.57%
Average Accuracy across 4 tasks: 94.71%
Average Accuracy across 5 tasks: 92.71%
Average Accuracy across 6 tasks: 93.44%
Average Accuracy across 7 tasks: 92.97%
Average Accuracy across 8 tasks: 91.57%
Average Accuracy across 9 tasks: 92.28%
Average Accuracy across 10 tasks: 91.57%
Average Accuracy across 1 tasks: 97.62%
Average Accuracy across 2 tasks: 96.53%
Average Accuracy across 3 tasks: 95.20%
Average Accuracy across 4 tasks: 94.60%
Average Accuracy across 5 tasks: 94.00

In [42]:
p_trends_core_2 = []
torch.manual_seed(SEED)
coreset_size = 400
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_core_2.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 95.75%
Average Accuracy across 3 tasks: 95.54%
Average Accuracy across 4 tasks: 95.13%
Average Accuracy across 5 tasks: 94.28%
Average Accuracy across 6 tasks: 94.17%
Average Accuracy across 7 tasks: 93.55%
Average Accuracy across 8 tasks: 92.32%
Average Accuracy across 9 tasks: 93.02%
Average Accuracy across 10 tasks: 92.35%
Average Accuracy across 1 tasks: 97.79%
Average Accuracy across 2 tasks: 95.71%
Average Accuracy across 3 tasks: 95.56%
Average Accuracy across 4 tasks: 95.20%
Average Accuracy across 5 tasks: 94.77%
Average Accuracy across 6 tasks: 93.52%
Average Accuracy across 7 tasks: 93.65%
Average Accuracy across 8 tasks: 93.29%
Average Accuracy across 9 tasks: 92.98%
Average Accuracy across 10 tasks: 92.75%
Average Accuracy across 1 tasks: 97.57%
Average Accuracy across 2 tasks: 96.55%
Average Accuracy across 3 tasks: 95.81%
Average Accuracy across 4 tasks: 94.91%
Average Accuracy across 5 tasks: 94.42

In [43]:
p_trends_core_3 = []
torch.manual_seed(SEED)
coreset_size = 1000
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_core_3.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 96.14%
Average Accuracy across 3 tasks: 95.92%
Average Accuracy across 4 tasks: 95.01%
Average Accuracy across 5 tasks: 95.01%
Average Accuracy across 6 tasks: 94.08%
Average Accuracy across 7 tasks: 94.29%
Average Accuracy across 8 tasks: 93.76%
Average Accuracy across 9 tasks: 93.70%
Average Accuracy across 10 tasks: 93.27%
Average Accuracy across 1 tasks: 97.83%
Average Accuracy across 2 tasks: 96.47%
Average Accuracy across 3 tasks: 95.58%
Average Accuracy across 4 tasks: 95.19%
Average Accuracy across 5 tasks: 95.14%
Average Accuracy across 6 tasks: 94.42%
Average Accuracy across 7 tasks: 94.21%
Average Accuracy across 8 tasks: 93.82%
Average Accuracy across 9 tasks: 93.58%
Average Accuracy across 10 tasks: 93.05%
Average Accuracy across 1 tasks: 97.67%
Average Accuracy across 2 tasks: 96.54%
Average Accuracy across 3 tasks: 95.88%
Average Accuracy across 4 tasks: 94.88%
Average Accuracy across 5 tasks: 94.62

In [44]:
p_trends_core_4 = []
torch.manual_seed(SEED)
coreset_size = 2500
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_core_4.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 96.22%
Average Accuracy across 3 tasks: 96.01%
Average Accuracy across 4 tasks: 95.43%
Average Accuracy across 5 tasks: 95.28%
Average Accuracy across 6 tasks: 94.92%
Average Accuracy across 7 tasks: 94.64%
Average Accuracy across 8 tasks: 94.36%
Average Accuracy across 9 tasks: 94.11%
Average Accuracy across 10 tasks: 93.51%
Average Accuracy across 1 tasks: 97.71%
Average Accuracy across 2 tasks: 96.43%
Average Accuracy across 3 tasks: 95.88%
Average Accuracy across 4 tasks: 95.55%
Average Accuracy across 5 tasks: 95.41%
Average Accuracy across 6 tasks: 94.83%
Average Accuracy across 7 tasks: 93.63%
Average Accuracy across 8 tasks: 94.42%
Average Accuracy across 9 tasks: 93.84%
Average Accuracy across 10 tasks: 93.78%
Average Accuracy across 1 tasks: 97.81%
Average Accuracy across 2 tasks: 96.48%
Average Accuracy across 3 tasks: 96.07%
Average Accuracy across 4 tasks: 95.31%
Average Accuracy across 5 tasks: 95.02

In [45]:
p_trends_core_5 = []
torch.manual_seed(SEED)
coreset_size = 5000
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_core_5.append(trend)

Average Accuracy across 1 tasks: 97.64%
Average Accuracy across 2 tasks: 96.70%
Average Accuracy across 3 tasks: 95.79%
Average Accuracy across 4 tasks: 95.27%
Average Accuracy across 5 tasks: 95.31%
Average Accuracy across 6 tasks: 94.80%
Average Accuracy across 7 tasks: 94.66%
Average Accuracy across 8 tasks: 94.46%
Average Accuracy across 9 tasks: 94.35%
Average Accuracy across 10 tasks: 94.01%
Average Accuracy across 1 tasks: 97.96%
Average Accuracy across 2 tasks: 96.52%
Average Accuracy across 3 tasks: 95.83%
Average Accuracy across 4 tasks: 95.70%
Average Accuracy across 5 tasks: 95.15%
Average Accuracy across 6 tasks: 95.23%
Average Accuracy across 7 tasks: 94.77%
Average Accuracy across 8 tasks: 94.30%
Average Accuracy across 9 tasks: 94.26%
Average Accuracy across 10 tasks: 94.03%
Average Accuracy across 1 tasks: 97.66%
Average Accuracy across 2 tasks: 96.43%
Average Accuracy across 3 tasks: 96.13%
Average Accuracy across 4 tasks: 95.68%
Average Accuracy across 5 tasks: 95.14

In [60]:
import pandas as pd
import numpy as np
import altair as alt

# Example arrays
array1 = np.mean(p_trends_no_core, axis=0)
array2 = np.mean(p_trends_core_1, axis=0)
array3 = np.mean(p_trends_core_2, axis=0)
array4 = np.mean(p_trends_core_3, axis=0)
array5 = np.mean(p_trends_core_4, axis=0)
array6 = np.mean(p_trends_core_5, axis=0)

# Convert to DataFrame in long format
df = pd.DataFrame({
    'Index': np.tile(np.arange(len(array1)), 6),
    'Value': np.concatenate([array1, array2, array3, array4, array5, array6]),
    'Array': np.repeat(['No Core', 'Coreset 200', 'Coreset 400', 'Coreset 1000', 'Coreset 2500', 'Coreset 5000'], len(array1))
})

legend_order = ['No Core', 'Coreset 200', 'Coreset 400', 'Coreset 1000', 'Coreset 2500', 'Coreset 5000']

# Plotting with specified legend order
chart = alt.Chart(df).mark_line(point=True).encode(
    x='Index:Q',
    y=alt.Y('Value:Q', scale=alt.Scale(domain=[0.9, 1])),
    color=alt.Color('Array:N', sort=legend_order),  # Apply the defined order here
    tooltip=['Array', 'Index', 'Value']
).properties(
    width=600,
    height=300,
    title='Comparison of VCL with Different Coreset Sizes'
)

chart.display()
