In [56]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

SEED = 123
torch.manual_seed(SEED)  # Set the random seed for PyTorch

class MFVI_Layer(nn.Module):
    def __init__(self, in_features, out_features):
        super(MFVI_Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Mean and log variance parameters
        self.W_m = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_m = nn.Parameter(torch.Tensor(out_features))
        self.W_logv = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_logv = nn.Parameter(torch.Tensor(out_features))

        # Prior distributions initialized to None
        self.prior_W_m = self.prior_b_m = self.prior_W_logv = self.prior_b_logv = None

        self.reset_parameters()

    def reset_parameters(self):
        # Initialize mean parameters to a normal distribution and log variance parameters to a small value
        self.W_m.data.normal_(0, 0.1)
        self.b_m.data.normal_(0, 0.1)
        self.W_logv.data.fill_(-6.0)
        self.b_logv.data.fill_(-6.0)
        self.set_priors()

    def set_priors(self):
        # Set priors to the current parameters
        self.prior_W_m = self.W_m.detach().clone()
        self.prior_b_m = self.b_m.detach().clone()
        self.prior_W_logv = self.W_logv.detach().clone()
        self.prior_b_logv = self.b_logv.detach().clone()

    def forward(self, x, sample=True):
        W_std = torch.exp(0.5 * self.W_logv)
        b_std = torch.exp(0.5 * self.b_logv)
        
        act_mu = F.linear(x, self.W_m, self.b_m)
        act_var = 1e-16 + F.linear(x.pow(2), W_std.pow(2)) + b_std.pow(2)
        act_std = torch.sqrt(act_var)

        if self.training or sample:
            eps = torch.randn_like(act_mu)
            return act_mu + act_std * eps
        else:
            return act_mu

    def kl_divergence(self, device):
        self.update_prior_device(device)
        num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        # Convert log variance to standard deviation for posterior and prior
        W_std_post = torch.exp(0.5 * self.W_logv)
        b_std_post = torch.exp(0.5 * self.b_logv)
        W_std_prior = torch.exp(0.5 * self.prior_W_logv)
        b_std_prior = torch.exp(0.5 * self.prior_b_logv)

        # Calculate KL divergence for weights
        kl_div_W = torch.log(W_std_prior / W_std_post) + \
                ((W_std_post**2 + (self.W_m - self.prior_W_m)**2) / (2 * W_std_prior**2)) - 0.5
        # Sum over all elements
        kl_div_W = kl_div_W.sum()

        # Calculate KL divergence for biases
        kl_div_b = torch.log(b_std_prior / b_std_post) + \
                ((b_std_post**2 + (self.b_m - self.prior_b_m)**2) / (2 * b_std_prior**2)) - 0.5
        # Sum over all elements
        kl_div_b = kl_div_b.sum()

        # Total KL divergence
        total_kl = kl_div_W + kl_div_b

        return total_kl/num_params

    def update_prior_device(self, device):
        # Ensure priors are moved to the correct device
        self.prior_W_m = self.prior_W_m.to(device)
        self.prior_b_m = self.prior_b_m.to(device)
        self.prior_W_logv = self.prior_W_logv.to(device)
        self.prior_b_logv = self.prior_b_logv.to(device)

class MFVI_NN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, num_tasks=1, single_head=False):
        super(MFVI_NN, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.num_tasks = num_tasks
        self.single_head = single_head
        self.layers = nn.ModuleList()
        self.task_specific_layers = nn.ModuleDict()  # Using ModuleDict for task-specific layers

        # Construct shared layers
        sizes = [input_size] + hidden_sizes
        for i in range(len(sizes) - 1):
            self.layers.append(MFVI_Layer(sizes[i], sizes[i + 1]))

        # Construct task-specific output layers
        if single_head:
            self.task_specific_layers[str(0)] = MFVI_Layer(sizes[-1], output_size)
        else:
            for task_id in range(num_tasks):
                self.task_specific_layers[str(task_id)] = MFVI_Layer(sizes[-1], output_size)

    def forward(self, x, task_id=0, sample=True):
        for layer in self.layers:
            x = F.relu(layer(x, sample=sample))
        if self.single_head:
            x = self.task_specific_layers["0"](x, sample=sample)
        else:
            x = self.task_specific_layers[str(task_id)](x, sample=sample)
        return x

    def kl_divergence(self):
        kl_div = 0
        # Accumulate KL divergence from shared and task-specific layers
        for layer in self.layers:
            kl_div += layer.kl_divergence(next(self.parameters()).device)
        for task_layer in self.task_specific_layers.values():
            kl_div += task_layer.kl_divergence(next(self.parameters()).device)
        return kl_div

    def update_priors(self):
        # Update priors in each layer
        for layer in self.layers + list(self.task_specific_layers.values()):
            layer.set_priors()


In [6]:
import torch.nn.functional as F

def train(model, trainloader, optimizer, epoch, device, kl_weight=1, task_id = 0, binary_label = None):
    model.train()
    for batch_idx, (data, target) in enumerate(trainloader):
        data, target = data.to(device), target.to(device)
        data = data.view(data.size(0), -1)  # Flatten the images
        optimizer.zero_grad()
        output = model(data, sample=True, task_id = task_id)
        if binary_label != None:
            target = (target == binary_label[0]).long()
        reconstruction_loss = F.cross_entropy(output, target, reduction='mean')
        
#         print(kl_divergence,reconstruction_loss)
        if task_id == 0:
            loss = reconstruction_loss
        else:
            kl_divergence = model.kl_divergence()
            loss = reconstruction_loss + kl_divergence * kl_weight
        loss.backward()
        optimizer.step()
#     print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(trainloader.dataset)} ({100. * batch_idx / len(trainloader):.0f}%)]\tLoss: {loss.item()}")

def test(model, testloader, device, task_id = 0, binary_label = None):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            data = data.view(data.size(0), -1)  # Flatten the images
            output = model(data, sample=False, task_id = task_id)
            if binary_label != None:
                target = (target == binary_label[0]).long()

            # Use cross_entropy for test loss calculation, sum up batch loss
            test_loss += F.cross_entropy(output, target, reduction='mean').item()
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(testloader.dataset)
    # print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(testloader.dataset)} ({100. * correct / len(testloader.dataset):.0f}%)\n')
    return test_loss, correct / len(testloader.dataset)


## Permuted MNIST

In [7]:
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                                ])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [8]:
from tqdm import tqdm
def permute_mnist(mnist, perm):
    """Apply a fixed permutation to the pixels of each image in the dataset."""
    permuted_data = []
    for img, target in mnist:
        # Flatten the image, apply permutation and reshape back to 1x28x28
        img_permuted = img.view(-1)[perm].view(1, 28, 28)
        permuted_data.append((img_permuted, target))
    return permuted_data

