In [1]:
!pip install laplace-torch

Collecting laplace-torch
  Downloading laplace_torch-0.2.2.2-py3-none-any.whl.metadata (5.1 kB)
Collecting asdfghjkl==0.1a4 (from laplace-torch)
  Downloading asdfghjkl-0.1a4-py3-none-any.whl.metadata (3.2 kB)
Collecting backpack-for-pytorch (from laplace-torch)
  Downloading backpack_for_pytorch-1.7.1-py3-none-any.whl.metadata (4.4 kB)
Collecting curvlinops-for-pytorch>=2.0 (from laplace-torch)
  Downloading curvlinops_for_pytorch-2.0.1-py3-none-any.whl.metadata (4.9 kB)
Collecting torchmetrics (from laplace-torch)
  Downloading torchmetrics-1.6.1-py3-none-any.whl.metadata (21 kB)
Collecting einconv (from curvlinops-for-pytorch>=2.0->laplace-torch)
  Downloading einconv-0.1.0-py3-none-any.whl.metadata (1.9 kB)
Collecting unfoldNd<1.0.0,>=0.2.0 (from backpack-for-pytorch->laplace-torch)
  Downloading unfoldNd-0.2.3-py3-none-any.whl.metadata (1.5 kB)
Collecting lightning-utilities>=0.8.0 (from torchmetrics->laplace-torch)
  Downloading lightning_utilities-0.11.9-py3-none-any.whl.metadat

In [2]:
import random
import numpy as np
import os
import sys
import torch
from torchvision.transforms import transforms
from torchvision import datasets
from collections import Counter
from torch import nn
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, Subset, ConcatDataset, random_split
import numpy as np
from copy import deepcopy
import numpy as np
from laplace import Laplace
from laplace.curvature import AsdlGGN
import torch
import torch.nn.functional as F
import pickle
import time

In [3]:
# Check if running on Colab
try:
  import google.colab
  IN_COLAB=True
except:
  IN_COLAB=False

if IN_COLAB:
  from google.colab import drive
  # Connect to Google drive where the training data is located
  drive.mount("/content/gdrive")
  work_dir = "/content/gdrive/My Drive/Colab Notebooks/DL-Project-2024-Experiments/SUBMISSION"
  os.chdir(work_dir)
  print(f"Connected to Google drive, setting working directory to '{work_dir}'")

Mounted at /content/gdrive
Connected to Google drive, setting working directory to '/content/gdrive/My Drive/Colab Notebooks/DL-Project-2024-Experiments/SUBMISSION'


In [4]:
# Check if GPUs are available
device=torch.device("cpu")
if torch.cuda.is_available():
    device=torch.device("cuda")
    print(f"Number of available GPUs: {torch.cuda.device_count()}")
print(f"Device: {device}")

# Check system
print(f"Number of CPUs: {os.cpu_count()}")
print(f"System version: {sys.version_info}")

Number of available GPUs: 1
Device: cuda
Number of CPUs: 12
System version: sys.version_info(major=3, minor=10, micro=12, releaselevel='final', serial=0)


In [5]:
# Create results folder if it does not exist yet
results_folder_name = 'raw_result_data'
if not os.path.exists(results_folder_name):
    os.makedirs(results_folder_name)

In [6]:
EPOCHS=50
BATCH_SIZE=100
LEARNING_RATE=0.001
OPTIMIZER_MOMENTUM=0.9
N_CLASSES=10
MEASURE_SENSITIVITY_ON_LAST_EPOCH_ONLY=True

In [7]:
# Load training data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = datasets.CIFAR10(root='./data', train=True,
                            download=True,
                            transform=transform)
testset = datasets.CIFAR10(root='./data', train=False,
                           download=True,
                           transform=transform)

print(f"Loaded CIFAR10 data: training={len(trainset)} items, testing={len(testset)} items.")

Files already downloaded and verified
Files already downloaded and verified
Loaded CIFAR10 data: training=50000 items, testing=10000 items.


In [8]:
def split_dataset_to_labels(dataset, subset_labels):
  return Subset(dataset, [i for i,label in enumerate(dataset.targets) if label in subset_labels])

