In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
SEED = 42

# Set the random seed for PyTorch
torch.manual_seed(SEED)

# Additionally, if using NumPy operations in your data processing, set its seed
np.random.seed(SEED)

class MFVI_Layer(nn.Module):
    def __init__(self, in_features, out_features):
        super(MFVI_Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Mean parameters
        self.W_m = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_m = nn.Parameter(torch.Tensor(out_features))

        # Log variance parameters
        self.W_logv = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_logv = nn.Parameter(torch.Tensor(out_features))

        # Prior distributions (initialized to None and will be set in reset_parameters)
        self.prior_W_m = None
        self.prior_b_m = None
        self.prior_W_logv = None
        self.prior_b_logv = None

        # Initialize parameters
        self.reset_parameters()

    def reset_parameters(self):
        self.W_m.data.normal_(0, 0.1)
        self.b_m.data.normal_(0, 0.1)
        self.W_logv.data.fill_(-6.0)
        self.b_logv.data.fill_(-6.0)

        # Initially set priors to match the initial parameters
        self.set_priors()

    def set_priors(self):
        # Update or set the prior distributions to the current parameters
        self.prior_W_m = self.W_m.detach().clone()
        self.prior_b_m = self.b_m.detach().clone()
        self.prior_W_logv = self.W_logv.detach().clone()
        self.prior_b_logv = self.b_logv.detach().clone()

    def forward(self, x, sample=True):
        # Calculate standard deviations from log variances
        W_std = torch.exp(0.5 * self.W_logv)
        b_std = torch.exp(0.5 * self.b_logv)
        
        # Calculate the output mean
        act_mu = F.linear(x, self.W_m, self.b_m)
        
        # Calculate the output variance
        # The variance of the output is given by the sum of variances of the weighted inputs (assuming independence)
        act_var = 1e-16 + F.linear(x.pow(2), W_std.pow(2)) + b_std.pow(2)
        act_std = torch.sqrt(act_var)

        if self.training or sample:
            # If in training mode or sample is True, sample from the posterior
            eps = torch.randn_like(act_mu)
            return act_mu + act_std * eps
        else:
            # Otherwise, return the mean of the posterior
            return act_mu


    def kl_divergence(self,device):
        # Calculate the number of parameters for normalization
        self.update_prior_device(device)
        num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        # Convert log variances to standard deviations
        W_std = torch.exp(0.5 * self.W_logv)
        b_std = torch.exp(0.5 * self.b_logv)
        prior_W_std = torch.exp(0.5 * self.prior_W_logv)
        prior_b_std = torch.exp(0.5 * self.prior_b_logv)

        # Calculate KL divergence for weights using the standard deviation-based formula
        kl_W = 0.5 * (2 * torch.log(prior_W_std / W_std) - 1 + (W_std / prior_W_std).pow(2) + ((self.W_m - self.prior_W_m) / prior_W_std).pow(2)).sum()

        # Calculate KL divergence for biases using the standard deviation-based formula
        kl_b = 0.5 * (2 * torch.log(prior_b_std / b_std) - 1 + (b_std / prior_b_std).pow(2) + ((self.b_m - self.prior_b_m) / prior_b_std).pow(2)).sum()

        # Return the normalized KL divergence
        return (kl_W + kl_b) /num_params

    def update_prior_device(self, device):
        self.prior_W_m = self.prior_W_m.to(device)
        self.prior_b_m = self.prior_b_m.to(device)
        self.prior_W_logv = self.prior_W_logv.to(device)
        self.prior_b_logv = self.prior_b_logv.to(device)



class MFVI_NN(nn.Module):
    def __init__(self, input_size, hidden_sizes,  output_size, num_tasks = 1, single_head =False):
        super(MFVI_NN, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.num_tasks = num_tasks
        self.single_head = single_head

        self.layers = nn.ModuleList()
        self.task_specific_layers = nn.ModuleDict()  # Using ModuleDict to hold task-specific layers

        # Define shared layers
        sizes = [input_size] + hidden_sizes
        for i in range(len(sizes)-1):
            self.layers.append(MFVI_Layer(sizes[i], sizes[i+1]))

        # Define task-specific output layers
        if single_head:
            self.task_specific_layers[str(0)] = MFVI_Layer(sizes[-1], output_size)
        else:
            for task_id in range(num_tasks):
                self.task_specific_layers[str(task_id)] = MFVI_Layer(sizes[-1], output_size)

    def forward(self, x, task_id=0, sample=True):
        for layer in self.layers:
            x = F.relu(layer(x, sample))
        if self.single_head:
            task_layer = self.task_specific_layers["0"]
            x = task_layer(x, sample)
        else:
            # Select and apply the task-specific output layer
            task_layer = self.task_specific_layers[str(task_id)]
            x = task_layer(x, sample)
        return x

    def kl_divergence(self):
        kl_div = 0
        # Sum KL divergence from shared layers
        for layer in self.layers:
            kl_div += layer.kl_divergence(next(self.parameters()).device)
        if self.single_head:
            layer = self.task_specific_layers["0"]
            kl_div += layer.kl_divergence(next(self.parameters()).device)
        # Sum KL divergence from task-specific layers
        # for task_layer in self.task_specific_layers.values():
        #     kl_div += task_layer.kl_divergence(next(self.parameters()).device)

        return kl_div


    def update_priors(self):
        # Update priors in each shared layer
        for layer in self.layers:
            layer.set_priors()  # Assuming each MFVI_Layer has a method called set_priors
        
        for task_layer in self.task_specific_layers.values():
            task_layer.set_priors()




In [17]:
import torch.nn.functional as F

def train(model, trainloader, optimizer, epoch, device, kl_weight=1, task_id = 0, binary_label = None):
    model.train()
    for batch_idx, (data, target) in enumerate(trainloader):
        data, target = data.to(device), target.to(device)
        data = data.view(data.size(0), -1)  # Flatten the images
        optimizer.zero_grad()
        output = model(data, sample=True, task_id = task_id)
        if binary_label != None:
            target = (target == binary_label[0]).long()
        reconstruction_loss = F.cross_entropy(output, target, reduction='mean')
        
#         print(kl_divergence,reconstruction_loss)
        if task_id == 0:
            loss = reconstruction_loss
        else:
            kl_divergence = model.kl_divergence()
            loss = reconstruction_loss + kl_divergence * kl_weight
        loss.backward()
        optimizer.step()
#     print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(trainloader.dataset)} ({100. * batch_idx / len(trainloader):.0f}%)]\tLoss: {loss.item()}")

def test(model, testloader, device, task_id = 0, binary_label = None):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            data = data.view(data.size(0), -1)  # Flatten the images
            output = model(data, sample=False, task_id = task_id)
            if binary_label != None:
                target = (target == binary_label[0]).long()

            # Use cross_entropy for test loss calculation, sum up batch loss
            test_loss += F.cross_entropy(output, target, reduction='mean').item()
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(testloader.dataset)
    # print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(testloader.dataset)} ({100. * correct / len(testloader.dataset):.0f}%)\n')
    return test_loss, correct / len(testloader.dataset)


## Permuted MNIST

In [2]:
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [3]:
import torch

def generate_permutations(task_count, image_size):
    permutations = [torch.randperm(image_size) for _ in range(task_count)]
    return permutations

task_count = 10
image_size = 28 * 28  # MNIST images are 28x28
permutations = generate_permutations(task_count, image_size)

In [4]:
from torch.utils.data import Dataset

class PermutedMNIST(Dataset):
    def __init__(self, mnist_dataset, permutation=None):
        self.mnist_dataset = mnist_dataset
        self.permutation = permutation

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        image, label = self.mnist_dataset[idx]
        if self.permutation is not None:
            # Apply permutation
            image = image.view(-1)[self.permutation].view(1, 28, 28)
        return image, label


In [5]:
from torch.utils.data import DataLoader

batch_size = 256

# Create a DataLoader for the original MNIST
pmnist_train_loaders = []
pmnist_test_loaders = []

# Create DataLoaders for permuted tasks
for perm in permutations:
    permuted_train = PermutedMNIST(mnist_trainset, permutation=perm)
    permuted_test = PermutedMNIST(mnist_testset, permutation=perm)

    train_loader = DataLoader(permuted_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(permuted_test, batch_size=batch_size, shuffle=False)

    pmnist_train_loaders.append(train_loader)
    pmnist_test_loaders.append(test_loader)


In [6]:
from tqdm import tqdm
# Assuming model, optimizer, train_loaders, test_loaders are defined
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [7]:
from torch.utils.data import DataLoader, ConcatDataset, Subset
def random_coreset(dataset, coreset_size):
    """
    Randomly selects a subset of data points to form a coreset.

    Args:
    - dataset (torch.utils.data.Dataset): The dataset to sample from.
    - coreset_size (int): The number of samples to include in the coreset.

    Returns:
    - coreset_indices (torch.Tensor): Indices of the selected samples.
    """
    # Ensure coreset size does not exceed dataset size
    coreset_size = min(coreset_size, len(dataset))
    
    # Randomly select indices without replacement
    coreset_indices = np.random.choice(len(dataset), size=coreset_size, replace=False)
    
    # Convert numpy array to torch tensor
    coreset_indices = torch.from_numpy(coreset_indices)
    
    coreset = Subset(dataset, coreset_indices)
    return coreset



In [8]:
from tqdm import tqdm
epoch_per_task = 10
model = MFVI_NN(28*28, [100, 100], 2, num_tasks = 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
coreset_size = 200
beta = 1

def run_vcl(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size=0, beta=1, binary_labels = None):
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    if binary_labels is None:
        binary_labels = [None] * model.output_size
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        task_accuracies_rc = []
        if coreset_size > 0:
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
                    train(model, coreset_loader, optimizer, epoch, device, beta, task_id=i, binary_label=binary_labels[i])
                model.update_priors()
        for epoch in (range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id, binary_label=binary_labels[task_id])
        model.update_priors()


        # for prediction
        prediction_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay
        if coreset_size > 0:
            coresets.append(random_coreset(train_loader.dataset, coreset_size))
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
                    train(prediction_model, coreset_loader, optimizer, epoch, device, beta, task_id=i, binary_label=binary_labels[i])
        task_num = 0  
        prev_test_loaders.append(test_loader)
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num, binary_label=binary_labels[task_num])
            task_accuracies_rc.append(task_accuracy)
            task_num += 1
        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    return ave_acc_trend_rc

def scale_similarity(sim, a, b):
    return 1/(1+np.exp(-20*(sim-(a+b)/2)))

def run_auto_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, 
                beta_star=1, raw_training_epoch = 1,raw_train_size = 1000, 
                binary_labels = None, dor = False, return_betas = False):
    task_difficulties = []
    
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    betas = []
    if binary_labels is None:
        binary_labels = [None] * model.output_size
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        raw_acc = []
        for i in range(10):
            raw_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
            ## raw training

            raw_trainset = random_coreset(train_loader.dataset, raw_train_size)
            raw_train_loader = DataLoader(raw_trainset, batch_size=batch_size, shuffle=True)
            raw_optimizer = torch.optim.Adam(raw_model.parameters(), lr=0.001)

            for epoch in (range(1, raw_training_epoch + 1)):
                train(raw_model, raw_train_loader, raw_optimizer, epoch, device, beta_star, task_id=0, binary_label=binary_labels[task_id])
            _, acc_simple_train = test(raw_model, test_loader, device,task_id=0, binary_label=binary_labels[task_id])
            raw_acc.append(acc_simple_train)
        
        acc_simple_train = np.mean(raw_acc)
        print(acc_simple_train)

        dummy_pred = 1/model.output_size
        curr_difficulty = min(max((1-(acc_simple_train - dummy_pred)/(1-dummy_pred)),0),1)
        if task_id > 0:
            print(task_id-1)
            _, raw_pred = test(model, test_loader, device,task_id=task_id-1, binary_label=binary_labels[task_id])
            print(raw_pred,'raw_pred')
            
            # similarity = min(max(np.abs(raw_pred - dummy_pred)/(prev_acc- dummy_pred),0),1)
            similarity = scale_similarity(np.abs(raw_pred-dummy_pred), 0, 1-dummy_pred)
            prev_difficulty = np.max(task_difficulties)
            print(prev_difficulty, curr_difficulty,similarity,'all')
            beta = beta_star*10**((prev_difficulty-curr_difficulty)*4+similarity*4)
            betas.append(beta)
            print(beta,'beta')
        else: 
            beta = beta_star

        if coreset_size > 0 and task_id>0:
            if (dor):
                zipped_and_indices = sorted(enumerate(zip(task_difficulties, coresets)), key=lambda x: x[1][0], reverse=True)

                sorted_task_nums = [index for index, _ in zipped_and_indices]
                sorted_difficulties = [pair[0] for _, pair in zipped_and_indices]
                replay_coresets = [pair[1] for _, pair in zipped_and_indices]
                replay_betas = [beta_star*10**(2-d) for d in sorted_difficulties]
                print(sorted_difficulties, replay_betas)
            else:
                replay_coresets = coresets
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(replay_coresets[i], batch_size=batch_size, shuffle=True)
                    if dor:
                        sorted_task_num = sorted_task_nums[i]
                        train(model, coreset_loader, optimizer, epoch, device, replay_betas[i], task_id=sorted_task_num, binary_label=binary_labels[sorted_task_num])
                    else:
                        train(model, coreset_loader, optimizer, epoch, device, beta_star, task_id=i, binary_label=binary_labels[i])
                model.update_priors()
        for epoch in (range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id, binary_label=binary_labels[task_id])
        model.update_priors()


        # for prediction
        prediction_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay

        task_difficulties.append(curr_difficulty)
        if coreset_size > 0:
            
            coresets.append(random_coreset(train_loader.dataset, coreset_size))
            if (dor):
                zipped_and_indices = sorted(enumerate(zip(task_difficulties, coresets)), key=lambda x: x[1][0], reverse=True)

                sorted_task_nums = [index for index, _ in zipped_and_indices]
                sorted_difficulties = [pair[0] for _, pair in zipped_and_indices]
                replay_coresets = [pair[1] for _, pair in zipped_and_indices]
                replay_betas = [beta_star*10**(2-d) for d in sorted_difficulties]
            else:
                replay_coresets = coresets
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(replay_coresets[i], batch_size=batch_size, shuffle=True)
                    if dor:
                        sorted_task_num = sorted_task_nums[i]
                        train(model, coreset_loader, optimizer, epoch, device, replay_betas[i], task_id=sorted_task_num, binary_label=binary_labels[sorted_task_num])
                    else:
                        train(model, coreset_loader, optimizer, epoch, device, beta_star, task_id=i, binary_label=binary_labels[i])
        task_num = 0  
        prev_test_loaders.append(test_loader)

        task_accuracies_rc = []
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num, binary_label=binary_labels[task_num])
            task_accuracies_rc.append(task_accuracy)
            task_num += 1
        
        prev_acc = task_accuracy
        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    if return_betas:
        return ave_acc_trend_rc, betas
    
    return ave_acc_trend_rc

In [9]:
def create_split_task(dataset, classes):
    """
    Create a binary classification task from the MNIST dataset.
    
    Parameters:
    - dataset: The original MNIST dataset (training or test).
    - classes: A tuple of two integers representing the classes to include in the split.
    
    Returns:
    - A Subset of the original dataset containing only the specified classes.
    """
    # Find indices of classes we're interested in
    indices = [i for i, (_, target) in enumerate(dataset) if target in classes]
    
    # Create a subset of the dataset with only the specified classes
    subset = Subset(dataset, indices)
    
    return subset

def create_split_dataloaders(train_dataset, test_dataset, tasks, batch_size=256):
    """
    Create DataLoaders for each binary task in Split MNIST.
    
    Parameters:
    - train_dataset: The MNIST training dataset.
    - test_dataset: The MNIST test dataset.
    - batch_size: The batch size for the DataLoader.
    
    Returns:
    - A list of tuples containing (train_loader, test_loader) for each binary task.
    """
    train_loaders = []
    test_loaders = []
    for task in tasks:
        # Create training subset and DataLoader
        train_subset = create_split_task(train_dataset, task)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

        # Create test subset and DataLoader
        test_subset = create_split_task(test_dataset, task)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)
    
    return train_loaders, test_loaders