# Initialize lists to store the permuted datasets
permuted_mnist_train_datasets = []
permuted_mnist_test_datasets = []
torch.manual_seed(SEED)
# Generate 10 permuted datasets
for _ in tqdm(range(10)):
    # Generate a fixed permutation
    fixed_permutation = torch.randperm(784)
    
    # Apply this permutation to the train and test datasets
    permuted_train = permute_mnist(mnist_trainset, fixed_permutation)
    permuted_test = permute_mnist(mnist_testset, fixed_permutation)
    
    # Store the permuted datasets
    permuted_mnist_train_datasets.append(permuted_train)
    permuted_mnist_test_datasets.append(permuted_test)

100%|██████████| 10/10 [00:34<00:00,  3.41s/it]


In [9]:

from torch.utils.data import DataLoader


batch_size = 256
pmnist_train_loaders = [DataLoader(m, batch_size=batch_size, shuffle=True) for m in permuted_mnist_train_datasets]
pmnist_test_loaders = [DataLoader(m, batch_size=batch_size, shuffle=False) for m in permuted_mnist_test_datasets]

In [10]:
from tqdm import tqdm
# Assuming model, optimizer, train_loaders, test_loaders are defined
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [11]:
from torch.utils.data import DataLoader, ConcatDataset, Subset
def random_coreset(dataset, coreset_size):
    """
    Randomly selects a subset of data points to form a coreset.

    Args:
    - dataset (torch.utils.data.Dataset): The dataset to sample from.
    - coreset_size (int): The number of samples to include in the coreset.

    Returns:
    - coreset_indices (torch.Tensor): Indices of the selected samples.
    """
    # Ensure coreset size does not exceed dataset size
    coreset_size = min(coreset_size, len(dataset))
    
    # Randomly select indices without replacement
    coreset_indices = np.random.choice(len(dataset), size=coreset_size, replace=False)
    
    # Convert numpy array to torch tensor
    coreset_indices = torch.from_numpy(coreset_indices)
    
    coreset = Subset(dataset, coreset_indices)
    return coreset



In [96]:
from tqdm import tqdm
epoch_per_task = 10

def run_vcl(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size=0, beta=1, binary_labels = None):
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    if binary_labels is None:
        binary_labels = [None] * model.num_tasks
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        task_accuracies_rc = []
        if coreset_size > 0:
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
                    train(model, coreset_loader, optimizer, epoch, device, beta, task_id=i, binary_label=binary_labels[i])
                model.update_priors()
        for epoch in (range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id, binary_label=binary_labels[task_id])
        model.update_priors()


        # for prediction
        prediction_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay
        if coreset_size > 0:
            coresets.append(random_coreset(train_loader.dataset, coreset_size))
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
                    train(prediction_model, coreset_loader, optimizer, epoch, device, beta, task_id=i, binary_label=binary_labels[i])
        task_num = 0  
        prev_test_loaders.append(test_loader)
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num, binary_label=binary_labels[task_num])
            task_accuracies_rc.append(task_accuracy)
            task_num += 1
        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    return ave_acc_trend_rc

def scale_similarity(sim, a, b):
    return 1/(1+np.exp(-20*(sim-(a+b)/2)))