subset_A_labels=[0,1,2,3,4]
subset_B_labels=[5,6,7,8,9]

# split the CIFAR10 data into two subsets according to their labels
trainset_A=split_dataset_to_labels(trainset, subset_A_labels)
trainset_B=split_dataset_to_labels(trainset, subset_B_labels)

testset_A=split_dataset_to_labels(testset, subset_A_labels)
testset_B=split_dataset_to_labels(testset, subset_B_labels)

In [9]:
def check_class_balance(dataset):
  label_counter = Counter()
  for _, label in dataset:
    label_counter[label]+=1
  total=len(dataset)
  for label in sorted(label_counter.keys()):
    n_label=label_counter[label]
    percentage=100*(n_label/total)
    print(f"Label {label}: {label_counter[label]}/{len(dataset)}={percentage:.0f}%, ")

In [10]:
#check if labels are balanced
print(check_class_balance(trainset))

Label 0: 5000/50000=10%, 
Label 1: 5000/50000=10%, 
Label 2: 5000/50000=10%, 
Label 3: 5000/50000=10%, 
Label 4: 5000/50000=10%, 
Label 5: 5000/50000=10%, 
Label 6: 5000/50000=10%, 
Label 7: 5000/50000=10%, 
Label 8: 5000/50000=10%, 
Label 9: 5000/50000=10%, 
None


In [11]:
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [12]:
def softmax_hessian(probs, eps=1e-10):
  """copied from https://github.com/team-approx-bayes/memory-perturbation
  """
  return torch.clamp(probs - probs * probs, min=eps, max=1 - eps)

def get_pred_vars_laplace(net, trainloader, nc, version='kfac', device='cuda'):
  """copied from https://github.com/team-approx-bayes/memory-perturbation
  """
  if device=='cuda':
      torch.cuda.empty_cache()
  elif device=='mps':
      torch.mps.empty_cache()


  if device == 'mps':
      # use simplyfied approximation on local since kfac times out
      version = 'diag'

  if version == 'kfac':
      hessian_structure = 'kron'
  elif version == 'diag':
      hessian_structure = 'diag'


  laplace_object = Laplace(
      net, 'classification',
      subset_of_weights='all',
      hessian_structure=hessian_structure,
      backend=AsdlGGN,
      )

  if device=='cuda':
      torch.cuda.empty_cache()
  elif device=='mps':
      torch.mps.empty_cache()

  laplace_object.fit(trainloader)

  fvars = np.empty(shape=(0, nc))
  for inputs, _ in trainloader:
      inputs = inputs.to(device)
      _, fvar = laplace_object._glm_predictive_distribution(inputs)
      fvars = np.vstack((fvars, np.diagonal(fvar.cpu().numpy(), axis1=1, axis2=2)))

  del laplace_object
  if device=='cuda':
      torch.cuda.empty_cache()
  elif device=='mps':
      torch.mps.empty_cache()
  return fvars.tolist()

def prediction_sensitivity(model, var_dataloader, n_classes, device):
  residuals_list = []
  lams_list = []

  for batch, (X, y) in enumerate(var_dataloader, 0):
      X, y = X.to(device), y.to(device)
      with torch.no_grad():
          model.eval()
          pred = model(X)
          probs = F.softmax(pred, dim=-1)
          residuals_list.append((probs - F.one_hot(y, n_classes)).detach().cpu().numpy())
          lams = softmax_hessian(probs).cpu().numpy()
          lams_list.append(lams)

  lambdas, residuals = np.vstack(lams_list), np.vstack(residuals_list)

  print('tracking variances...')
  model.train()
  vars = get_pred_vars_laplace(model, var_dataloader, n_classes, device=device)
  print('done')

  sensitivities = np.asarray(residuals) * np.asarray(lambdas) * np.asarray(vars)
  sensitivities = np.sum(np.abs(sensitivities), axis=-1)

  return sensitivities