In [10]:
import pandas as pd


def plot_trends(trends, title = 'Accuracy Trends in the Permuated MNIST Experiment', lower = 0.7):
    df = pd.DataFrame({
        '# of tasks': range(len(trends[0][0])),
        'beta = 0.01': np.mean(trends[0], axis = 0),
        'beta = 1': np.mean(trends[1], axis = 0),
        'beta = 100': np.mean(trends[2], axis = 0),
        'AutoVCL': np.mean(trends[3], axis = 0)
    })
    # Convert the DataFrame to long format
    df_long = df.melt('# of tasks', var_name='Series', value_name='Values')

    import altair as alt
    legend_order = ['beta = 0.01', 'beta = 1', 'beta = 100', 'AutoVCL']
    # Create the plot
    chart = alt.Chart(df_long).mark_line(point=True).encode(
        
        x=alt.X('# of tasks:Q', title='# tasks', axis=alt.Axis(values=list(range(df_long['# of tasks'].max() + 1)))), # Ensure integer ticks),
        y=alt.Y('Values:Q', scale=alt.Scale(domain=[lower, 1]), title='Accuracy'),
        color=alt.Color('Series:N', sort=legend_order,scale=alt.Scale(scheme='category10'), legend=alt.Legend(title="Model")),
        tooltip=['# of tasks', 'Values', 'Series']
    ).properties(
        width=800,
        height=400,
        title=title
    ).configure_axis(
        labelFontSize=15,
        titleFontSize=20
    ).configure_legend(
        labelFontSize=15,
        titleFontSize=15
    ).configure_title(
        fontSize=24
    )

    chart.display()

