In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class MFVI_Layer(nn.Module):
    def __init__(self, in_features, out_features):
        super(MFVI_Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Mean parameters
        self.W_m = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_m = nn.Parameter(torch.Tensor(out_features))

        # Log variance parameters
        self.W_logv = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_logv = nn.Parameter(torch.Tensor(out_features))

        # Prior distributions (initialized to None and will be set in reset_parameters)
        self.prior_W_m = None
        self.prior_b_m = None
        self.prior_W_logv = None
        self.prior_b_logv = None

        # Initialize parameters
        self.reset_parameters()

    def reset_parameters(self):
        self.W_m.data.normal_(0, 0.1)
        self.b_m.data.normal_(0, 0.1)
        self.W_logv.data.fill_(-6.0)
        self.b_logv.data.fill_(-6.0)

        # Initially set priors to match the initial parameters
        self.set_priors()

    def set_priors(self):
        # Update or set the prior distributions to the current parameters
        self.prior_W_m = self.W_m.detach().clone()
        self.prior_b_m = self.b_m.detach().clone()
        self.prior_W_logv = self.W_logv.detach().clone()
        self.prior_b_logv = self.b_logv.detach().clone()

    def forward(self, x, sample=True):
        # Calculate standard deviations from log variances
        W_std = torch.exp(0.5 * self.W_logv)
        b_std = torch.exp(0.5 * self.b_logv)
        
        # Calculate the output mean
        act_mu = F.linear(x, self.W_m, self.b_m)
        
        # Calculate the output variance
        # The variance of the output is given by the sum of variances of the weighted inputs (assuming independence)
        act_var = 1e-16 + F.linear(x.pow(2), W_std.pow(2)) + b_std.pow(2)
        act_std = torch.sqrt(act_var)

        if self.training or sample:
            # If in training mode or sample is True, sample from the posterior
            eps = torch.randn_like(act_mu)
            return act_mu + act_std * eps
        else:
            # Otherwise, return the mean of the posterior
            return act_mu


    def kl_divergence(self,device):
        # Calculate the number of parameters for normalization
        self.update_prior_device(device)
        num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

        # Convert log variances to standard deviations
        W_std = torch.exp(0.5 * self.W_logv)
        b_std = torch.exp(0.5 * self.b_logv)
        prior_W_std = torch.exp(0.5 * self.prior_W_logv)
        prior_b_std = torch.exp(0.5 * self.prior_b_logv)

        # Calculate KL divergence for weights using the standard deviation-based formula
        kl_W = 0.5 * (2 * torch.log(prior_W_std / W_std) - 1 + (W_std / prior_W_std).pow(2) + ((self.W_m - self.prior_W_m) / prior_W_std).pow(2)).sum()

        # Calculate KL divergence for biases using the standard deviation-based formula
        kl_b = 0.5 * (2 * torch.log(prior_b_std / b_std) - 1 + (b_std / prior_b_std).pow(2) + ((self.b_m - self.prior_b_m) / prior_b_std).pow(2)).sum()

        # Return the normalized KL divergence
        return (kl_W + kl_b) /num_params

    def update_prior_device(self, device):
        self.prior_W_m = self.prior_W_m.to(device)
        self.prior_b_m = self.prior_b_m.to(device)
        self.prior_W_logv = self.prior_W_logv.to(device)
        self.prior_b_logv = self.prior_b_logv.to(device)



class MFVI_NN(nn.Module):
    def __init__(self, input_size, hidden_sizes,  output_size, no_train_samples=10, no_pred_samples=100, num_tasks = 1):
        super(MFVI_NN, self).__init__()

        self.layers = nn.ModuleList()
        self.task_specific_layers = nn.ModuleDict()  # Using ModuleDict to hold task-specific layers
        self.no_train_samples = no_train_samples
        self.no_pred_samples = no_pred_samples

        # Define shared layers
        sizes = [input_size] + hidden_sizes
        for i in range(len(sizes)-1):
            self.layers.append(MFVI_Layer(sizes[i], sizes[i+1]))

        # Define task-specific output layers
        for task_id in range(num_tasks):
            self.task_specific_layers[str(task_id)] = MFVI_Layer(sizes[-1], output_size)

    def forward(self, x, task_id=0, sample=True):
        for layer in self.layers:
            x = F.relu(layer(x, sample))

        # Select and apply the task-specific output layer
        task_layer = self.task_specific_layers[str(task_id)]
        x = task_layer(x, sample)
        return x

    def kl_divergence(self):
        kl_div = 0
        # Sum KL divergence from shared layers
        for layer in self.layers:
            kl_div += layer.kl_divergence(next(self.parameters()).device)

        # Sum KL divergence from task-specific layers
        for task_layer in self.task_specific_layers.values():
            kl_div += task_layer.kl_divergence(next(self.parameters()).device)

        return kl_div


    def update_priors(self):
        # Update priors in each shared layer
        for layer in self.layers:
            layer.set_priors()  # Assuming each MFVI_Layer has a method called set_priors
        
        for task_layer in self.task_specific_layers.values():
            task_layer.set_priors()


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

SEED = 42
torch.manual_seed(SEED)  # Set the random seed for PyTorch

class MFVI_Layer(nn.Module):
    def __init__(self, in_features, out_features):
        super(MFVI_Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Mean and log variance parameters
        self.W_m = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_m = nn.Parameter(torch.Tensor(out_features))
        self.W_logv = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_logv = nn.Parameter(torch.Tensor(out_features))

        # Prior distributions initialized to None
        self.prior_W_m = self.prior_b_m = self.prior_W_logv = self.prior_b_logv = None

        self.reset_parameters()

    def reset_parameters(self):
        # Initialize mean parameters to a normal distribution and log variance parameters to a small value
        self.W_m.data.normal_(0, 0.1)
        self.b_m.data.normal_(0, 0.1)
        self.W_logv.data.fill_(-6.0)
        self.b_logv.data.fill_(-6.0)
        self.set_priors()

    def set_priors(self):
        # Set priors to the current parameters
        self.prior_W_m = self.W_m.detach().clone()
        self.prior_b_m = self.b_m.detach().clone()
        self.prior_W_logv = self.W_logv.detach().clone()
        self.prior_b_logv = self.b_logv.detach().clone()

    def forward(self, x, sample=True):
        W_std = torch.exp(0.5 * self.W_logv)
        b_std = torch.exp(0.5 * self.b_logv)
        
        act_mu = F.linear(x, self.W_m, self.b_m)
        act_var = 1e-16 + F.linear(x.pow(2), W_std.pow(2)) + b_std.pow(2)
        act_std = torch.sqrt(act_var)

        if self.training or sample:
            eps = torch.randn_like(act_mu)
            return act_mu + act_std * eps
        else:
            return act_mu

    def kl_divergence(self, device):
        self.update_prior_device(device)
        num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        # Convert log variance to standard deviation for posterior and prior
        W_std_post = torch.exp(0.5 * self.W_logv)
        b_std_post = torch.exp(0.5 * self.b_logv)
        W_std_prior = torch.exp(0.5 * self.prior_W_logv)
        b_std_prior = torch.exp(0.5 * self.prior_b_logv)

        # Calculate KL divergence for weights
        kl_div_W = torch.log(W_std_prior / W_std_post) + \
                ((W_std_post**2 + (self.W_m - self.prior_W_m)**2) / (2 * W_std_prior**2)) - 0.5
        # Sum over all elements
        kl_div_W = kl_div_W.sum()

        # Calculate KL divergence for biases
        kl_div_b = torch.log(b_std_prior / b_std_post) + \
                ((b_std_post**2 + (self.b_m - self.prior_b_m)**2) / (2 * b_std_prior**2)) - 0.5
        # Sum over all elements
        kl_div_b = kl_div_b.sum()

        # Total KL divergence
        total_kl = kl_div_W + kl_div_b

        return total_kl/num_params

    def update_prior_device(self, device):
        # Ensure priors are moved to the correct device
        self.prior_W_m = self.prior_W_m.to(device)
        self.prior_b_m = self.prior_b_m.to(device)
        self.prior_W_logv = self.prior_W_logv.to(device)
        self.prior_b_logv = self.prior_b_logv.to(device)

class MFVI_NN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, num_tasks=1, single_head=False):
        super(MFVI_NN, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.num_tasks = num_tasks
        self.single_head = single_head
        self.layers = nn.ModuleList()
        self.task_specific_layers = nn.ModuleDict()  # Using ModuleDict for task-specific layers

        # Construct shared layers
        sizes = [input_size] + hidden_sizes
        for i in range(len(sizes) - 1):
            self.layers.append(MFVI_Layer(sizes[i], sizes[i + 1]))

        # Construct task-specific output layers
        if single_head:
            self.task_specific_layers[str(0)] = MFVI_Layer(sizes[-1], output_size)
        else:
            for task_id in range(num_tasks):
                self.task_specific_layers[str(task_id)] = MFVI_Layer(sizes[-1], output_size)

    def forward(self, x, task_id=0, sample=True):
        for layer in self.layers:
            x = F.relu(layer(x, sample=sample))
        if self.single_head:
            x = self.task_specific_layers["0"](x, sample=sample)
        else:
            x = self.task_specific_layers[str(task_id)](x, sample=sample)
        return x

    def kl_divergence(self):
        kl_div = 0
        # Accumulate KL divergence from shared and task-specific layers
        for layer in self.layers:
            kl_div += layer.kl_divergence(next(self.parameters()).device)
        # for task_layer in self.task_specific_layers.values():
        #     kl_div += task_layer.kl_divergence(next(self.parameters()).device)
        return kl_div

    def update_priors(self):
        # Update priors in each layer
        for layer in self.layers + list(self.task_specific_layers.values()):
            layer.set_priors()


In [5]:
import torch.nn.functional as F

def train(model, trainloader, optimizer, epoch, device, kl_weight=1, task_id = 0, binary_label = None):
    model.train()
    for batch_idx, (data, target) in enumerate(trainloader):
        data, target = data.to(device), target.to(device)
        data = data.view(data.size(0), -1)  # Flatten the images
        optimizer.zero_grad()
        output = model(data, sample=True, task_id = task_id)
        if binary_label != None:
            target = (target == binary_label[0]).long()
        reconstruction_loss = F.cross_entropy(output, target, reduction='mean')
        
#         print(kl_divergence,reconstruction_loss)
        if task_id == 0:
            loss = reconstruction_loss
        else:
            kl_divergence = model.kl_divergence()
            loss = reconstruction_loss + kl_divergence * kl_weight
        loss.backward()
        optimizer.step()
#     print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(trainloader.dataset)} ({100. * batch_idx / len(trainloader):.0f}%)]\tLoss: {loss.item()}")

def test(model, testloader, device, task_id = 0, binary_label = None):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            data = data.view(data.size(0), -1)  # Flatten the images
            output = model(data, sample=False, task_id = task_id)
            if binary_label != None:
                target = (target == binary_label[0]).long()

            # Use cross_entropy for test loss calculation, sum up batch loss
            test_loss += F.cross_entropy(output, target, reduction='mean').item()
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(testloader.dataset)
    # print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(testloader.dataset)} ({100. * correct / len(testloader.dataset):.0f}%)\n')
    return test_loss, correct / len(testloader.dataset)


## Permuted MNIST

In [6]:
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [7]:
def permute_mnist(mnist, perm):
    """Apply a fixed permutation to the pixels of each image in the dataset."""
    permuted_data = []
    for img, target in mnist:
        # Flatten the image, apply permutation and reshape back to 1x28x28
        img_permuted = img.view(-1)[perm].view(1, 28, 28)
        permuted_data.append((img_permuted, target))
    return permuted_data

# Initialize lists to store the permuted datasets
permuted_mnist_train_datasets = []
permuted_mnist_test_datasets = []

# Generate 10 permuted datasets
for _ in tqdm(range(10)):
    # Generate a fixed permutation
    fixed_permutation = torch.randperm(784)
    
    # Apply this permutation to the train and test datasets
    permuted_train = permute_mnist(mnist_trainset, fixed_permutation)
    permuted_test = permute_mnist(mnist_testset, fixed_permutation)
    
    # Store the permuted datasets
    permuted_mnist_train_datasets.append(permuted_train)
    permuted_mnist_test_datasets.append(permuted_test)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:33<00:00,  3.38s/it]


NameError: name 'DataLoader' is not defined

In [18]:

from torch.utils.data import DataLoader


batch_size = 64
permuted_mnist_train_loaders = [DataLoader(m, batch_size=batch_size, shuffle=True) for m in permuted_mnist_train_datasets]
 
permuted_mnist_test_loaders = [DataLoader(m, batch_size=batch_size, shuffle=False) for m in permuted_mnist_test_datasets]

In [None]:
import torch
torch.manual_seed(SEED)
def generate_permutations(task_count, image_size):
    permutations = [torch.randperm(image_size) for _ in range(task_count-1)]
    return permutations

task_count = 10
image_size = 28 * 28  # MNIST images are 28x28
permutations = generate_permutations(task_count, image_size)

In [None]:
from torch.utils.data import Dataset

class PermutedMNIST(Dataset):
    def __init__(self, mnist_dataset, permutation=None):
        self.mnist_dataset = mnist_dataset
        self.permutation = permutation

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        image, label = self.mnist_dataset[idx]
        if self.permutation is not None:
            # Apply permutation
            image = image.view(-1)[self.permutation].view(1, 28, 28)
        return image, label


In [None]:
from torch.utils.data import DataLoader

# Create a DataLoader for the original MNIST
pmnist_train_loaders = [DataLoader(mnist_trainset, batch_size=batch_size, shuffle=True)]
pmnist_test_loaders = [DataLoader(mnist_testset,batch_size=batch_size, shuffle=False)]
pmnist_data_loader = [(pmnist_train_loaders[0], pmnist_test_loaders[0])]
# Create DataLoaders for permuted tasks
for perm in permutations:
    permuted_train = PermutedMNIST(mnist_trainset, permutation=perm)
    permuted_test = PermutedMNIST(mnist_testset, permutation=perm)

    train_loader = DataLoader(permuted_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(permuted_test, batch_size=batch_size, shuffle=False)

    pmnist_train_loaders.append(train_loader)
    pmnist_test_loaders.append(test_loader)
    pmnist_data_loader.append((train_loader,test_loader))


In [11]:
from tqdm import tqdm
# Assuming model, optimizer, train_loaders, test_loaders are defined
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

In [12]:
from torch.utils.data import DataLoader, ConcatDataset, Subset
def random_coreset(dataset, coreset_size):
    """
    Randomly selects a subset of data points to form a coreset.

    Args:
    - dataset (torch.utils.data.Dataset): The dataset to sample from.
    - coreset_size (int): The number of samples to include in the coreset.

    Returns:
    - coreset_indices (torch.Tensor): Indices of the selected samples.
    """
    # Ensure coreset size does not exceed dataset size
    coreset_size = min(coreset_size, len(dataset))
    
    # Randomly select indices without replacement
    coreset_indices = np.random.choice(len(dataset), size=coreset_size, replace=False)
    
    # Convert numpy array to torch tensor
    coreset_indices = torch.from_numpy(coreset_indices)
    
    coreset = Subset(dataset, coreset_indices)
    return coreset



In [13]:
from tqdm import tqdm
epoch_per_task = 20

def run_vcl(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size=0, beta=1, binary_labels = None):
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    if binary_labels is None:
        binary_labels = [None] * model.num_tasks
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        task_accuracies_rc = []
        # if coreset_size > 0:
        #     for i in (range(len(coresets))):
        #         for epoch in (range(1, epoch_per_task + 1)):
        #             coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
        #             train(model, coreset_loader, optimizer, epoch, device, beta, task_id=i, binary_label=binary_labels[i])
        #         model.update_priors()
        for epoch in (range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id, binary_label=binary_labels[task_id])
        model.update_priors()


        # for prediction
        prediction_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay
        # if coreset_size > 0:
        #     coresets.append(random_coreset(train_loader.dataset, coreset_size))
        #     for i in (range(len(coresets))):
        #         for epoch in (range(1, epoch_per_task + 1)):
        #             coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
        #             train(prediction_model, coreset_loader, optimizer, epoch, device, beta, task_id=i, binary_label=binary_labels[i])
        task_num = 0  
        prev_test_loaders.append(test_loader)
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num, binary_label=binary_labels[task_num])
            task_accuracies_rc.append(task_accuracy)
            task_num += 1
        print(task_accuracies_rc)
        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    return ave_acc_trend_rc

def scale_similarity(sim, a, b):
    return 1/(1+np.exp(-20*(sim-(a+b)/2)))

def run_auto_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, 
                beta_star=1, raw_training_epoch = 5,raw_train_size = 1000, 
                binary_labels = None, dor = False, return_betas = False):
    task_difficulties = []
    
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    betas = []
    if binary_labels is None:
        binary_labels = [None] * model.output_size
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        raw_acc = []
        for i in range(10):
            raw_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
            ## raw training

            raw_trainset = random_coreset(train_loader.dataset, raw_train_size)
            raw_train_loader = DataLoader(raw_trainset, batch_size=batch_size, shuffle=True)
            raw_optimizer = torch.optim.Adam(raw_model.parameters(), lr=0.001)

            for epoch in (range(1, raw_training_epoch + 1)):
                train(raw_model, raw_train_loader, raw_optimizer, epoch, device, beta_star, task_id=0, binary_label=binary_labels[task_id])
            _, acc_simple_train = test(raw_model, test_loader, device,task_id=0, binary_label=binary_labels[task_id])
            raw_acc.append(acc_simple_train)
        
        acc_simple_train = np.mean(raw_acc)
        print(acc_simple_train)

        dummy_pred = 1/model.output_size
        curr_difficulty = min(max((1-(acc_simple_train - dummy_pred)/(1-dummy_pred)),0),1)
        if task_id > 0:
            print(task_id-1)
            _, raw_pred = test(model, test_loader, device,task_id=task_id-1, binary_label=binary_labels[task_id])
            print(raw_pred,'raw_pred')
            
            # similarity = min(max(np.abs(raw_pred - dummy_pred)/(prev_acc- dummy_pred),0),1)
            similarity = scale_similarity(np.abs(raw_pred-dummy_pred), 0, 1-dummy_pred)
            prev_difficulty = np.max(task_difficulties)
            print(prev_difficulty, curr_difficulty,similarity,'all')
            beta = beta_star*10**((prev_difficulty-curr_difficulty)*4+similarity*4)
            betas.append(beta)
            print(beta,'beta')
        else: 
            beta = beta_star

        if coreset_size > 0 and task_id>0:
            if (dor):
                zipped_and_indices = sorted(enumerate(zip(task_difficulties, coresets)), key=lambda x: x[1][0], reverse=True)

                sorted_task_nums = [index for index, _ in zipped_and_indices]
                sorted_difficulties = [pair[0] for _, pair in zipped_and_indices]
                replay_coresets = [pair[1] for _, pair in zipped_and_indices]
                replay_betas = [beta_star*10**(2-d) for d in sorted_difficulties]
                print(sorted_difficulties, replay_betas)
            else:
                replay_coresets = coresets
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(replay_coresets[i], batch_size=batch_size, shuffle=True)
                    if dor:
                        sorted_task_num = sorted_task_nums[i]
                        train(model, coreset_loader, optimizer, epoch, device, replay_betas[i], task_id=sorted_task_num, binary_label=binary_labels[sorted_task_num])
                    else:
                        train(model, coreset_loader, optimizer, epoch, device, beta_star, task_id=i, binary_label=binary_labels[i])
                model.update_priors()
        for epoch in (range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id, binary_label=binary_labels[task_id])
        model.update_priors()


        # for prediction
        prediction_model = type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay

        task_difficulties.append(curr_difficulty)
        if coreset_size > 0:
            
            coresets.append(random_coreset(train_loader.dataset, coreset_size))
            if (dor):
                zipped_and_indices = sorted(enumerate(zip(task_difficulties, coresets)), key=lambda x: x[1][0], reverse=True)

                sorted_task_nums = [index for index, _ in zipped_and_indices]
                sorted_difficulties = [pair[0] for _, pair in zipped_and_indices]
                replay_coresets = [pair[1] for _, pair in zipped_and_indices]
                replay_betas = [beta_star*10**(2-d) for d in sorted_difficulties]
            else:
                replay_coresets = coresets
            for i in (range(len(coresets))):
                for epoch in (range(1, epoch_per_task + 1)):
                    coreset_loader = DataLoader(replay_coresets[i], batch_size=batch_size, shuffle=True)
                    if dor:
                        sorted_task_num = sorted_task_nums[i]
                        train(model, coreset_loader, optimizer, epoch, device, replay_betas[i], task_id=sorted_task_num, binary_label=binary_labels[sorted_task_num])
                    else:
                        train(model, coreset_loader, optimizer, epoch, device, beta_star, task_id=i, binary_label=binary_labels[i])
        task_num = 0  
        prev_test_loaders.append(test_loader)

        task_accuracies_rc = []
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num, binary_label=binary_labels[task_num])
            task_accuracies_rc.append(task_accuracy)
            task_num += 1
        
        prev_acc = task_accuracy
        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    if return_betas:
        return ave_acc_trend_rc, betas
    
    return ave_acc_trend_rc