def run_auto_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, 
                beta_star=1, raw_training_epoch = 1,raw_train_size = 1000, 
                binary_labels = None, dor = False, return_betas = False):
    task_difficulties = []
    
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    betas = []
    diff_gaps = []
    if binary_labels is None:
        binary_labels = [None] * model.output_size
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        raw_acc = []
        for i in range(10):
            raw_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
            ## raw training

            raw_trainset = random_coreset(train_loader.dataset, raw_train_size)
            raw_train_loader = DataLoader(raw_trainset, batch_size=batch_size, shuffle=True)
            raw_optimizer = torch.optim.Adam(raw_model.parameters(), lr=0.001)

            for epoch in (range(1, raw_training_epoch + 1)):
                train(raw_model, raw_train_loader, raw_optimizer, epoch, device, beta_star, task_id=0, binary_label=binary_labels[task_id])
            _, acc_simple_train = test(raw_model, test_loader, device,task_id=0, binary_label=binary_labels[task_id])
            raw_acc.append(acc_simple_train)
        
        acc_simple_train = np.mean(raw_acc)
        print(acc_simple_train)

        dummy_pred = 1/model.output_size
        curr_difficulty = min(max((1-(acc_simple_train - dummy_pred)/(1-dummy_pred)),0),1)
        if task_id > 0:
            print(task_id-1)
            _, raw_pred = test(model, test_loader, device,task_id=task_id-1, binary_label=binary_labels[task_id])
            print(raw_pred,'raw_pred')
            
            # similarity = min(max(np.abs(raw_pred - dummy_pred)/(prev_acc- dummy_pred),0),1)
            similarity = scale_similarity(np.abs(raw_pred-dummy_pred), 0, 1-dummy_pred)
            prev_difficulty = np.max(task_difficulties)
            print(prev_difficulty, curr_difficulty,similarity,'all')
            diff_gaps.append(np.abs(prev_difficulty-curr_difficulty))
            avg_diff_gaps = np.mean(diff_gaps)
            beta = beta_star*np.exp((prev_difficulty-curr_difficulty/(1+avg_diff_gaps*task_id))*5+similarity*5)
            betas.append(beta)
            print(beta,'beta')
        else: 
            beta = beta_star

        if coreset_size > 0 and task_id>0:
            if (dor):
                zipped_and_indices = sorted(enumerate(zip(task_difficulties, coresets)), key=lambda x: x[1][0], reverse=True)

                sorted_task_nums = [index for index, _ in zipped_and_indices]
                sorted_difficulties = [pair[0] for _, pair in zipped_and_indices]
                replay_coresets = [pair[1] for _, pair in zipped_and_indices]
                replay_betas = [beta_star*10**(2-d) for d in sorted_difficulties]
                print(sorted_difficulties, replay_betas)
            else:
                replay_coresets = coresets
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(replay_coresets[i], batch_size=batch_size, shuffle=True)
                    if dor:
                        sorted_task_num = sorted_task_nums[i]
                        train(model, coreset_loader, optimizer, epoch, device, replay_betas[i], task_id=sorted_task_num, binary_label=binary_labels[sorted_task_num])
                    else:
                        train(model, coreset_loader, optimizer, epoch, device, beta_star, task_id=i, binary_label=binary_labels[i])
                model.update_priors()
        for epoch in (range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id, binary_label=binary_labels[task_id])
        model.update_priors()


        # for prediction
        prediction_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay

        task_difficulties.append(curr_difficulty)
        if coreset_size > 0:
            
            coresets.append(random_coreset(train_loader.dataset, coreset_size))
            if (dor):
                zipped_and_indices = sorted(enumerate(zip(task_difficulties, coresets)), key=lambda x: x[1][0], reverse=True)

                sorted_task_nums = [index for index, _ in zipped_and_indices]
                sorted_difficulties = [pair[0] for _, pair in zipped_and_indices]
                replay_coresets = [pair[1] for _, pair in zipped_and_indices]
                replay_betas = [beta_star*10**(2-d) for d in sorted_difficulties]
            else:
                replay_coresets = coresets
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(replay_coresets[i], batch_size=batch_size, shuffle=True)
                    if dor:
                        sorted_task_num = sorted_task_nums[i]
                        train(model, coreset_loader, optimizer, epoch, device, replay_betas[i], task_id=sorted_task_num, binary_label=binary_labels[sorted_task_num])
                    else:
                        train(model, coreset_loader, optimizer, epoch, device, beta_star, task_id=i, binary_label=binary_labels[i])
        task_num = 0  
        prev_test_loaders.append(test_loader)

        task_accuracies_rc = []
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num, binary_label=binary_labels[task_num])
            task_accuracies_rc.append(task_accuracy)
            task_num += 1
        
        prev_acc = task_accuracy
        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    if return_betas:
        return ave_acc_trend_rc, betas
    
    return ave_acc_trend_rc

In [13]:
def create_split_task(dataset, classes):
    """
    Create a binary classification task from the MNIST dataset.
    
    Parameters:
    - dataset: The original MNIST dataset (training or test).
    - classes: A tuple of two integers representing the classes to include in the split.
    
    Returns:
    - A Subset of the original dataset containing only the specified classes.
    """
    # Find indices of classes we're interested in
    indices = [i for i, (_, target) in enumerate(dataset) if target in classes]
    
    # Create a subset of the dataset with only the specified classes
    subset = Subset(dataset, indices)
    
    return subset

def create_split_dataloaders(train_dataset, test_dataset, tasks, batch_size=256):
    """
    Create DataLoaders for each binary task in Split MNIST.
    
    Parameters:
    - train_dataset: The MNIST training dataset.
    - test_dataset: The MNIST test dataset.
    - batch_size: The batch size for the DataLoader.
    
    Returns:
    - A list of tuples containing (train_loader, test_loader) for each binary task.
    """
    train_loaders = []
    test_loaders = []
    for task in tasks:
        # Create training subset and DataLoader
        train_subset = create_split_task(train_dataset, task)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

        # Create test subset and DataLoader
        test_subset = create_split_task(test_dataset, task)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)
    
    return train_loaders, test_loaders