## Reproduce

In [32]:
epoch_per_task

10

In [54]:
vcl_results = {}
coreset_results = {}
for coreset_size in [200, 400, 1000, 2500, 5000]:
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    vcl_trend, coreset_trend = run_vcl(model,  pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, coreset_size, 
        beta=1, return_coreset_score=False)
    vcl_results[coreset_size] = vcl_trend
    coreset_results[coreset_size] = coreset_trend
vcl_trend =  run_vcl(model,  pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, coreset_size=0, beta=1)
vcl_results[0] = vcl_trend

Average Accuracy across 1 tasks: 97.05%
Average Accuracy across 2 tasks: 94.08%
Average Accuracy across 3 tasks: 92.56%


[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:110] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 

## split no core

In [55]:
tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
split_train_loaders, split_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=256)

In [16]:

coreset_size = 0
trends_1 = []
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = tasks)
    trends_1.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.73%
Average Accuracy across 3 tasks: 97.64%
Average Accuracy across 4 tasks: 97.37%
Average Accuracy across 5 tasks: 97.08%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.76%
Average Accuracy across 3 tasks: 98.83%
Average Accuracy across 4 tasks: 98.68%
Average Accuracy across 5 tasks: 96.74%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.46%
Average Accuracy across 3 tasks: 99.13%
Average Accuracy across 4 tasks: 96.26%
Average Accuracy across 5 tasks: 95.84%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.66%
Average Accuracy across 3 tasks: 98.82%
Average Accuracy across 4 tasks: 95.63%
Average Accuracy across 5 tasks: 92.63%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.59%
Average Accuracy across 3 tasks: 98.25%
Average Accuracy across 4 tasks: 95.73%
Average Accuracy across 5 tasks: 94.85%