In [None]:
def create_split_task(dataset, classes):
    """
    Create a binary classification task from the MNIST dataset.
    
    Parameters:
    - dataset: The original MNIST dataset (training or test).
    - classes: A tuple of two integers representing the classes to include in the split.
    
    Returns:
    - A Subset of the original dataset containing only the specified classes.
    """
    # Find indices of classes we're interested in
    indices = [i for i, (_, target) in enumerate(dataset) if target in classes]
    
    # Create a subset of the dataset with only the specified classes
    subset = Subset(dataset, indices)
    
    return subset

def create_split_dataloaders(train_dataset, test_dataset, tasks, batch_size=256):
    """
    Create DataLoaders for each binary task in Split MNIST.
    
    Parameters:
    - train_dataset: The MNIST training dataset.
    - test_dataset: The MNIST test dataset.
    - batch_size: The batch size for the DataLoader.
    
    Returns:
    - A list of tuples containing (train_loader, test_loader) for each binary task.
    """
    train_loaders = []
    test_loaders = []
    for task in tasks:
        # Create training subset and DataLoader
        train_subset = create_split_task(train_dataset, task)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

        # Create test subset and DataLoader
        test_subset = create_split_task(test_dataset, task)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

        train_loaders.append(train_loader)
        test_loaders.append(test_loader)
    
    return train_loaders, test_loaders