In [305]:
def plot_trends(trends, betas,title='Accuracy Trends in the Permuated MNIST Experiment', lower=0.7):
    import pandas as pd
    import numpy as np
    import altair as alt

    # Insert a None at the beginning of the AutoBeta list to align with the second point onwards
    adjusted_auto_beta = [None] + list(np.log(np.mean(betas, axis=0)))

    df = pd.DataFrame({
        '# of tasks': range(1,len(trends[0][0])+1),
        'beta = 0.01': np.mean(trends[0], axis=0),
        'beta = 1': np.mean(trends[1], axis=0),
        'beta = 100': np.mean(trends[2], axis=0),
        'AutoVCL': np.mean(trends[3], axis=0),
        'AutoBeta': adjusted_auto_beta  # Use the adjusted AutoBeta data
    })
    axis_start = 1.5
    # Convert the DataFrame to long format for accuracies
    df_long_acc = df.melt('# of tasks', var_name='Series', value_name='Values', value_vars=['beta = 0.01', 'beta = 1', 'beta = 100', 'AutoVCL'])
    legend_order = ['beta = 0.01', 'beta = 1', 'beta = 100', 'AutoVCL']
    # Plot for accuracies
    acc_chart = alt.Chart(df_long_acc).mark_line(point=True).encode(
        x=alt.X('# of tasks:Q', title='# tasks', 
            scale=alt.Scale(domain=[axis_start, len(trends[0][0])]),  # Adjust domain slightly for visual alignment
            axis=alt.Axis(values=list(range(1, len(trends[0][0])+1)))
           ),
        y=alt.Y('Values:Q', scale=alt.Scale(domain=[lower, 1]), axis=alt.Axis(grid=True), title='Accuracy'),
        color=alt.Color('Series:N', sort=legend_order,scale=alt.Scale(scheme='category10'), legend=alt.Legend(title="Model")),
        tooltip=['# of tasks', 'Values', 'Series']
    )
    # Plot for AutoBeta
    loss_chart = alt.Chart(df).mark_bar(opacity=0.3, color='skyblue', width=40).encode(
        x=alt.X('# of tasks:Q', title='# tasks', 
            scale=alt.Scale(domain=[axis_start, len(trends[0][0])]),  # Adjust domain slightly for visual alignment
            axis=alt.Axis(values=list(range(1, len(trends[0][0]))))
           ),
        y=alt.Y('AutoBeta:Q', title='log(AutoBeta)',scale=alt.Scale(domain=[-6, 6]), axis=alt.Axis(labelColor='skyblue',titleFontSize=18,titleColor='skyblue')),
        tooltip=['# of tasks', 'AutoBeta']
    )
 
    # Combine the charts with independent scales for y-axes
    chart = alt.layer(acc_chart, loss_chart).resolve_scale(y='independent').properties(
        width=800,
        height=400,
        title=title
    ).configure_axis(
        labelFontSize=15,
        titleFontSize=20
    ).configure_legend(
        labelFontSize=15,
        titleFontSize=15
    ).configure_title(
        fontSize=24
    )

    chart.display()

## Reproduce

In [None]:
epoch_per_task

## split no core

In [123]:
tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
split_train_loaders, split_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=256)

In [124]:

coreset_size = 0
trends_1 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = tasks)
    trends_1.append(trend)

Average Accuracy across 1 tasks: 99.86%
Average Accuracy across 2 tasks: 99.64%
Average Accuracy across 3 tasks: 99.32%
Average Accuracy across 4 tasks: 96.82%
Average Accuracy across 5 tasks: 90.68%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.78%
Average Accuracy across 3 tasks: 99.43%
Average Accuracy across 4 tasks: 97.52%
Average Accuracy across 5 tasks: 96.74%
Average Accuracy across 1 tasks: 99.72%
Average Accuracy across 2 tasks: 99.71%
Average Accuracy across 3 tasks: 99.54%
Average Accuracy across 4 tasks: 94.69%
Average Accuracy across 5 tasks: 97.85%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.68%
Average Accuracy across 3 tasks: 99.50%
Average Accuracy across 4 tasks: 93.90%
Average Accuracy across 5 tasks: 96.39%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.78%
Average Accuracy across 3 tasks: 99.53%
Average Accuracy across 4 tasks: 98.02%
Average Accuracy across 5 tasks: 96.90%


In [125]:
trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = tasks)
    trends_2.append(trend)

Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.71%
Average Accuracy across 3 tasks: 99.72%
Average Accuracy across 4 tasks: 99.37%
Average Accuracy across 5 tasks: 98.27%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.71%
Average Accuracy across 3 tasks: 99.28%
Average Accuracy across 4 tasks: 98.63%
Average Accuracy across 5 tasks: 98.84%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.68%
Average Accuracy across 3 tasks: 99.57%
Average Accuracy across 4 tasks: 99.15%
Average Accuracy across 5 tasks: 98.91%
Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 99.73%
Average Accuracy across 3 tasks: 99.64%
Average Accuracy across 4 tasks: 99.36%
Average Accuracy across 5 tasks: 97.53%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.78%
Average Accuracy across 3 tasks: 99.64%
Average Accuracy across 4 tasks: 97.19%
Average Accuracy across 5 tasks: 98.74%