In [17]:
trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = tasks)
    trends_2.append(trend)

Average Accuracy across 1 tasks: 99.86%
Average Accuracy across 2 tasks: 99.59%
Average Accuracy across 3 tasks: 98.99%
Average Accuracy across 4 tasks: 97.95%
Average Accuracy across 5 tasks: 92.80%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.64%
Average Accuracy across 3 tasks: 99.40%
Average Accuracy across 4 tasks: 98.87%
Average Accuracy across 5 tasks: 96.44%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.64%
Average Accuracy across 3 tasks: 98.37%
Average Accuracy across 4 tasks: 93.14%
Average Accuracy across 5 tasks: 97.01%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.54%
Average Accuracy across 3 tasks: 96.81%
Average Accuracy across 4 tasks: 94.45%
Average Accuracy across 5 tasks: 94.69%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.76%
Average Accuracy across 3 tasks: 98.92%
Average Accuracy across 4 tasks: 97.14%
Average Accuracy across 5 tasks: 92.71%


In [19]:

trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e2, binary_labels = tasks)
    trends_3.append(trend)

Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 99.02%
Average Accuracy across 3 tasks: 98.38%
Average Accuracy across 4 tasks: 98.63%
Average Accuracy across 5 tasks: 97.39%
Average Accuracy across 1 tasks: 99.86%
Average Accuracy across 2 tasks: 98.90%
Average Accuracy across 3 tasks: 98.65%
Average Accuracy across 4 tasks: 98.56%
Average Accuracy across 5 tasks: 97.83%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.05%
Average Accuracy across 3 tasks: 98.33%
Average Accuracy across 4 tasks: 98.41%
Average Accuracy across 5 tasks: 98.01%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 98.51%
Average Accuracy across 3 tasks: 98.68%
Average Accuracy across 4 tasks: 98.80%
Average Accuracy across 5 tasks: 98.00%
Average Accuracy across 1 tasks: 99.86%
Average Accuracy across 2 tasks: 98.78%
Average Accuracy across 3 tasks: 98.01%
Average Accuracy across 4 tasks: 98.57%
Average Accuracy across 5 tasks: 97.72%


