In [9]:
class MFVI_Layer(nn.Module):
    def __init__(self, in_features, out_features):
        super(MFVI_Layer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Mean and log variance parameters
        self.W_m = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_m = nn.Parameter(torch.Tensor(out_features))
        self.W_logv = nn.Parameter(torch.Tensor(out_features, in_features))
        self.b_logv = nn.Parameter(torch.Tensor(out_features))

        # Prior distributions initialized to None
        self.prior_W_m = self.prior_b_m = self.prior_W_logv = self.prior_b_logv = None

        self.reset_parameters()

    def reset_parameters(self):
        # Initialize mean parameters to a normal distribution and log variance parameters to a small value
        self.W_m.data.normal_(0, 0.1)
        self.b_m.data.normal_(0, 0.1)
        self.W_logv.data.fill_(-6.0)
        self.b_logv.data.fill_(-6.0)
        self.set_priors()

    def set_priors(self):
        # Set priors to the current parameters
        self.prior_W_m = self.W_m.detach().clone()
        self.prior_b_m = self.b_m.detach().clone()
        self.prior_W_logv = self.W_logv.detach().clone()
        self.prior_b_logv = self.b_logv.detach().clone()

    def forward(self, x, sample=True):
        W_std = torch.exp(0.5 * self.W_logv)
        b_std = torch.exp(0.5 * self.b_logv)
        
        act_mu = F.linear(x, self.W_m, self.b_m)
        act_var = 1e-16 + F.linear(x.pow(2), W_std.pow(2)) + b_std.pow(2)
        act_std = torch.sqrt(act_var)

        if self.training or sample:
            eps = torch.randn_like(act_mu)
            return act_mu + act_std * eps
        else:
            return act_mu

    def kl_divergence(self, device):
        self.update_prior_device(device)
        num_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        # Convert log variance to standard deviation for posterior and prior
        W_std_post = torch.exp(0.5 * self.W_logv)
        b_std_post = torch.exp(0.5 * self.b_logv)
        W_std_prior = torch.exp(0.5 * self.prior_W_logv)
        b_std_prior = torch.exp(0.5 * self.prior_b_logv)

        # Calculate KL divergence for weights
        kl_div_W = torch.log(W_std_prior / W_std_post) + \
                ((W_std_post**2 + (self.W_m - self.prior_W_m)**2) / (2 * W_std_prior**2)) - 0.5
        # Sum over all elements
        kl_div_W = kl_div_W.sum()

        # Calculate KL divergence for biases
        kl_div_b = torch.log(b_std_prior / b_std_post) + \
                ((b_std_post**2 + (self.b_m - self.prior_b_m)**2) / (2 * b_std_prior**2)) - 0.5
        # Sum over all elements
        kl_div_b = kl_div_b.sum()

        # Total KL divergence
        total_kl = kl_div_W + kl_div_b

        return total_kl/num_params

    def update_prior_device(self, device):
        # Ensure priors are moved to the correct device
        self.prior_W_m = self.prior_W_m.to(device)
        self.prior_b_m = self.prior_b_m.to(device)
        self.prior_W_logv = self.prior_W_logv.to(device)
        self.prior_b_logv = self.prior_b_logv.to(device)

class MFVI_NN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size, num_tasks=1, single_head=False):
        super(MFVI_NN, self).__init__()
        self.input_size = input_size
        self.hidden_sizes = hidden_sizes
        self.output_size = output_size
        self.num_tasks = num_tasks
        self.single_head = single_head
        self.layers = nn.ModuleList()
        self.task_specific_layers = nn.ModuleDict()  # Using ModuleDict for task-specific layers

        # Construct shared layers
        sizes = [input_size] + hidden_sizes
        for i in range(len(sizes) - 1):
            self.layers.append(MFVI_Layer(sizes[i], sizes[i + 1]))

        # Construct task-specific output layers
        if single_head:
            self.task_specific_layers[str(0)] = MFVI_Layer(sizes[-1], output_size)
        else:
            for task_id in range(num_tasks):
                self.task_specific_layers[str(task_id)] = MFVI_Layer(sizes[-1], output_size)

    def forward(self, x, task_id=0, sample=True):
        for layer in self.layers:
            x = F.relu(layer(x, sample=sample))
        if self.single_head:
            x = self.task_specific_layers["0"](x, sample=sample)
        else:
            x = self.task_specific_layers[str(task_id)](x, sample=sample)
        return x

    def kl_divergence(self):
        kl_div = 0
        # Accumulate KL divergence from shared and task-specific layers
        for layer in self.layers:
            kl_div += layer.kl_divergence(next(self.parameters()).device)
        # for task_layer in self.task_specific_layers.values():
        #     kl_div += task_layer.kl_divergence(next(self.parameters()).device)
        return kl_div

    def update_priors(self):
        # Update priors in each layer
        for layer in self.layers + list(self.task_specific_layers.values()):
            layer.set_priors()