In [126]:

trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e2, binary_labels = tasks)
    trends_3.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 98.97%
Average Accuracy across 3 tasks: 98.94%
Average Accuracy across 4 tasks: 98.98%
Average Accuracy across 5 tasks: 98.20%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 98.85%
Average Accuracy across 3 tasks: 98.96%
Average Accuracy across 4 tasks: 99.01%
Average Accuracy across 5 tasks: 97.98%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.22%
Average Accuracy across 3 tasks: 99.19%
Average Accuracy across 4 tasks: 98.65%
Average Accuracy across 5 tasks: 96.46%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 98.88%
Average Accuracy across 3 tasks: 98.91%
Average Accuracy across 4 tasks: 98.88%
Average Accuracy across 5 tasks: 98.07%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.07%
Average Accuracy across 3 tasks: 98.99%
Average Accuracy across 4 tasks: 98.98%
Average Accuracy across 5 tasks: 98.27%


In [310]:
trends_4 = []
betas_split = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, beta = run_auto_vcl(model, 
        split_train_loaders,
        split_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks, 
        return_betas = True)
    trends_4.append(trend)
    betas_split.append(beta)

0.994420803782506
Average Accuracy across 1 tasks: 99.95%
0.9128305582761997
0
0.45396669931439765 raw_pred
0.011158392434988063 0.1743388834476005 0.01663724888539962 all
0.5431184870548946 beta
Average Accuracy across 2 tasks: 99.76%
0.9288153681963713
1
0.7086446104589115 raw_pred
0.1743388834476005 0.14236926360725732 0.3042569596635395 all
6.033738116294315 beta
Average Accuracy across 3 tasks: 99.57%
0.9665659617321248
2
0.34390735146022156 raw_pred
0.1743388834476005 0.06686807653575033 0.1326018546001556 all
3.5895837378074518 beta
Average Accuracy across 4 tasks: 99.43%
0.8985375693393847
3
0.7256681795259707 raw_pred
0.1743388834476005 0.20292486132123067 0.3806862631191963 all
7.485306388009886 beta
Average Accuracy across 5 tasks: 98.90%
0.9940898345153663
Average Accuracy across 1 tasks: 100.00%
0.8992654260528894
0
0.4720861900097943 raw_pred
0.01182033096926749 0.2014691478942212 0.011638570781885236 all
0.4821671674710661 beta
Average Accuracy across 2 tasks: 99.78%
0.9

In [311]:
import matplotlib.pyplot as plt
plot_trends([trends_1,trends_2,trends_3,trends_4,],betas_split ,lower=0.9)

## Intentially alike


In [117]:
tasks = [(0, 1), (8, 7), (9, 4), (6, 2), (3, 5)]
split_alike_train_loaders, split_alike_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)

In [118]:
coreset_size = 0
trends_alike_1 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
    beta=1e-2, binary_labels = tasks)
    trends_alike_1.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.73%
Average Accuracy across 3 tasks: 98.90%
Average Accuracy across 4 tasks: 97.60%
Average Accuracy across 5 tasks: 92.08%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.68%
Average Accuracy across 3 tasks: 98.83%
Average Accuracy across 4 tasks: 96.32%
Average Accuracy across 5 tasks: 93.46%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.75%
Average Accuracy across 3 tasks: 99.43%
Average Accuracy across 4 tasks: 98.52%
Average Accuracy across 5 tasks: 86.86%
Average Accuracy across 1 tasks: 100.00%
Average Accuracy across 2 tasks: 99.83%
Average Accuracy across 3 tasks: 98.78%
Average Accuracy across 4 tasks: 97.23%
Average Accuracy across 5 tasks: 92.14%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.78%
Average Accuracy across 3 tasks: 99.00%
Average Accuracy across 4 tasks: 94.47%
Average Accuracy across 5 tasks: 91.44%

In [119]:
coreset_size = 0
trends_alike_2 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
        beta=1, binary_labels = tasks)
    trends_alike_2.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.75%
Average Accuracy across 3 tasks: 98.64%
Average Accuracy across 4 tasks: 98.60%
Average Accuracy across 5 tasks: 95.57%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.70%
Average Accuracy across 3 tasks: 99.32%
Average Accuracy across 4 tasks: 98.75%
Average Accuracy across 5 tasks: 94.41%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.80%
Average Accuracy across 3 tasks: 99.40%
Average Accuracy across 4 tasks: 98.79%
Average Accuracy across 5 tasks: 95.35%
Average Accuracy across 1 tasks: 100.00%
Average Accuracy across 2 tasks: 99.70%
Average Accuracy across 3 tasks: 98.84%
Average Accuracy across 4 tasks: 98.69%
Average Accuracy across 5 tasks: 94.29%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.75%
Average Accuracy across 3 tasks: 99.19%
Average Accuracy across 4 tasks: 97.86%
Average Accuracy across 5 tasks: 95.28%

In [120]:
coreset_size = 0
trends_alike_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders,
     optimizer, epoch_per_task, coreset_size, 
        beta=1e2, binary_labels = tasks)
    trends_alike_3.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.13%