def test_loop(test_loader, model, loss_fn, device):
  model.eval()
  running_loss = 0.0
  n_correct=0.0
  for X,y in test_loader:
    X, y = X.to(device), y.to(device)
    pred = model(X)
    loss = loss_fn(pred, y)
    running_loss+=loss.item()
    n_correct+=(pred.argmax(1) == y).type(torch.float).sum().item()
  accuracy = n_correct/len(test_loader.dataset)
  return running_loss, accuracy

def task1(trainset, test_set_1, test_set_2, loss_fn, n_classes, n_epochs, device):
  dataloader=DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False)
  dataloader_var=DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=False)

  test_loader_1 = DataLoader(test_set_1, batch_size=BATCH_SIZE, shuffle=True)
  test_loader_2 = DataLoader(test_set_2, batch_size=BATCH_SIZE, shuffle=True)
  model = CNN()
  model = model.to(device)
  model.train()
  optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=OPTIMIZER_MOMENTUM)
  cp_epochs_list=[]
  test_accuracies_1 = []
  test_accuracies_2 = []
  sensitivities_epochs_list = []
  for epoch in range(n_epochs):
    cp_batches_list=[]
    running_loss=0.0
    #training loop
    for batch, (X, y) in enumerate(dataloader, 0):
      X, y = X.to(device), y.to(device)
      optimizer.zero_grad()
      pred = model(X)
      loss = loss_fn(pred, y)
      loss.backward()
      optimizer.step()
      running_loss+=loss.item()
      with torch.no_grad():
        probs = F.softmax(pred, dim=-1)
        _, preds = probs.max(1)
        cp_batch=(preds==y).detach().cpu().numpy()
        cp_batches_list.append(cp_batch)

    with torch.no_grad():
      _, test_1_accuracy = test_loop(test_loader_1, model, loss_fn, device)
      test_accuracies_1.append(test_1_accuracy)
      _, test_2_accuracy = test_loop(test_loader_2, model, loss_fn, device)
      test_accuracies_2.append(test_2_accuracy)
    print(f"Task 1/2: Epoch [{epoch+1}/{n_epochs}]: test_1_accuracy={test_1_accuracy:.2%},  test_2_accuracy={test_2_accuracy:.2%}")

    running_loss = 0.0
    cp_epoch=np.hstack(cp_batches_list)
    cp_epochs_list.append(cp_epoch)
    if not MEASURE_SENSITIVITY_ON_LAST_EPOCH_ONLY:
      # calculate sensitivity
      sensitivities_epoch = prediction_sensitivity(model, dataloader_var, n_classes, device)
      sensitivities_epochs_list.append(sensitivities_epoch)

  if MEASURE_SENSITIVITY_ON_LAST_EPOCH_ONLY:
    # calculate sensitivity
    sensitivities_epoch = prediction_sensitivity(model, dataloader_var, n_classes, device)
    sensitivities_epochs_list.append(sensitivities_epoch)

  correct_predictions=np.stack(cp_epochs_list, axis=-1)
  learning_speeds=np.mean(correct_predictions.astype(int), axis=-1)

  sensitivities_experiment=np.stack(sensitivities_epochs_list, axis=0)
  sensitivities=np.mean(sensitivities_experiment, axis=0)
  return sensitivities, learning_speeds, test_accuracies_1, test_accuracies_2, model


In [13]:
def train_loop(train_loader, model, loss_fn, optimizer, device):
  model.train()
  running_loss = 0.0
  n_correct=0.0
  for X,y in train_loader:
    X, y = X.to(device), y.to(device)
    optimizer.zero_grad()
    pred = model(X)
    loss = loss_fn(pred, y)
    loss.backward()
    optimizer.step()
    running_loss+=loss.item()
    n_correct+=(pred.argmax(1) == y).type(torch.float).sum().item()
  accuracy = n_correct/len(train_loader.dataset)
  return running_loss, accuracy