In [None]:
import pandas as pd


def plot_trends(trends, title = 'Accuracy Trends in the Permuated MNIST Experiment', lower = 0.7):
    df = pd.DataFrame({
        '# of tasks': range(len(trends[0][0])),
        'beta = 0.01': np.mean(trends[0], axis = 0),
        'beta = 1': np.mean(trends[1], axis = 0),
        'beta = 100': np.mean(trends[2], axis = 0),
        'AutoVCL': np.mean(trends[3], axis = 0)
    })
    # Convert the DataFrame to long format
    df_long = df.melt('# of tasks', var_name='Series', value_name='Values')

    import altair as alt
    legend_order = ['beta = 0.01', 'beta = 1', 'beta = 100', 'AutoVCL']
    # Create the plot
    chart = alt.Chart(df_long).mark_line(point=True).encode(
        
        x=alt.X('# of tasks:Q', title='# tasks', axis=alt.Axis(values=list(range(df_long['# of tasks'].max() + 1)))), # Ensure integer ticks),
        y=alt.Y('Values:Q', scale=alt.Scale(domain=[lower, 1]), title='Accuracy'),
        color=alt.Color('Series:N', sort=legend_order,scale=alt.Scale(scheme='category10'), legend=alt.Legend(title="Model")),
        tooltip=['# of tasks', 'Values', 'Series']
    ).properties(
        width=800,
        height=400,
        title=title
    ).configure_axis(
        labelFontSize=15,
        titleFontSize=20
    ).configure_legend(
        labelFontSize=15,
        titleFontSize=15
    ).configure_title(
        fontSize=24
    )

    chart.display()