In [16]:
import torch.nn.functional as F

def train(model, trainloader, optimizer, epoch, device, kl_weight=1, task_id = 0, binary_label = None):
    model.train()
    for batch_idx, (data, target) in enumerate(trainloader):
        data, target = data.to(device), target.to(device)
        data = data.view(data.size(0), -1)  # Flatten the images
        optimizer.zero_grad()
        output = model(data, sample=True, task_id = task_id)
        if binary_label != None:
            target = (target == binary_label[0]).long()
        reconstruction_loss = F.cross_entropy(output, target, reduction='mean')
        
#         print(kl_divergence,reconstruction_loss)
        if task_id == 0:
            loss = reconstruction_loss
        else:
            kl_divergence = model.kl_divergence()
            loss = reconstruction_loss + kl_divergence * kl_weight
        loss.backward()
        optimizer.step()
#     print(f"Train Epoch: {epoch} [{batch_idx * len(data)}/{len(trainloader.dataset)} ({100. * batch_idx / len(trainloader):.0f}%)]\tLoss: {loss.item()}")

def test(model, testloader, device, task_id = 0, binary_label = None):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            data = data.view(data.size(0), -1)  # Flatten the images
            output = model(data, sample=False, task_id = task_id)
            if binary_label != None:
                target = (target == binary_label[0]).long()

            # Use cross_entropy for test loss calculation, sum up batch loss
            test_loss += F.cross_entropy(output, target, reduction='mean').item()
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(testloader.dataset)
    # print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(testloader.dataset)} ({100. * correct / len(testloader.dataset):.0f}%)\n')
    return test_loss, correct / len(testloader.dataset)


## Permuted MNIST

In [11]:
from torchvision import datasets, transforms

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='~/.pytorch/MNIST_data/', train=False, download=True, transform=transform)

In [12]:
import torch

def generate_permutations(task_count, image_size):
    permutations = [torch.randperm(image_size) for _ in range(task_count - 1)]
    return permutations

# Generate permutations for 9 tasks (+1 original MNIST)
task_count = 10
image_size = 28 * 28  # MNIST images are 28x28
permutations = generate_permutations(task_count, image_size)

In [13]:
from torch.utils.data import Dataset

class PermutedMNIST(Dataset):
    def __init__(self, mnist_dataset, permutation=None):
        self.mnist_dataset = mnist_dataset
        self.permutation = permutation

    def __len__(self):
        return len(self.mnist_dataset)

    def __getitem__(self, idx):
        image, label = self.mnist_dataset[idx]
        if self.permutation is not None:
            # Apply permutation
            image = image.view(-1)[self.permutation].view(1, 28, 28)
        return image, label


In [22]:
from torch.utils.data import DataLoader

batch_size = 256