def task2(model, train_set_2, buffer_set, test_set_1, test_set_2, n_epochs, learning_rate, optimizer_momentum, batch_size, verbose=True):

  test_accuracies_1 = []
  test_accuracies_2 = []

  train_loader_2_with_buffer = DataLoader(ConcatDataset([train_set_2, buffer_set]), batch_size=batch_size, shuffle=True)

  test_loader_1 = DataLoader(test_set_1, batch_size=batch_size, shuffle=True)
  test_loader_2 = DataLoader(test_set_2, batch_size=batch_size, shuffle=True)
  model = model.to(device)

  optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=optimizer_momentum)
  loss_fn = nn.CrossEntropyLoss()

  for epoch in range(n_epochs):

    train_loop(train_loader_2_with_buffer, model, loss_fn, optimizer, device)

    _, test_1_accuracy = test_loop(test_loader_1, model, loss_fn, device)
    test_accuracies_1.append(test_1_accuracy)
    _, test_2_accuracy = test_loop(test_loader_2, model, loss_fn, device)
    test_accuracies_2.append(test_2_accuracy)

    if verbose:
      print(f"Task 2/2: Epoch [{epoch+1}/{n_epochs}]: test_1_accuracy={test_1_accuracy:.2%},  test_2_accuracy={test_2_accuracy:.2%}")

  test_accuracies_1=np.stack(test_accuracies_1, axis=0)
  test_accuracies_2=np.stack(test_accuracies_2, axis=0)

  return test_accuracies_1, test_accuracies_2

def get_random_buffer(dataset, n_samples):
  rng=np.random.default_rng()
  buffer_set=Subset(dataset, rng.choice(len(dataset),n_samples, replace=False))
  return buffer_set

def get_goldilocks_buffer(dataset, n_samples, learning_speeds, remove_lowest_pct=0.0, remove_highest_pct=0.0):
  learning_speeds_ranking=np.argsort(learning_speeds)
  rng=np.random.default_rng()
  buffer_set=Subset(dataset, rng.choice(learning_speeds_ranking[int(len(learning_speeds)*remove_lowest_pct):int(len(learning_speeds)*(1.0-remove_highest_pct))],n_samples, replace=False))
  return buffer_set

def get_sensitivity_buffer(dataset, n_samples, sensitivities, highest):
  sensitivities_ranking=np.argsort(sensitivities)
  if highest:
    sensitivities_ranking=sensitivities_ranking[::-1]
  buffer_set=Subset(dataset, sensitivities_ranking[0:n_samples])
  return buffer_set

def get_mixed_buffer(dataset, size_pct, goldilocks_rate=0.0, mpe_rate=0.0, learning_speeds=None, sensitivities=None, goldilocks_remove_lowest_pct=0.0, goldilocks_remove_highest_pct=0.0, mpe_highest=False):
  n_samples_total=int(len(dataset)*size_pct)
  n_samples_goldilocks=0
  n_samples_mpe=0
  n_samples_random=0

  # This logic is to avoid being off by 1 sample due to rounding
  if goldilocks_rate==0.0:
    n_samples_mpe=int(n_samples_total*mpe_rate)
    n_samples_random=n_samples_total-n_samples_mpe
  elif mpe_rate==0:
    n_samples_goldilocks=int(n_samples_total*goldilocks_rate)
    n_samples_random=n_samples_total-n_samples_goldilocks
  elif goldilocks_rate+mpe_rate==1.0:
    n_samples_goldilocks=int(n_samples_total*goldilocks_rate)
    n_samples_mpe=n_samples_total-n_samples_goldilocks
  else:
    n_samples_goldilocks=int(n_samples_total*goldilocks_rate)
    n_samples_mpe=int(n_samples_total*mpe_rate)
    n_samples_random=n_samples_total-n_samples_goldilocks-n_samples_mpe

  mix=[]

  if n_samples_goldilocks>0:
    b = get_goldilocks_buffer(dataset, n_samples_goldilocks, learning_speeds, remove_lowest_pct=goldilocks_remove_lowest_pct, remove_highest_pct=goldilocks_remove_highest_pct)
    mix.append(b)
  if n_samples_mpe>0:
    b = get_sensitivity_buffer(dataset, n_samples_mpe, sensitivities, mpe_highest)
    mix.append(b)
  if n_samples_random>0:
    b = get_random_buffer(dataset, n_samples_random)
    mix.append(b)

  buffer_set=Subset(dataset, [])
  if len(mix)>0:
    buffer_set = ConcatDataset(mix)

  return buffer_set