## Reproduce

In [23]:
epoch_per_task

10

## split no core

In [13]:
tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
split_train_loaders, split_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=256)

In [14]:

coreset_size = 0
trends_1 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = tasks)
    trends_1.append(trend)

[0.9985815602836879]
Average Accuracy across 1 tasks: 99.86%
[0.9981087470449173, 0.9931439764936337]
Average Accuracy across 2 tasks: 99.56%
[0.9749408983451536, 0.9882468168462292, 0.9973319103521878]
Average Accuracy across 3 tasks: 98.68%
[0.9962174940898345, 0.9059745347698335, 0.923692636072572, 0.9979859013091642]
Average Accuracy across 4 tasks: 95.60%
[0.9929078014184397, 0.871694417238002, 0.9386339381003201, 0.9859013091641491, 0.9904185577407968]
Average Accuracy across 5 tasks: 95.59%
[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9990543735224586, 0.9951028403525954]
Average Accuracy across 2 tasks: 99.71%
[0.9111111111111111, 0.9172380019588638, 0.9983991462113126]
Average Accuracy across 3 tasks: 94.22%
[0.7300236406619386, 0.8770812928501469, 0.7524012806830309, 0.9984894259818731]
Average Accuracy across 4 tasks: 83.95%
[0.9919621749408983, 0.8751224289911851, 0.8009605122732124, 0.9884189325276939, 0.9924357034795764]
Average Accuracy across 5 tasks:

In [15]:
trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = tasks)
    trends_2.append(trend)

[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9971631205673759, 0.9960822722820764]
Average Accuracy across 2 tasks: 99.66%
[0.9985815602836879, 0.9671890303623898, 0.9983991462113126]
Average Accuracy across 3 tasks: 98.81%
[0.9125295508274232, 0.9353574926542605, 0.9695837780149413, 0.9984894259818731]
Average Accuracy across 4 tasks: 95.40%
[0.9427895981087471, 0.9201762977473066, 0.9866595517609391, 0.9934541792547835, 0.994452849218356]
Average Accuracy across 5 tasks: 96.75%
[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9995271867612293, 0.9936336924583742]
Average Accuracy across 2 tasks: 99.66%
[0.9962174940898345, 0.9857982370225269, 0.9994663820704376]
Average Accuracy across 3 tasks: 99.38%
[0.9981087470449173, 0.9382957884427032, 0.9893276414087513, 0.9984894259818731]
Average Accuracy across 4 tasks: 98.11%
[0.9947990543735225, 0.8712047012732616, 0.9754535752401281, 0.9929506545820745, 0.9949571356530509]
Average Accuracy across 5 tasks

In [16]:

trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, split_train_loaders,split_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e2, binary_labels = tasks)
    trends_3.append(trend)

[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9995271867612293, 0.9715964740450539]
Average Accuracy across 2 tasks: 98.56%
[0.9990543735224586, 0.9686581782566112, 0.9898612593383138]
Average Accuracy across 3 tasks: 98.59%
[0.9990543735224586, 0.9637610186092067, 0.9866595517609391, 0.9939577039274925]
Average Accuracy across 4 tasks: 98.59%
[0.9995271867612293, 0.8839373163565132, 0.9807897545357525, 0.9894259818731118, 0.9788199697428139]
Average Accuracy across 5 tasks: 96.65%
[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9990543735224586, 0.9779627815866797]
Average Accuracy across 2 tasks: 98.85%
[0.9995271867612293, 0.9760039177277179, 0.9887940234791889]
Average Accuracy across 3 tasks: 98.81%
[0.9990543735224586, 0.9720861900097943, 0.9818569903948773, 0.9929506545820745]
Average Accuracy across 4 tasks: 98.65%
[0.9990543735224586, 0.9647404505386875, 0.9781216648879403, 0.9939577039274925, 0.9773071104387292]
Average Accuracy across 5 task

In [17]:
trends_4 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_auto_vcl(model, 
        split_train_loaders,
        split_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks,
        dor = True)
    trends_4.append(trend)