In [38]:
coreset_size


200

In [37]:
trends_4 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_auto_vcl(model, 
        split_train_loaders,
        split_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks,
        dor = True)
    trends_4.append(trend)

0.9952718676122931
Average Accuracy across 1 tasks: 99.91%
0.9492654260528894
0
0.009456264775413725 0.10146914789422112 0 all
0.42849767274278683 beta
Average Accuracy across 2 tasks: 99.73%
0.9790821771611526
1
0.10146914789422112 0.04183564567769471 0.4677617709992642 all
128.70075447851292 beta
Average Accuracy across 3 tasks: 99.31%
0.9854984894259818
2
0.10146914789422112 0.029003021148036323 0 all
1.949236373939389 beta
Average Accuracy across 4 tasks: 98.99%
0.9598587997982854
3
0.10146914789422112 0.08028240040342927 0.2372270358654626 all
10.805755420837402 beta
Average Accuracy across 5 tasks: 98.44%
0.9955555555555555
Average Accuracy across 1 tasks: 99.91%
0.9564152791381
0
0.008888888888888946 0.08716944172380003 0.056914691989304614 all
0.8213656881538066 beta
Average Accuracy across 2 tasks: 99.56%
0.9702241195304161
1
0.08716944172380003 0.05955176093916781 0.505907912791584 all
136.17656475885408 beta
Average Accuracy across 3 tasks: 98.70%
0.9834843907351459
2
0.0871

In [66]:
import matplotlib.pyplot as plt
# def plot_trends(trends):
#     for t in trends:
#         plt.plot(range(len(t[0])),np.mean(t, axis = 0))
plot_trends([trends_1,trends_2,trends_3,trends_4,])

NameError: name 'trends_1' is not defined

## Intentially alike


In [226]:
tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
split_alike_train_loaders, split_alike_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)

In [190]:
coreset_size = 0
trends_alike_1 = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
    beta=1e-2, binary_labels = tasks)
    trends_alike_1.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.66%
Average Accuracy across 3 tasks: 99.46%
Average Accuracy across 4 tasks: 98.49%
Average Accuracy across 5 tasks: 98.56%