def get_file_key(n_epochs, experiment, buffer_pct, goldilocks_rate, mpe_rate, goldilocks_remove_lowest_pct, goldilocks_remove_highest_pct, mpe_highest):
  return f'e2e_ep_{n_epochs}_ex_{experiment}_b_{buffer_pct}_gl_{goldilocks_rate}_mpe_{mpe_rate}_gll_{goldilocks_remove_lowest_pct}_glh_{goldilocks_remove_highest_pct}_mpeh_{mpe_highest}'

def task_2_orchestration(model, trainset_A, trainset_B, testset_A, testset_B, n_epochs, experiment, learning_rate, optimizer_momentum, batch_size, buffer_pct, test_1_accuracy, test_2_accuracy, goldilocks_rate=0.0, mpe_rate=0.0, learning_speeds=None, sensitivities=None, goldilocks_remove_lowest_pct=0.0, goldilocks_remove_highest_pct=0.0, mpe_highest=False, verbose=True):
  file_key=get_file_key(n_epochs, experiment, buffer_pct, goldilocks_rate, mpe_rate, goldilocks_remove_lowest_pct, goldilocks_remove_highest_pct, mpe_highest)
  filename=f"{results_folder_name}/task_accuracies_data_{file_key}.pkl"

  buffer_set = get_mixed_buffer(trainset_A, buffer_pct, goldilocks_rate, mpe_rate, learning_speeds, sensitivities, goldilocks_remove_lowest_pct, goldilocks_remove_highest_pct, mpe_highest)
  test_1_accuracy_2, test_2_accuracy_2 = task2(model, trainset_B, buffer_set, testset_A, testset_B, n_epochs, learning_rate, optimizer_momentum, batch_size, verbose=True)
  description=f"epochs={n_epochs}, exp={experiment}, lr={learning_rate}, bsize={buffer_pct:.0%} (gl={goldilocks_rate:.0%}, mpe={mpe_rate:.0%}, rnd={1.0-goldilocks_rate-mpe_rate:.0%})\n"
  if goldilocks_rate>0.0:
    description=description+f", goldilocks: [remove_lowest_pct={goldilocks_remove_lowest_pct:.0%}, remove_highest_pct={goldilocks_remove_highest_pct:.0%}]"
  if mpe_rate>0.0:
    description=description+f", mpe: [highest={mpe_highest}]"

  # save the results to a file
  full_test_accuracies_1 = [*test_1_accuracy, *test_1_accuracy_2]
  full_test_accuracies_2 = [*test_2_accuracy, *test_2_accuracy_2]

  all_results={"n_epochs":n_epochs,"experiment":experiment, "learning_rate":learning_rate,"optimizer_momentum": optimizer_momentum,
                "batch_size":batch_size,"buffer_pct":buffer_pct,"goldilocks_rate":goldilocks_rate, "mpe_rate":mpe_rate,
                "goldilocks_remove_lowest_pct":goldilocks_remove_lowest_pct,"goldilocks_remove_highest_pct": goldilocks_remove_highest_pct,
                "mpe_highest":mpe_highest, "description": description, "test_accuracies_1": full_test_accuracies_1, "test_accuracies_2": full_test_accuracies_2,
                "learning_speeds": learning_speeds, "sensitivities": sensitivities}

  with open(filename, 'wb') as f:
    pickle.dump(all_results, f)

In [None]:
loss_fn = nn.CrossEntropyLoss()
sensitivities, learning_speed, test_1_accuracy, test_2_accuracy, model = task1(trainset_A, testset_A, testset_B, loss_fn, N_CLASSES, EPOCHS, device)