0.9849172576832153
Average Accuracy across 1 tasks: 99.95%
0.790303623898139
0
0.47355533790401566 raw_pred
0.030165484633569495 0.41939275220372196 0.011305380343249198 all
0.03078310694200783 beta
Average Accuracy across 2 tasks: 99.66%
0.7532550693703308
1
0.7572038420490929 raw_pred
0.41939275220372196 0.49348986125933836 0.5359570316743368 all
70.3784488137113 beta
Average Accuracy across 3 tasks: 99.48%
0.9025176233635447
2
0.3076535750251762 raw_pred
0.49348986125933836 0.19496475327291063 0.23992850820810221 all
142.49986889497907 beta
Average Accuracy across 4 tasks: 99.34%
0.7853756933938476
3
0.6666666666666666 raw_pred
0.49348986125933836 0.4292486132123048 0.15886910488091505 all
7.806231234471929 beta
Average Accuracy across 5 tasks: 97.64%
0.9733806146572104
Average Accuracy across 1 tasks: 99.95%
0.8158178256611166
0
0.4387855044074437 raw_pred
0.05323877068557925 0.3683643486777668 0.022407217512722184 all
0.06747246210619345 beta
Average Accuracy across 2 tasks: 99.42

In [18]:
import matplotlib.pyplot as plt
# def plot_trends(trends):
#     for t in trends:
#         plt.plot(range(len(t[0])),np.mean(t, axis = 0))
plot_trends([trends_1,trends_2,trends_3,trends_4,])

## Intentially alike


In [57]:
tasks = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]
split_alike_train_loaders, split_alike_test_loaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)

In [58]:
coreset_size = 0
trends_alike_1 = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
    beta=1e-2, binary_labels = tasks)
    trends_alike_1.append(trend)

[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9981087470449173, 0.9936336924583742]
Average Accuracy across 2 tasks: 99.59%
[0.9947990543735225, 0.9809010773751224, 0.9989327641408752]
Average Accuracy across 3 tasks: 99.15%
[0.47801418439716314, 0.9005876591576886, 0.6947705442902882, 0.9984894259818731]
Average Accuracy across 4 tasks: 76.80%
[0.968321513002364, 0.7610186092066601, 0.9749199573105657, 0.9783484390735147, 0.9929399899142713]
Average Accuracy across 5 tasks: 93.51%
[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9990543735224586, 0.9946131243878551]
Average Accuracy across 2 tasks: 99.68%
[0.9976359338061466, 0.8712047012732616, 0.9983991462113126]
Average Accuracy across 3 tasks: 95.57%
[0.9952718676122931, 0.6214495592556317, 0.7630736392742796, 0.998992950654582]
Average Accuracy across 4 tasks: 84.47%
[0.9782505910165484, 0.7311459353574926, 0.983991462113127, 0.9838872104733132, 0.9954614220877458]
Average Accuracy across 5 tasks:

In [59]:
coreset_size = 0
trends_alike_2 = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders, optimizer, epoch_per_task, coreset_size, 
        beta=1, binary_labels = tasks)
    trends_alike_2.append(trend)

[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9981087470449173, 0.9906953966699314]
Average Accuracy across 2 tasks: 99.44%
[0.9858156028368794, 0.9823702252693438, 0.9994663820704376]
Average Accuracy across 3 tasks: 98.92%
[0.9555555555555556, 0.910871694417238, 0.9823906083244397, 0.9979859013091642]
Average Accuracy across 4 tasks: 96.17%
[0.9801418439716312, 0.8344760039177277, 0.9893276414087513, 0.9813695871097684, 0.9914271306101866]
Average Accuracy across 5 tasks: 95.53%
[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9966903073286052, 0.9916748285994124]
Average Accuracy across 2 tasks: 99.42%
[0.9981087470449173, 0.9045053868756121, 0.9978655282817502]
Average Accuracy across 3 tasks: 96.68%
[0.9947990543735225, 0.8604309500489716, 0.9935965848452508, 0.9979859013091642]
Average Accuracy across 4 tasks: 96.17%
[0.9867612293144208, 0.8192948090107738, 0.987726787620064, 0.9919436052366566, 0.994452849218356]
Average Accuracy across 5 tasks: 

In [60]:
coreset_size = 0
trends_alike_3 = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model,  split_alike_train_loaders,split_alike_test_loaders,
     optimizer, epoch_per_task, coreset_size, 
        beta=1e2, binary_labels = tasks)
    trends_alike_3.append(trend)

[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9990543735224586, 0.970617042115573]
Average Accuracy across 2 tasks: 98.48%
[0.9985815602836879, 0.9671890303623898, 0.9919957310565635]
Average Accuracy across 3 tasks: 98.59%
[0.9990543735224586, 0.965230166503428, 0.987726787620064, 0.9959718026183283]
Average Accuracy across 4 tasks: 98.70%
[0.9985815602836879, 0.9559255631733594, 0.991462113127001, 0.9904330312185297, 0.967725668179526]
Average Accuracy across 5 tasks: 98.08%
[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9995271867612293, 0.9725759059745348]
Average Accuracy across 2 tasks: 98.61%
[0.9995271867612293, 0.975024485798237, 0.9930629669156884]
Average Accuracy across 3 tasks: 98.92%
[0.9995271867612293, 0.9725759059745348, 0.9925293489861259, 0.9929506545820745]
Average Accuracy across 4 tasks: 98.94%
[0.9990543735224586, 0.9573947110675808, 0.9903948772678762, 0.9788519637462235, 0.9757942511346445]
Average Accuracy across 5 tasks: 98.

In [61]:
trends_alike_4 = []
alike_betas = []
torch.manual_seed(SEED)
for i in range(10):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, alike_beta = run_auto_vcl(model, 
        split_alike_train_loaders,
        split_alike_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = tasks,
        return_betas = True)
    trends_alike_4.append(trend)
    alike_betas.append(alike_beta)