KeyboardInterrupt: 

In [None]:
coreset_size = 0
trends_alike_2 = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
        beta=1, binary_labels = tasks)
    trends_alike_2.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.51%
Average Accuracy across 3 tasks: 99.43%
Average Accuracy across 4 tasks: 98.89%
Average Accuracy across 5 tasks: 96.65%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.56%
Average Accuracy across 3 tasks: 98.74%
Average Accuracy across 4 tasks: 94.36%
Average Accuracy across 5 tasks: 96.41%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.67%
Average Accuracy across 3 tasks: 99.43%
Average Accuracy across 4 tasks: 94.82%
Average Accuracy across 5 tasks: 94.30%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.65%
Average Accuracy across 3 tasks: 99.45%
Average Accuracy across 4 tasks: 88.87%
Average Accuracy across 5 tasks: 96.86%
Average Accuracy across 1 tasks: 99.95%
Average Accuracy across 2 tasks: 99.46%
Average Accuracy across 3 tasks: 99.54%
Average Accuracy across 4 tasks: 97.31%
Average Accuracy across 5 tasks: 92.54%


In [191]:
coreset_size = 0
trends_alike_3 = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders,
     optimizer, epoch_per_task, coreset_size, 
        beta=1e2, binary_labels = tasks)
    trends_alike_3.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 99.23%
Average Accuracy across 3 tasks: 98.92%
Average Accuracy across 4 tasks: 99.34%
Average Accuracy across 5 tasks: 99.05%


KeyboardInterrupt: 

In [227]:
trends_alike_4 = []
alike_betas = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, alike_beta = run_auto_vcl(model, 
        split_alike_train_loaders,
        split_alike_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks,
        return_betas = True)
    trends_alike_4.append(trend)
    alike_betas.append(alike_beta)

0.9667612293144208
Average Accuracy across 1 tasks: 99.91%
0.8170421155729677
0
0.5416258570029383 raw_pred
0.06647754137115847 0.3659157688540646 0.015254888028760561 all
0.0729905510109667 beta
Average Accuracy across 2 tasks: 99.66%
0.7935965848452508
1
0.7764140875133404 raw_pred
0.3659157688540646 0.4128068303094985 0.6290822684137637 all
213.18901736295032 beta
Average Accuracy across 3 tasks: 99.07%
0.9110271903323263
2
0.47532729103726085 raw_pred
0.4128068303094985 0.17794561933534747 0.010916041445929923 all
9.618536772691854 beta
Average Accuracy across 4 tasks: 97.50%
0.6989914271306101
3
0.7231467473524962 raw_pred
0.4128068303094985 0.6020171457387797 0.3688705894889076 all
5.231679085537842 beta
Average Accuracy across 5 tasks: 98.57%
0.9621749408983451


KeyboardInterrupt: 

In [185]:
model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(model, split_alike_train_loaders[0], optimizer, 1, device, 1, 0, binary_label=[0,1])

In [188]:
test(model, split_alike_test_loaders[1], device,task_id=0, binary_label=[2,1])

(0.014337299053824108, 0.7378864790032302)

In [170]:
split_alike_train_loaders[1].dataset[2][1]

7

In [144]:
plot_trends([trends_alike_1, trends_alike_2, trends_alike_3, trends_alike_4], lower = 0.9)

In [None]:
f_mnist_train_dataset = datasets.FashionMNIST(
    root='data', 
    train=True, 
    download=True, 
    transform=transform_mnist
)

f_mnist_test_dataset = datasets.FashionMNIST(
    root='data', 
    train=False, 
    download=True, 
    transform=transform_mnist
)

## different difficulties

In [126]:
transform_cifar = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # Convert image to grayscale
    transforms.Resize((28, 28)),
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))])

# Load the CIFAR-10 training dataset with the defined transform
cifar_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)

# Load the CIFAR-10 test dataset with the defined transform
cifar_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)


Files already downloaded and verified
Files already downloaded and verified