base_parameters={"trainset_A": trainset_A, "trainset_B":trainset_B, "testset_A": testset_A, "testset_B":testset_B,
                    "n_epochs": EPOCHS, "experiment": int(time.time()), "learning_rate":LEARNING_RATE,
                    "optimizer_momentum": OPTIMIZER_MOMENTUM, "batch_size": BATCH_SIZE, "verbose": True,
                    "learning_speeds":learning_speed, "sensitivities":sensitivities,
                    "test_1_accuracy": test_1_accuracy, "test_2_accuracy": test_2_accuracy}

# No buffer (baseline)
task_2_orchestration(deepcopy(model),buffer_pct=0.00, **base_parameters)

# Pure random buffers
task_2_orchestration(deepcopy(model),buffer_pct=0.40, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, **base_parameters)

# Pure Goldilocks, 40%
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, **base_parameters)

# Pure Goldilocks, 20%
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, **base_parameters)

# Pure Goldilocks, 4%
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=1.00, mpe_rate=0.00, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, **base_parameters)

# Pure MPE, 40%
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.00, mpe_rate=1.00, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.00, mpe_rate=1.00, mpe_highest=True, **base_parameters)

# Pure MPE, 20%
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.00, mpe_rate=1.00, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.00, mpe_rate=1.00, mpe_highest=True, **base_parameters)

# Pure MPE, 4%
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.00, mpe_rate=1.00, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.00, mpe_rate=1.00, mpe_highest=True, **base_parameters)

# Mixed Goldilocks and MPE, 40%
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, mpe_highest=False, **base_parameters)

task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, mpe_highest=True, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, mpe_highest=True, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.40, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, mpe_highest=True, **base_parameters)

# Mixed Goldilocks and MPE, 20%
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, mpe_highest=False, **base_parameters)

task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, mpe_highest=True, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, mpe_highest=True, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.20, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, mpe_highest=True, **base_parameters)

# Mixed Goldilocks and MPE, 4%
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, mpe_highest=False, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, mpe_highest=False, **base_parameters)

task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.15, goldilocks_remove_highest_pct=0.45, mpe_highest=True, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.30, goldilocks_remove_highest_pct=0.30, mpe_highest=True, **base_parameters)
task_2_orchestration(deepcopy(model),buffer_pct=0.04, goldilocks_rate=0.50, mpe_rate=0.50, goldilocks_remove_lowest_pct=0.45, goldilocks_remove_highest_pct=0.15, mpe_highest=True, **base_parameters)


Task 1/2: Epoch [1/50]: test_1_accuracy=20.16%,  test_2_accuracy=0.00%
Task 1/2: Epoch [2/50]: test_1_accuracy=35.00%,  test_2_accuracy=0.00%
Task 1/2: Epoch [3/50]: test_1_accuracy=37.98%,  test_2_accuracy=0.00%
Task 1/2: Epoch [4/50]: test_1_accuracy=43.70%,  test_2_accuracy=0.00%
Task 1/2: Epoch [5/50]: test_1_accuracy=48.44%,  test_2_accuracy=0.00%
Task 1/2: Epoch [6/50]: test_1_accuracy=51.48%,  test_2_accuracy=0.00%
Task 1/2: Epoch [7/50]: test_1_accuracy=53.10%,  test_2_accuracy=0.00%
Task 1/2: Epoch [8/50]: test_1_accuracy=54.16%,  test_2_accuracy=0.00%
Task 1/2: Epoch [9/50]: test_1_accuracy=54.54%,  test_2_accuracy=0.00%
Task 1/2: Epoch [10/50]: test_1_accuracy=55.96%,  test_2_accuracy=0.00%
Task 1/2: Epoch [11/50]: test_1_accuracy=56.96%,  test_2_accuracy=0.00%
Task 1/2: Epoch [12/50]: test_1_accuracy=58.38%,  test_2_accuracy=0.00%
Task 1/2: Epoch [13/50]: test_1_accuracy=59.58%,  test_2_accuracy=0.00%
Task 1/2: Epoch [14/50]: test_1_accuracy=60.70%,  test_2_accuracy=0.00%
T

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


done
Task 2/2: Epoch [1/50]: test_1_accuracy=0.00%,  test_2_accuracy=58.56%