0.979243498817967
Average Accuracy across 1 tasks: 99.91%
0.8014201762977473
0
0.5264446620959843 raw_pred
0.041513002364065965 0.3971596474045054 0.011305380343249186 all
0.04194062916890819 beta
Average Accuracy across 2 tasks: 99.61%
0.8017075773745997
1
0.7038420490928495 raw_pred
0.3971596474045054 0.3965848452508005 0.28431466007599576 all
13.789796236598516 beta
Average Accuracy across 3 tasks: 99.45%
0.8942598187311178
2
0.3504531722054381 raw_pred
0.3971596474045054 0.21148036253776437 0.118254598002793 all
16.43370669344585 beta
Average Accuracy across 4 tasks: 97.45%
0.7869894099848713
3
0.6036308623298033 raw_pred
0.3971596474045054 0.42602118003025735 0.050816417179571506 all
1.224107441900368 beta
Average Accuracy across 5 tasks: 96.22%
0.9639243498817966
Average Accuracy across 1 tasks: 99.91%
0.8270812928501469
0
0.5186092066601371 raw_pred
0.0721513002364067 0.34583741429970627 0.009681441455361192 all
0.08789846877597776 beta
Average Accuracy across 2 tasks: 99.66%
0.

In [62]:
model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(tasks)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train(model, split_alike_train_loaders[0], optimizer, 1, device, 1, 0, binary_label=[0,1])

In [63]:
test(model, split_alike_test_loaders[1], device,task_id=0, binary_label=[2,1])

(0.021845142241206154, 0.43927522037218414)

In [170]:
split_alike_train_loaders[1].dataset[2][1]

7

In [64]:
plot_trends([trends_alike_1, trends_alike_2, trends_alike_3, trends_alike_4], lower = 0.9)

## different difficulties

In [66]:
transform_cifar = transforms.Compose([
    transforms.Grayscale(num_output_channels=1), # Convert image to grayscale
    transforms.Resize((28, 28)),
    transforms.ToTensor(), 
    transforms.Normalize((0.5,), (0.5,))])

# Load the CIFAR-10 training dataset with the defined transform
cifar_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_cifar)

# Load the CIFAR-10 test dataset with the defined transform
cifar_test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_cifar)


Files already downloaded and verified
Files already downloaded and verified


In [67]:
tasks = [(0,1),(2,3),(4,5),(6,7),(8,9)]
mixed_tasks = [tasks[i//2] for i in range(len(tasks)*2)]

In [68]:
torch.manual_seed(SEED)
mnist_train_loaders, mnist_test_loaders = \
    create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=batch_size)
cifar_train_loaders, cifar_test_loaders = \
    create_split_dataloaders(cifar_train_dataset, cifar_test_dataset, tasks, batch_size=batch_size)

In [69]:
mixed_train_loaders = [mnist_train_loaders, cifar_train_loaders]
mixed_test_loaders = [mnist_test_loaders, cifar_test_loaders]

mixed_train_loaders = [mixed_train_loaders[i%2][i//2] for i in range(len(mixed_tasks))]
mixed_test_loaders = [mixed_test_loaders[i%2][i//2] for i in range(len(mixed_tasks))]

In [70]:
coreset_size= 0
torch.manual_seed(SEED)
mixed_trends_1 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2, binary_labels = mixed_tasks)
    mixed_trends_1.append(trend)

[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9981087470449173, 0.867]
Average Accuracy across 2 tasks: 93.26%
[0.9990543735224586, 0.8065, 0.9946131243878551]
Average Accuracy across 3 tasks: 93.34%
[0.9030732860520094, 0.763, 0.9760039177277179, 0.712]
Average Accuracy across 4 tasks: 83.85%
[0.9375886524822695, 0.761, 0.8413320274240941, 0.652, 0.9983991462113126]
Average Accuracy across 5 tasks: 83.81%
[0.9687943262411347, 0.688, 0.8545543584720862, 0.6565, 0.9946638207043756, 0.764]
Average Accuracy across 6 tasks: 82.11%
[0.7456264775413711, 0.5765, 0.9030362389813908, 0.6185, 0.884204909284952, 0.7135, 0.9979859013091642]
Average Accuracy across 7 tasks: 77.71%
[0.8803782505910166, 0.6765, 0.8677766895200784, 0.613, 0.8324439701173959, 0.692, 0.9773413897280967, 0.8595]
Average Accuracy across 8 tasks: 79.99%
[0.9739952718676123, 0.696, 0.7541625857002938, 0.617, 0.9583778014941302, 0.6445, 0.9521651560926485, 0.7915, 0.9959657085224407]
Average Accuracy across

In [75]:
torch.manual_seed(SEED)
mixed_trends_2 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1, binary_labels = mixed_tasks)
    mixed_trends_2.append(trend)

[0.9990543735224586]
Average Accuracy across 1 tasks: 99.91%
[0.9985815602836879, 0.867]
Average Accuracy across 2 tasks: 93.28%
[0.9990543735224586, 0.8415, 0.9931439764936337]
Average Accuracy across 3 tasks: 94.46%
[0.9167848699763593, 0.8105, 0.9760039177277179, 0.6975]
Average Accuracy across 4 tasks: 85.02%
[0.9787234042553191, 0.7925, 0.9138099902056807, 0.6735, 0.9983991462113126]
Average Accuracy across 5 tasks: 87.14%
[0.9229314420803783, 0.7225, 0.8594515181194907, 0.6695, 0.996264674493063, 0.7765]
Average Accuracy across 6 tasks: 82.45%


KeyboardInterrupt: 

In [41]:
torch.manual_seed(SEED)
mixed_trends_3 = []
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, mixed_train_loaders,mixed_test_loaders, optimizer, epoch_per_task, coreset_size, 
        beta=1e2, binary_labels = mixed_tasks)
    mixed_trends_3.append(trend)

NameError: name 'mixed_tasks' is not defined

In [73]:
mixed_trends_4 = []
mixed_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, m_beta = run_auto_vcl(model, 
        mixed_train_loaders,
        mixed_test_loaders,
        optimizer, 
        epoch_per_task, 
        coreset_size,
        binary_labels = mixed_tasks,
        return_betas = True)
    mixed_trends_4.append(trend)
    mixed_betas.append(m_beta)

0.9734278959810874
Average Accuracy across 1 tasks: 99.91%
0.7145499999999999
0
0.5635 raw_pred
0.05314420803782527 0.5709000000000002 0.023430667673670316 all
0.01053657597065627 beta
Average Accuracy across 2 tasks: 93.28%
0.7502938295788442
1
0.4720861900097943 raw_pred
0.5709000000000002 0.4994123408423117 0.011638570781885236 all
2.1503290384782243 beta
Average Accuracy across 3 tasks: 95.12%
0.5905000000000001
2
0.502 raw_pred
0.5709000000000002 0.8189999999999997 0.006964089177762089 all
0.1085066503231236 beta
Average Accuracy across 4 tasks: 83.65%
0.7350586979722518
3
0.4823906083244397 raw_pred
0.8189999999999997 0.5298826040554965 0.00949159044096126 all
15.647175207027473 beta
Average Accuracy across 5 tasks: 87.70%
0.6059
4
0.414 raw_pred
0.8189999999999997 0.7882 0.03626371637464836 all
1.8546196867506108 beta
Average Accuracy across 6 tasks: 83.59%
0.9106747230614299
5
0.4405840886203424 raw_pred
0.8189999999999997 0.17865055387714013 0.021632643504852576 all
444.557926

In [None]:
# mixed_trends_5 = []
# for i in range(5):
#     model = MFVI_NN(28*28, [256, 256], 2, num_tasks = len(mixed_tasks)).to(device)
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
#     trend = run_auto_vcl(model, 
#         mixed_train_loaders,
#         mixed_test_loaders,
#         optimizer, 
#         epoch_per_task, 
#         coreset_size,
#         binary_labels = mixed_tasks,
#         dor = True)
#     mixed_trends_5.append(trend)

In [74]:
plot_trends([mixed_trends_1, mixed_trends_2, mixed_trends_3, mixed_trends_4])

P-MNIST

In [30]:
p_trends_1 = []
torch.manual_seed(SEED)
coreset_size = 0
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2)
    p_trends_1.append(trend)

Average Accuracy across 1 tasks: 97.18%
Average Accuracy across 2 tasks: 86.45%


KeyboardInterrupt: 

In [71]:
model

MFVI_NN(
  (layers): ModuleList(
    (0-1): 2 x MFVI_Layer()
  )
  (task_specific_layers): ModuleDict(
    (0): MFVI_Layer()
  )
)

In [22]:
p_trends_2 = []
coreset_size = 0
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, permuted_mnist_train_loaders,permuted_mnist_test_loaders, 
        optimizer,10, coreset_size, beta=100)
    p_trends_2.append(trend)