# Create a DataLoader for the original MNIST
pmnist_train_loaders = [DataLoader(train_dataset, batch_size=batch_size, shuffle=True)]
pmnist_test_loaders = [DataLoader(test_dataset, batch_size=batch_size, shuffle=False)]
train_datasets = [train_dataset]
# Create DataLoaders for permuted tasks
for perm in permutations:
    permuted_train = PermutedMNIST(train_dataset, permutation=perm)
    permuted_test = PermutedMNIST(test_dataset, permutation=perm)

    train_loader = DataLoader(permuted_train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(permuted_test, batch_size=batch_size, shuffle=False)

    pmnist_train_loaders.append(train_loader)
    pmnist_test_loaders.append(test_loader)
    train_datasets.append(permuted_train)


In [17]:
from tqdm import tqdm
# Assuming model, optimizer, train_loaders, test_loaders are defined
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
    
model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epoch_per_task = 10
beta = 1
def run_vcl(model, optimizer, epoch_per_task, beta):
    device = next(model.parameters()).device
    task_accuracies = []
    ave_acc_trend = []
    prev_test_loaders= []
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        task_accuracies = []
        prev_test_loaders.append(test_loader)
        for epoch in tqdm(range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id)
        task_num = 0  
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(model, ptl, device,task_id=task_num)
            task_accuracies.append(task_accuracy)
            task_num += 1
        
        model.update_priors()

        average_accuracy = sum(task_accuracies) / len(task_accuracies)
        ave_acc_trend.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies)} tasks: {average_accuracy*100:.2f}%')
run_vcl(model, optimizer, epoch_per_task, beta)

  0%|          | 0/10 [00:00<?, ?it/s]

100%|██████████| 10/10 [00:33<00:00,  3.36s/it]


Average Accuracy across 1 tasks: 97.78%


100%|██████████| 10/10 [00:46<00:00,  4.66s/it]


Average Accuracy across 2 tasks: 96.72%


100%|██████████| 10/10 [00:51<00:00,  5.18s/it]


Average Accuracy across 3 tasks: 96.49%


100%|██████████| 10/10 [00:56<00:00,  5.66s/it]


Average Accuracy across 4 tasks: 95.99%


 90%|█████████ | 9/10 [01:04<00:07,  7.21s/it]


KeyboardInterrupt: 

In [27]:
pmnist_test_loaders[0].dataset.targets

tensor([7, 2, 1,  ..., 4, 5, 6])

In [8]:
import matplotlib.pyplot as plt

plt.plot(range(len(ave_acc_trend)),ave_acc_trend)

NameError: name 'ave_acc_trend' is not defined

In [9]:
from torch.utils.data import DataLoader, ConcatDataset, Subset
def random_coreset(dataset, coreset_size):
    """
    Randomly selects a subset of data points to form a coreset.

    Args:
    - dataset (torch.utils.data.Dataset): The dataset to sample from.
    - coreset_size (int): The number of samples to include in the coreset.

    Returns:
    - coreset_indices (torch.Tensor): Indices of the selected samples.
    """
    # Ensure coreset size does not exceed dataset size
    coreset_size = min(coreset_size, len(dataset))
    
    # Randomly select indices without replacement
    coreset_indices = np.random.choice(len(dataset), size=coreset_size, replace=False)
    
    # Convert numpy array to torch tensor
    coreset_indices = torch.from_numpy(coreset_indices)
    
    coreset = Subset(dataset, coreset_indices)
    return coreset



In [20]:
from tqdm import tqdm
epoch_per_task = 10
model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
coreset_size = 200
beta = 1

def run_vcl_with_coreset(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size, beta):
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        task_accuracies_rc = []
        for i in tqdm(range(len(coresets))):
            for epoch in (range(1, epoch_per_task + 1)):
                coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
                train(model, coreset_loader, optimizer, epoch, device, beta, task_id=i)
            model.update_priors()
        for epoch in tqdm(range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id)
        model.update_priors()
        coresets.append(random_coreset(train_datasets[task_id], coreset_size))

        # for prediction
        prediction_model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay
        for i in tqdm(range(len(coresets))):
            for epoch in (range(1, epoch_per_task + 1)):
                coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
                train(prediction_model, coreset_loader, optimizer, epoch, device, beta, task_id=i)
        task_num = 0  
        prev_test_loaders.append(test_loader)
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num)
            task_accuracies_rc.append(task_accuracy)
            task_num += 1

        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    return ave_acc_trend_rc