Average Accuracy across 3 tasks: 97.91%
Average Accuracy across 4 tasks: 97.90%
Average Accuracy across 5 tasks: 97.74%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.38%
Average Accuracy across 3 tasks: 98.03%
Average Accuracy across 4 tasks: 97.68%
Average Accuracy across 5 tasks: 95.66%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.15%
Average Accuracy across 3 tasks: 98.01%
Average Accuracy across 4 tasks: 97.38%
Average Accuracy across 5 tasks: 96.97%
Average Accuracy across 1 tasks: 100.00%
Average Accuracy across 2 tasks: 99.15%
Average Accuracy across 3 tasks: 97.89%
Average Accuracy across 4 tasks: 97.68%
Average Accuracy across 5 tasks: 97.94%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.13%
Average Accuracy across 3 tasks: 97.91%
Average Accuracy across 4 tasks: 97.42%
Average Accuracy across 5 tasks: 97.18%

In [121]:
trends_alike_4 = []
alike_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, alike_beta = run_auto_vcl(model, 
        split_alike_train_loaders,
        split_alike_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks,
        return_betas = True)
    trends_alike_4.append(trend)
    alike_betas.append(alike_beta)

0.995839243498818
Average Accuracy across 1 tasks: 100.00%
0.9428071928071928
0
0.42357642357642356 raw_pred
0.008321513002363945 0.11438561438561434 0.030133279405254842 all
0.7226646049900293 beta
Average Accuracy across 2 tasks: 99.78%
0.7797086891009544
1
0.4535409342039176 raw_pred
0.11438561438561434 0.44058262179809127 0.01677713800406201 all
0.4138494857055068 beta
Average Accuracy across 3 tasks: 99.11%
0.9182914572864324
2
0.21809045226130652 raw_pred
0.44058262179809127 0.1634170854271353 0.6543444090561353 all
147.9163856087284 beta
Average Accuracy across 4 tasks: 98.56%
0.8257097791798106
3
0.2870662460567823 raw_pred
0.44058262179809127 0.3485804416403788 0.32271448795659746 all
17.27006311166509 beta
Average Accuracy across 5 tasks: 97.96%
0.994468085106383
Average Accuracy across 1 tasks: 99.91%
0.9303696303696304
0
0.44955044955044954 raw_pred
0.01106382978723408 0.13926073926073923 0.01814570565892329 all
0.6242932337685918 beta
Average Accuracy across 2 tasks: 99.75

array([  0.64009456,   0.44214465, 153.48629473,  22.31186373])

In [307]:
plot_trends([trends_alike_1, trends_alike_2, trends_alike_3, trends_alike_4,] ,alike_betas, lower = 0.9)

## different difficulties

In [21]:
transform_cifar = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # Convert image to grayscale
    transforms.Resize((28, 28)),
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))])

# Load the CIFAR-10 training dataset with the defined transform
cifar_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)

# Load the CIFAR-10 test dataset with the defined transform
cifar_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)


Files already downloaded and verified
Files already downloaded and verified