TypeError: MFVI_NN.__init__() got an unexpected keyword argument 'single_head'

In [None]:
p_trends_3 = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, pmnist_train_loaders,pmnist_test_loaders,
     optimizer, epoch_per_task, coreset_size, beta=100)
    p_trends_3.append(trend)

[0.9721]
Average Accuracy across 1 tasks: 97.21%
[0.9254, 0.9682]
Average Accuracy across 2 tasks: 94.68%
[0.8229, 0.8182, 0.9667]
Average Accuracy across 3 tasks: 86.93%


KeyboardInterrupt: 

In [50]:
p_trends_4 = []
p_betas = []
torch.manual_seed(SEED)
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend, p_beta= run_auto_vcl(model, pmnist_train_loaders,pmnist_test_loaders, optimizer, 
        epoch_per_task, coreset_size, return_betas=True)
    p_trends_4.append(trend)
    p_betas.append(p_beta)

0.6274200000000001
Average Accuracy across 1 tasks: 97.39%
0.59131
0
0.1433 raw_pred
0.41397777777777767 0.45409999999999995 0.0002933062276259872 all
0.692921977329255 beta
Average Accuracy across 2 tasks: 95.61%
0.62581
1
0.0637 raw_pred
0.45409999999999995 0.41576666666666673 0.00025499795550306954 all
1.4267678845975205 beta
Average Accuracy across 3 tasks: 91.74%
0.61495
2
0.1025 raw_pred
0.45409999999999995 0.4278333333333333 0.00012972033049847497 all
1.2752212702334638 beta


KeyboardInterrupt: 

In [None]:
plot_trends([p_trends_1, p_trends_2,p_trends_3, p_trends_4])

In [53]:
pc_trends_1 = []
coreset_size = 1000
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-2)
    pc_trends_1.append(trend)
pc_trends_2 = []
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, beta=1)
    pc_trends_2.append(trend)
pc_trends_3 = []
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e2)
    pc_trends_3.append(trend)
pc_trends_4 = []
for i in range(1):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend  = run_auto_vcl(model, train_loaders,test_loaders, optimizer, epoch_per_task, coreset_size, 
            raw_training_epoch = raw_training_epoch,
            raw_train_size= raw_train_size)
    pc_trends_4.append(trend)

Average Accuracy across 1 tasks: 99.91%


KeyboardInterrupt: 

In [None]:
plot_trends([pc_trends_1, pc_trends_2, pc_trends_3, pc_trends_4])