In [127]:
tasks = [(0,1),(2,3),(4,5),(6,7),(8,9)]
mixed_tasks = [tasks[i//2] for i in range(len(tasks)*2)]

In [129]:
torch.manual_seed(SEED)
mnist_train_loaders, mnist_test_loaders = \
    create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)
cifar_train_loaders, cifar_test_loaders = \
    create_split_dataloaders(cifar_train_dataset, cifar_test_dataset, tasks, batch_size=batch_size)

In [130]:
mixed_train_loaders = [mnist_train_loaders, cifar_train_loaders]
mixed_test_loaders = [mnist_test_loaders, cifar_test_loaders]

mixed_train_loaders = [mixed_train_loaders[i%2][i//2] for i in range(len(mixed_tasks))]
mixed_test_loaders = [mixed_test_loaders[i%2][i//2] for i in range(len(mixed_tasks))]

In [131]:
coreset_size= 0
torch.manual_seed(SEED)
mixed_trends_1 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = mixed_tasks)
    mixed_trends_1.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.26%
Average Accuracy across 3 tasks: 93.34%
Average Accuracy across 4 tasks: 83.85%
Average Accuracy across 5 tasks: 83.81%
Average Accuracy across 6 tasks: 82.11%
Average Accuracy across 7 tasks: 77.71%
Average Accuracy across 8 tasks: 79.99%
Average Accuracy across 9 tasks: 82.04%
Average Accuracy across 10 tasks: 79.90%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.48%
Average Accuracy across 3 tasks: 93.87%
Average Accuracy across 4 tasks: 86.28%
Average Accuracy across 5 tasks: 83.17%
Average Accuracy across 6 tasks: 82.75%
Average Accuracy across 7 tasks: 79.27%
Average Accuracy across 8 tasks: 79.36%
Average Accuracy across 9 tasks: 80.23%
Average Accuracy across 10 tasks: 77.54%
Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 93.29%
Average Accuracy across 3 tasks: 92.39%
Average Accuracy across 4 tasks: 86.93%
Average Accuracy across 5 tasks: 85.35

In [132]:
torch.manual_seed(SEED)
mixed_trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = mixed_tasks)
    mixed_trends_2.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.28%
Average Accuracy across 3 tasks: 94.46%
Average Accuracy across 4 tasks: 85.02%
Average Accuracy across 5 tasks: 87.14%
Average Accuracy across 6 tasks: 82.45%
Average Accuracy across 7 tasks: 83.87%
Average Accuracy across 8 tasks: 80.21%
Average Accuracy across 9 tasks: 82.58%
Average Accuracy across 10 tasks: 81.63%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 93.23%
Average Accuracy across 3 tasks: 94.03%
Average Accuracy across 4 tasks: 85.22%
Average Accuracy across 5 tasks: 87.65%
Average Accuracy across 6 tasks: 84.47%
Average Accuracy across 7 tasks: 81.84%
Average Accuracy across 8 tasks: 83.43%
Average Accuracy across 9 tasks: 83.55%
Average Accuracy across 10 tasks: 82.06%
Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 93.03%
Average Accuracy across 3 tasks: 92.69%
Average Accuracy across 4 tasks: 87.62%
Average Accuracy across 5 tasks: 85.79

In [133]:
torch.manual_seed(SEED)
mixed_trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e2, binary_labels = mixed_tasks)
    mixed_trends_3.append(trend)

Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 89.58%
Average Accuracy across 3 tasks: 91.79%
Average Accuracy across 4 tasks: 84.56%
Average Accuracy across 5 tasks: 87.47%
Average Accuracy across 6 tasks: 84.63%
Average Accuracy across 7 tasks: 86.83%
Average Accuracy across 8 tasks: 84.39%
Average Accuracy across 9 tasks: 85.76%
Average Accuracy across 10 tasks: 85.10%
Average Accuracy across 1 tasks: 99.91%
Average Accuracy across 2 tasks: 89.53%
Average Accuracy across 3 tasks: 92.17%
Average Accuracy across 4 tasks: 85.96%
Average Accuracy across 5 tasks: 88.29%
Average Accuracy across 6 tasks: 85.02%
Average Accuracy across 7 tasks: 86.78%
Average Accuracy across 8 tasks: 85.79%
Average Accuracy across 9 tasks: 86.67%
Average Accuracy across 10 tasks: 86.10%
Average Accuracy across 1 tasks: 99.76%
Average Accuracy across 2 tasks: 88.05%
Average Accuracy across 3 tasks: 90.87%
Average Accuracy across 4 tasks: 85.06%
Average Accuracy across 5 tasks: 87.75

In [228]:
mixed_trends_4 = []
mixed_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, m_beta = run_auto_vcl(model, 
        mixed_train_loaders,
        mixed_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = mixed_tasks,
        return_betas = True)
    mixed_trends_4.append(trend)
    mixed_betas.append(m_beta)

0.9836406619385342
Average Accuracy across 1 tasks: 99.91%
0.7062000000000002
0
0.578 raw_pred
0.03271867612293167 0.5875999999999997 0.031068484159667384 all
0.00803061195761599 beta
Average Accuracy across 2 tasks: 93.50%
0.7895200783545544
1
0.4681684622918707 raw_pred
0.5875999999999997 0.4209598432908912 0.012575255575269687 all
5.210287033464248 beta
Average Accuracy across 3 tasks: 94.32%
0.5709500000000001
2
0.4655 raw_pred
0.5875999999999997 0.8580999999999999 0.013255481427098773 all
0.0935452900052146 beta
Average Accuracy across 4 tasks: 86.19%
0.7798292422625399
3
0.4018143009605123 raw_pred
0.8580999999999999 0.4403415154749202 0.0458132516264789 all
71.49685600561959 beta
Average Accuracy across 5 tasks: 86.88%
0.61615
4
0.4495 raw_pred
0.8580999999999999 0.7677 0.018163691028432512 all
2.718033243194161 beta
Average Accuracy across 6 tasks: 84.86%
0.8834340382678751
5
0.49244712990936557 raw_pred
0.8580999999999999 0.2331319234642497 0.007775723965440524 all
339.6058807

In [224]:
scale_similarity(0.2, 0,1-0.5)

0.26894142136999516

In [None]:
# mixed_trends_5 = []
# for i in range(5):
#     model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#     trend = run_auto_vcl(model, 
#         mixed_train_loaders,
#         mixed_test_loaders,
#         optimizer, 
#         epoch_per_task, 
#         coreset_size,
#         binary_labels = mixed_tasks,
#         dor = True)
#     mixed_trends_5.append(trend)

In [229]:
plot_trends([mixed_trends_1, mixed_trends_2, mixed_trends_3, mixed_trends_4])

P-MNIST

In [240]:
p_trends_1 = []
torch.manual_seed(SEED)
coreset_size = 0
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2)
    p_trends_1.append(trend)

Average Accuracy across 1 tasks: 97.26%
Average Accuracy across 2 tasks: 82.98%


KeyboardInterrupt: 

In [246]:
coreset_size

0

In [18]:
p_trends_2 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=False).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1)
    p_trends_2.append(trend)