In [22]:
tasks = [(0,1),(2,3),(4,5),(6,7),(8,9)]
mixed_tasks = [tasks[i//2] for i in range(len(tasks)*2)]

In [23]:
torch.manual_seed(SEED)
mnist_train_loaders, mnist_test_loaders = \
    create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)
cifar_train_loaders, cifar_test_loaders = \
    create_split_dataloaders(cifar_train_dataset, cifar_test_dataset, tasks, batch_size=batch_size)

In [24]:
mixed_train_loaders = [mnist_train_loaders, cifar_train_loaders]
mixed_test_loaders = [mnist_test_loaders, cifar_test_loaders]

mixed_train_loaders = [mixed_train_loaders[i%2][i//2] for i in range(len(mixed_tasks))]
mixed_test_loaders = [mixed_test_loaders[i%2][i//2] for i in range(len(mixed_tasks))]

In [74]:
coreset_size= 0
torch.manual_seed(SEED)
mixed_trends_1 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = mixed_tasks)
    mixed_trends_1.append(trend)

Average Accuracy across 1 tasks: 100.00%
Average Accuracy across 2 tasks: 93.63%
Average Accuracy across 3 tasks: 92.74%
Average Accuracy across 4 tasks: 85.73%
Average Accuracy across 5 tasks: 85.50%
Average Accuracy across 6 tasks: 82.72%
Average Accuracy across 7 tasks: 79.29%
Average Accuracy across 8 tasks: 76.47%
Average Accuracy across 9 tasks: 72.76%
Average Accuracy across 10 tasks: 72.57%
Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 93.28%
Average Accuracy across 3 tasks: 89.76%
Average Accuracy across 4 tasks: 85.13%
Average Accuracy across 5 tasks: 81.70%
Average Accuracy across 6 tasks: 81.25%
Average Accuracy across 7 tasks: 79.09%
Average Accuracy across 8 tasks: 79.44%
Average Accuracy across 9 tasks: 77.17%
Average Accuracy across 10 tasks: 77.25%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.75%
Average Accuracy across 3 tasks: 94.42%
Average Accuracy across 4 tasks: 85.73%
Average Accuracy across 5 tasks: 84.7

In [309]:
torch.manual_seed(SEED)
mixed_trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = mixed_tasks)
    mixed_trends_2.append(trend)

Average Accuracy across 1 tasks: 100.00%
Average Accuracy across 2 tasks: 93.30%
Average Accuracy across 3 tasks: 93.37%
Average Accuracy across 4 tasks: 86.41%
Average Accuracy across 5 tasks: 87.59%
Average Accuracy across 6 tasks: 82.02%
Average Accuracy across 7 tasks: 84.26%
Average Accuracy across 8 tasks: 81.85%
Average Accuracy across 9 tasks: 79.67%
Average Accuracy across 10 tasks: 80.03%
Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 93.25%
Average Accuracy across 3 tasks: 93.35%
Average Accuracy across 4 tasks: 87.12%
Average Accuracy across 5 tasks: 88.32%
Average Accuracy across 6 tasks: 76.68%
Average Accuracy across 7 tasks: 78.54%
Average Accuracy across 8 tasks: 78.89%
Average Accuracy across 9 tasks: 80.05%
Average Accuracy across 10 tasks: 81.86%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.83%
Average Accuracy across 3 tasks: 94.31%
Average Accuracy across 4 tasks: 88.02%
Average Accuracy across 5 tasks: 89.8

In [76]:
torch.manual_seed(SEED)
mixed_trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, 
        epoch_per_task, coreset_size, beta=1e2, binary_labels = mixed_tasks)
    mixed_trends_3.append(trend)

Average Accuracy across 1 tasks: 100.00%
Average Accuracy across 2 tasks: 87.35%
Average Accuracy across 3 tasks: 90.90%
Average Accuracy across 4 tasks: 83.75%
Average Accuracy across 5 tasks: 86.76%
Average Accuracy across 6 tasks: 83.42%
Average Accuracy across 7 tasks: 85.36%
Average Accuracy across 8 tasks: 83.45%
Average Accuracy across 9 tasks: 84.41%
Average Accuracy across 10 tasks: 83.19%
Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 87.96%
Average Accuracy across 3 tasks: 91.58%
Average Accuracy across 4 tasks: 84.29%
Average Accuracy across 5 tasks: 86.94%
Average Accuracy across 6 tasks: 80.66%
Average Accuracy across 7 tasks: 82.60%
Average Accuracy across 8 tasks: 82.07%
Average Accuracy across 9 tasks: 82.07%
Average Accuracy across 10 tasks: 81.14%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 89.25%
Average Accuracy across 3 tasks: 92.28%
Average Accuracy across 4 tasks: 85.49%
Average Accuracy across 5 tasks: 88.1

In [97]:
mixed_trends_4 = []
mixed_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, m_beta = run_auto_vcl(model, 
        mixed_train_loaders,
        mixed_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = mixed_tasks,
        return_betas = True)
    mixed_trends_4.append(trend)
    mixed_betas.append(m_beta)

0.9946572104018913
Average Accuracy across 1 tasks: 99.91%
0.6978500000000001
0
0.6465 raw_pred
0.010685579196217487 0.6042999999999998 0.11204703855699023 all
0.2773870820168487 beta
Average Accuracy across 2 tasks: 93.45%
0.9015181194906955
1
0.46767874632713025 raw_pred
0.6042999999999998 0.19696376101860902 0.012697455202165291 all
13.36730525473021 beta
Average Accuracy across 3 tasks: 95.26%
0.5807499999999999
2
0.492 raw_pred
0.6042999999999998 0.8385000000000002 0.007845023030455634 all
3.2707281286801315 beta
Average Accuracy across 4 tasks: 86.92%
0.9059231590181431
3
0.503735325506937 raw_pred
0.8385000000000002 0.18815368196371374 0.0072082588077857146 all
49.52650036842785 beta
Average Accuracy across 5 tasks: 89.76%
0.6123
4
0.4885 raw_pred
0.8385000000000002 0.7754000000000001 0.008409068065962979 all
18.535405289814072 beta
Average Accuracy across 6 tasks: 85.47%
0.9598187311178249
5
0.5196374622356495 raw_pred
0.8385000000000002 0.0803625377643502 0.009880615143752616 

In [30]:
# mixed_trends_5 = []
# for i in range(5):
#     model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#     trend = run_auto_vcl(model, 
#         mixed_train_loaders,
#         mixed_test_loaders,
#         optimizer, 
#         epoch_per_task, 
#         coreset_size,
#         binary_labels = mixed_tasks,
#         dor = True)
#     mixed_trends_5.append(trend)

In [308]:
plot_trends([mixed_trends_1, mixed_trends_2, mixed_trends_3, mixed_trends_4], mixed_betas,lower=0.7)

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


P-MNIST

In [101]:
p_trends_1 = []
torch.manual_seed(SEED)
coreset_size = 0
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2)
    p_trends_1.append(trend)

Average Accuracy across 1 tasks: 97.68%
Average Accuracy across 2 tasks: 95.38%
Average Accuracy across 3 tasks: 88.66%
Average Accuracy across 4 tasks: 82.93%
Average Accuracy across 5 tasks: 79.31%
Average Accuracy across 6 tasks: 73.25%
Average Accuracy across 7 tasks: 70.21%
Average Accuracy across 8 tasks: 67.15%
Average Accuracy across 9 tasks: 61.20%
Average Accuracy across 10 tasks: 57.76%
Average Accuracy across 1 tasks: 97.85%
Average Accuracy across 2 tasks: 95.78%
Average Accuracy across 3 tasks: 90.71%
Average Accuracy across 4 tasks: 82.54%
Average Accuracy across 5 tasks: 74.92%
Average Accuracy across 6 tasks: 74.06%
Average Accuracy across 7 tasks: 67.81%
Average Accuracy across 8 tasks: 64.33%
Average Accuracy across 9 tasks: 57.56%
Average Accuracy across 10 tasks: 57.05%
Average Accuracy across 1 tasks: 97.90%
Average Accuracy across 2 tasks: 96.42%
Average Accuracy across 3 tasks: 88.64%
Average Accuracy across 4 tasks: 83.64%
Average Accuracy across 5 tasks: 79.19

In [115]:
p_trends_2 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_2.append(trend)

Average Accuracy across 1 tasks: 97.68%
Average Accuracy across 2 tasks: 96.77%
Average Accuracy across 3 tasks: 95.98%
Average Accuracy across 4 tasks: 95.32%
Average Accuracy across 5 tasks: 94.90%
Average Accuracy across 6 tasks: 93.94%
Average Accuracy across 7 tasks: 93.19%
Average Accuracy across 8 tasks: 92.03%
Average Accuracy across 9 tasks: 91.00%
Average Accuracy across 10 tasks: 89.98%
Average Accuracy across 1 tasks: 97.85%
Average Accuracy across 2 tasks: 96.91%
Average Accuracy across 3 tasks: 96.30%
Average Accuracy across 4 tasks: 95.75%
Average Accuracy across 5 tasks: 95.09%
Average Accuracy across 6 tasks: 93.65%
Average Accuracy across 7 tasks: 92.35%
Average Accuracy across 8 tasks: 91.95%
Average Accuracy across 9 tasks: 91.13%
Average Accuracy across 10 tasks: 89.87%
Average Accuracy across 1 tasks: 97.90%
Average Accuracy across 2 tasks: 96.79%
Average Accuracy across 3 tasks: 96.14%
Average Accuracy across 4 tasks: 95.63%
Average Accuracy across 5 tasks: 94.95

In [103]:
p_trends_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=100)
    p_trends_3.append(trend)