ave_acc_trend_1 = run_vcl_with_coreset(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size, beta=1e-1)

0it [00:00, ?it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:04<00:42,  4.75s/it]


KeyboardInterrupt: 

In [24]:
def run_vcl_with_coreset(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size, beta):
    ave_acc_trend_rc = []
    prev_test_loaders= []
    coresets = []
    for task_id, (train_loader, test_loader) in enumerate(zip(train_loaders, test_loaders), start=0):
        task_accuracies_rc = []
        # for i in tqdm(range(len(coresets))):
        #     for epoch in (range(1, epoch_per_task + 1)):
        #         coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
        #         train(model, coreset_loader, optimizer, epoch, device, beta, task_id=i)
        #     model.update_priors()
        for epoch in tqdm(range(1, epoch_per_task + 1)):
            train(model, train_loader, optimizer, epoch, device, beta, task_id=task_id)
        model.update_priors()
        # coresets.append(random_coreset(train_datasets[task_id], coreset_size))

        # for prediction
        prediction_model =type(model)(model.input_size, model.hidden_sizes, model.output_size, model.num_tasks, model.single_head).to(device)
        prediction_model.load_state_dict(model.state_dict())
        # replay
        # for i in tqdm(range(len(coresets))):
        #     for epoch in (range(1, epoch_per_task + 1)):
        #         coreset_loader = DataLoader(coresets[i], batch_size=batch_size, shuffle=True)
        #         train(prediction_model, coreset_loader, optimizer, epoch, device, beta, task_id=i)
        task_num = 0  
        prev_test_loaders.append(test_loader)
        for ptl in prev_test_loaders: 
            test_loss, task_accuracy = test(prediction_model, ptl, device,task_id=task_num)
            task_accuracies_rc.append(task_accuracy)
            task_num += 1

        average_accuracy = sum(task_accuracies_rc) / len(task_accuracies_rc)
        ave_acc_trend_rc.append(average_accuracy)
        print(f'Average Accuracy across {len(task_accuracies_rc)} tasks: {average_accuracy*100:.2f}%')
    return ave_acc_trend_rc

In [28]:
p_trends_2 = []
coreset_size = 0
for i in range(5):
    model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10, single_head=True).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    trend = run_vcl_with_coreset(model, pmnist_train_loaders, pmnist_train_loaders, 
        optimizer,10, coreset_size, beta=1)
    p_trends_2.append(trend)

100%|██████████| 10/10 [00:49<00:00,  4.90s/it]


Average Accuracy across 1 tasks: 99.03%


 70%|███████   | 7/10 [01:10<00:30, 10.06s/it]


KeyboardInterrupt: 

In [11]:
epoch_per_task = 10
model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
coreset_size = 200
ave_acc_trend_2 = run_vcl_with_coreset(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size, beta=1)

0it [00:00, ?it/s]
100%|██████████| 10/10 [02:07<00:00, 12.79s/it]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


Average Accuracy across 1 tasks: 97.82%


100%|██████████| 1/1 [00:00<00:00,  2.20it/s]
100%|██████████| 10/10 [03:00<00:00, 18.10s/it]
100%|██████████| 2/2 [00:01<00:00,  1.79it/s]


Average Accuracy across 2 tasks: 95.91%


100%|██████████| 2/2 [00:01<00:00,  1.81it/s]
100%|██████████| 10/10 [03:01<00:00, 18.15s/it]
100%|██████████| 3/3 [00:01<00:00,  1.74it/s]


Average Accuracy across 3 tasks: 95.53%


100%|██████████| 3/3 [00:01<00:00,  1.67it/s]
100%|██████████| 10/10 [03:01<00:00, 18.14s/it]
100%|██████████| 4/4 [00:02<00:00,  1.67it/s]


Average Accuracy across 4 tasks: 94.73%