Average Accuracy across 1 tasks: 97.41%
Average Accuracy across 2 tasks: 93.11%


KeyboardInterrupt: 

In [None]:
p_trends_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e2)
    p_trends_3.append(trend)

Average Accuracy across 1 tasks: 97.80%
Average Accuracy across 2 tasks: 90.43%
Average Accuracy across 3 tasks: 88.51%
Average Accuracy across 4 tasks: 87.55%
Average Accuracy across 5 tasks: 86.75%
Average Accuracy across 6 tasks: 86.48%
Average Accuracy across 7 tasks: 86.20%
Average Accuracy across 8 tasks: 86.03%
Average Accuracy across 9 tasks: 85.61%
Average Accuracy across 10 tasks: 85.40%
Average Accuracy across 1 tasks: 97.71%
Average Accuracy across 2 tasks: 90.53%
Average Accuracy across 3 tasks: 88.20%
Average Accuracy across 4 tasks: 86.75%
Average Accuracy across 5 tasks: 86.31%
Average Accuracy across 6 tasks: 86.01%
Average Accuracy across 7 tasks: 85.85%
Average Accuracy across 8 tasks: 85.56%
Average Accuracy across 9 tasks: 85.32%
Average Accuracy across 10 tasks: 85.15%
Average Accuracy across 1 tasks: 97.74%
Average Accuracy across 2 tasks: 90.80%
Average Accuracy across 3 tasks: 88.24%
Average Accuracy across 4 tasks: 87.25%
Average Accuracy across 5 tasks: 86.77

In [230]:
p_trends_4 = []
p_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, p_beta= run_auto_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, 
        epoch_per_task, coreset_size, return_betas=True)
    p_trends_4.append(trend)
    p_betas.append(p_beta)

0.24968999999999997
Average Accuracy across 1 tasks: 97.30%
0.24671000000000004
0
0.1387 raw_pred
0.8336777777777779 0.8369888888888889 0.00026753301764113245 all
0.9723568714678044 beta
Average Accuracy across 2 tasks: 94.41%
0.22808
1
0.112 raw_pred
0.8369888888888889 0.857688888888889 0.0001568599996668718 all
0.8276132594156334 beta
Average Accuracy across 3 tasks: 90.74%
0.24731999999999998
2
0.0719 raw_pred
0.857688888888889 0.8363111111111111 0.00021643582836173894 all
1.2200478232689953 beta
Average Accuracy across 4 tasks: 84.60%
0.2266
3
0.1234 raw_pred
0.857688888888889 0.8593333333333333 0.00019702162945985427 all
0.9867572108509809 beta


KeyboardInterrupt: 

In [None]:
plot_trends([p_trends_1, p_trends_2,p_trends_3, p_trends_4])

In [53]:
pc_trends_1 = []
coreset_size = 1000
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2)
    pc_trends_1.append(trend)
pc_trends_2 = []
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, beta=1)
    pc_trends_2.append(trend)
pc_trends_3 = []
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e2)
    pc_trends_3.append(trend)
pc_trends_4 = []
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend  = run_auto_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, 
            raw_training_epoch = raw_training_epoch,
            raw_train_size= raw_train_size)
    pc_trends_4.append(trend)

Average Accuracy across 1 tasks: 99.91%


KeyboardInterrupt: 

In [None]:
plot_trends([pc_trends_1, pc_trends_2, pc_trends_3, pc_trends_4])