Average Accuracy across 1 tasks: 97.68%
Average Accuracy across 2 tasks: 86.02%
Average Accuracy across 3 tasks: 81.05%
Average Accuracy across 4 tasks: 78.82%
Average Accuracy across 5 tasks: 77.67%
Average Accuracy across 6 tasks: 75.98%
Average Accuracy across 7 tasks: 74.76%
Average Accuracy across 8 tasks: 74.06%
Average Accuracy across 9 tasks: 73.25%
Average Accuracy across 10 tasks: 72.62%
Average Accuracy across 1 tasks: 97.85%
Average Accuracy across 2 tasks: 85.99%
Average Accuracy across 3 tasks: 81.60%
Average Accuracy across 4 tasks: 79.13%
Average Accuracy across 5 tasks: 77.60%
Average Accuracy across 6 tasks: 76.88%
Average Accuracy across 7 tasks: 76.00%
Average Accuracy across 8 tasks: 75.01%
Average Accuracy across 9 tasks: 74.34%
Average Accuracy across 10 tasks: 73.51%
Average Accuracy across 1 tasks: 97.90%
Average Accuracy across 2 tasks: 85.93%
Average Accuracy across 3 tasks: 81.64%
Average Accuracy across 4 tasks: 79.34%
Average Accuracy across 5 tasks: 77.97

In [104]:
p_trends_4 = []
p_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, p_beta= run_auto_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, 
        epoch_per_task, coreset_size, return_betas=True)
    p_trends_4.append(trend)
    p_betas.append(p_beta)

0.2434
Average Accuracy across 1 tasks: 97.75%
0.29325
0
0.132 raw_pred
0.8406666666666667 0.7852777777777777 0.00023398956983617628 all
1.6228421399680775 beta
Average Accuracy across 2 tasks: 96.49%
0.30765
1
0.0897 raw_pred
0.8406666666666667 0.7692777777777778 0.00015161720091445732 all
2.2044398336895314 beta
Average Accuracy across 3 tasks: 95.77%
0.28714
2
0.0468 raw_pred
0.8406666666666667 0.7920666666666667 0.00035750630121053004 all
2.30644579110361 beta
Average Accuracy across 4 tasks: 95.12%
0.23016999999999999
3
0.087 raw_pred
0.8406666666666667 0.8553666666666667 0.00016002827476199207 all
1.8411507225137402 beta
Average Accuracy across 5 tasks: 94.70%
0.27287
4
0.1046 raw_pred
0.8553666666666667 0.8079222222222222 0.0001352838637004801 all
2.7545006099288583 beta
Average Accuracy across 6 tasks: 93.74%
0.2983700000000001
5
0.1149 raw_pred
0.8553666666666667 0.7795888888888888 0.00016622533685894952 all
3.704732755327189 beta
Average Accuracy across 7 tasks: 92.75%
0.2834

In [306]:
plot_trends([p_trends_1, p_trends_2,p_trends_3, p_trends_4], p_betas, lower = 0.5)

In [111]:
np.mean(p_trends_3,axis=0)

array([0.978     , 0.85905   , 0.81322667, 0.787145  , 0.774216  ,
       0.76382333, 0.75400286, 0.746775  , 0.73916444, 0.73209   ])