100%|██████████| 4/4 [00:02<00:00,  1.70it/s]
100%|██████████| 10/10 [03:01<00:00, 18.11s/it]
100%|██████████| 5/5 [00:03<00:00,  1.66it/s]


Average Accuracy across 5 tasks: 93.29%


100%|██████████| 5/5 [00:03<00:00,  1.61it/s]
100%|██████████| 10/10 [03:00<00:00, 18.06s/it]
100%|██████████| 6/6 [00:03<00:00,  1.62it/s]


Average Accuracy across 6 tasks: 93.78%


100%|██████████| 6/6 [00:03<00:00,  1.64it/s]
100%|██████████| 10/10 [03:01<00:00, 18.12s/it]
100%|██████████| 7/7 [00:04<00:00,  1.62it/s]


Average Accuracy across 7 tasks: 93.35%


100%|██████████| 7/7 [00:04<00:00,  1.59it/s]
100%|██████████| 10/10 [03:01<00:00, 18.15s/it]
100%|██████████| 8/8 [00:05<00:00,  1.60it/s]


Average Accuracy across 8 tasks: 92.62%


100%|██████████| 8/8 [00:05<00:00,  1.58it/s]
100%|██████████| 10/10 [03:01<00:00, 18.16s/it]
100%|██████████| 9/9 [00:05<00:00,  1.62it/s]


Average Accuracy across 9 tasks: 92.50%


100%|██████████| 9/9 [00:05<00:00,  1.57it/s]
100%|██████████| 10/10 [02:59<00:00, 17.95s/it]
100%|██████████| 10/10 [00:06<00:00,  1.61it/s]


Average Accuracy across 10 tasks: 91.71%


In [12]:
model = MFVI_NN(28*28, [100, 100], 10, num_tasks = 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
ave_acc_trend_3 = run_vcl_with_coreset(model, train_loaders, test_loaders, optimizer, epoch_per_task, coreset_size, beta=10)

0it [00:00, ?it/s]
100%|██████████| 10/10 [02:05<00:00, 12.59s/it]
100%|██████████| 1/1 [00:00<00:00,  2.25it/s]


Average Accuracy across 1 tasks: 97.85%


100%|██████████| 1/1 [00:00<00:00,  2.30it/s]
100%|██████████| 10/10 [02:59<00:00, 17.95s/it]
100%|██████████| 2/2 [00:01<00:00,  1.78it/s]


Average Accuracy across 2 tasks: 93.47%


100%|██████████| 2/2 [00:01<00:00,  1.78it/s]
100%|██████████| 10/10 [03:03<00:00, 18.34s/it]
100%|██████████| 3/3 [00:01<00:00,  1.66it/s]


Average Accuracy across 3 tasks: 91.85%


100%|██████████| 3/3 [00:01<00:00,  1.66it/s]
100%|██████████| 10/10 [03:07<00:00, 18.71s/it]
100%|██████████| 4/4 [00:02<00:00,  1.66it/s]


Average Accuracy across 4 tasks: 90.50%


100%|██████████| 4/4 [00:02<00:00,  1.61it/s]
100%|██████████| 10/10 [03:04<00:00, 18.48s/it]
100%|██████████| 5/5 [00:03<00:00,  1.66it/s]


Average Accuracy across 5 tasks: 89.79%


100%|██████████| 5/5 [00:03<00:00,  1.54it/s]
100%|██████████| 10/10 [03:06<00:00, 18.64s/it]
100%|██████████| 6/6 [00:03<00:00,  1.61it/s]


Average Accuracy across 6 tasks: 89.01%


100%|██████████| 6/6 [00:03<00:00,  1.55it/s]
100%|██████████| 10/10 [03:04<00:00, 18.40s/it]
100%|██████████| 7/7 [00:04<00:00,  1.62it/s]


Average Accuracy across 7 tasks: 88.30%


100%|██████████| 7/7 [00:04<00:00,  1.55it/s]
100%|██████████| 10/10 [03:05<00:00, 18.54s/it]
100%|██████████| 8/8 [00:05<00:00,  1.58it/s]


Average Accuracy across 8 tasks: 88.08%


100%|██████████| 8/8 [00:05<00:00,  1.57it/s]
100%|██████████| 10/10 [03:03<00:00, 18.36s/it]
100%|██████████| 9/9 [00:05<00:00,  1.60it/s]


Average Accuracy across 9 tasks: 87.32%


100%|██████████| 9/9 [00:05<00:00,  1.55it/s]
100%|██████████| 10/10 [03:04<00:00, 18.42s/it]
100%|██████████| 10/10 [00:06<00:00,  1.57it/s]


Average Accuracy across 10 tasks: 87.22%


In [13]:
plt.plot(range(len(ave_acc_trend_1)),ave_acc_trend_rc)
plt.plot(range(len(ave_acc_trend_2)),ave_acc_trend_rc)
plt.plot(range(len(ave_acc_trend_3)),ave_acc_trend_rc)

NameError: name 'ave_acc_trend_rc' is not defined

In [50]:
import np 
def greedy_k_center(X, k):
    """
    Selects k points from X using the greedy k-center algorithm.
    
    Args:
    - X (np.array): The dataset, shape (n_samples, n_features).
    - k (int): Number of centers to select.
    
    Returns:
    - centers (np.array): The selected centers, shape (k, n_features).
    - indices (list): Indices of the selected centers.
    """
    # Randomly choose the first center
    n_samples = X.shape[0]
    first_center_idx = np.random.choice(n_samples)
    centers = [X[first_center_idx]]
    indices = [first_center_idx]
    
    # Initialize the minimum distance to the closest center for each point
    min_distances = np.full(n_samples, np.inf)
    
    # Iteratively select k-1 remaining centers
    for _ in range(k - 1):
        # Update the minimum distances for all points
        distances = np.linalg.norm(X - centers[-1], axis=1)
        min_distances = np.minimum(min_distances, distances)
        
        # Select the next center to be the point with the maximum distance to its closest center
        next_center_idx = np.argmax(min_distances)
        centers.append(X[next_center_idx])
        indices.append(next_center_idx)
    
    return np.array(centers), indices


ModuleNotFoundError: No module named 'np'

In [51]:
def create_split_task(dataset, classes):
    """
    Create a binary classification task from the MNIST dataset.
    
    Parameters:
    - dataset: The original MNIST dataset (training or test).
    - classes: A tuple of two integers representing the classes to include in the split.
    
    Returns:
    - A Subset of the original dataset containing only the specified classes.
    """
    # Find indices of classes we're interested in
    indices = [i for i, (_, target) in enumerate(dataset) if target in classes]
    
    # Create a subset of the dataset with only the specified classes
    subset = Subset(dataset, indices)
    
    return subset

def create_split_dataloaders(train_dataset, test_dataset, tasks, batch_size=256):
    """
    Create DataLoaders for each binary task in Split MNIST.
    
    Parameters:
    - train_dataset: The MNIST training dataset.
    - test_dataset: The MNIST test dataset.
    - batch_size: The batch size for the DataLoader.
    
    Returns:
    - A list of tuples containing (train_loader, test_loader) for each binary task.
    """
    dataloaders = []

    for task in tasks:
        # Create training subset and DataLoader
        train_subset = create_split_task(train_dataset, task)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

        # Create test subset and DataLoader
        test_subset = create_split_task(test_dataset, task)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False)

        dataloaders.append((train_loader, test_loader))
    
    return dataloaders

In [63]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5, 0.5, 0.5))]
)

In [64]:
cifar10_trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
cifar10_testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

Files already downloaded and verified

Files already downloaded and verified


In [65]:
tasks = [(6, 1), (9, 7), (8, 4), (5, 0), (2, 3)]
dataloaders = create_split_dataloaders(mnist_trainset, mnist_testset, tasks, batch_size=256)

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

In [None]:
ave_acc_trend_split_mnist_1 = run_vcl_with_coreset(model, *dataloaders optimizer, epoch_per_task, coreset_size, beta=10)