In [None]:
from fisherunlearn import compute_client_information, find_informative_params, reset_parameters, mia_attack
from fisherunlearn import UnlearnNet

import fisherunlearn

import copy
from torch.utils.data import DataLoader, Subset
from backpack import extend, backpack
from backpack.extensions import DiagHessian, DiagGGNMC, DiagGGNExact
from tqdm.notebook import tqdm

import os
import pickle
import random
import logging
import functools

import torch
from torch import nn
from torch.utils.data import DataLoader, Subset
from torchvision.models import resnet18
import seaborn as sns

import torchmetrics
from torch.multiprocessing import Pool, Queue
torch.multiprocessing.set_start_method('spawn', force=True)

import numpy as np
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm

from typing import TypedDict, Literal

import pandas as pd

In [None]:
TRAIN = True

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
INFO_BATCH_SIZE = 5

In [None]:
# create validation routine
def validate(net, dl, n_classes, device):
    # create metric objects
    tm_acc = torchmetrics.Accuracy(task='multiclass', num_classes=n_classes, average='macro', top_k=1)
    tm_con = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=n_classes)
    # move metric to device
    tm_acc.to(device)
    tm_con.to(device)
    # set network in eval mode
    net.eval()
    # at the end of epoch, validate model
    for inp, gt in dl:
        # move batch to gpu
        inp = inp.to(device).float()
        gt = gt.to(device)
        # remove singleton dimension if it exists
        if gt.ndim > 1:
            gt = gt.squeeze()
        # get output
        with torch.no_grad():
            # perform prediction
            logits = net(inp)

        # update metrics with the raw logits
        tm_acc.update(logits, gt)
        tm_con.update(logits, gt)

    # at the end, compute metric
    acc = tm_acc.compute()
    con = tm_con.compute()
    # return score
    return acc, con

In [None]:
#Loading the data
class Dataset(torch.utils.data.Dataset):

    def __init__(self, csv):
        # read the csv file
        self.df = pd.read_csv(csv, sep=',')
        # self.df = self.df.dropna(axis=0)
        # save cols
        self.output_cols = ['outcome']
        self.input_cols = list(set(self.df.columns)-set(self.output_cols))


    def __len__(self):
        # here i will return the number of samples in the dataset
        return len(self.df)


    def __getitem__(self, idx):
        # here i will load the file in position idx
        cur_sample = self.df.iloc[idx]
        # split in input / ground-truth
        cur_sample_x = cur_sample[self.input_cols]
        cur_sample_y = cur_sample[self.output_cols]
        # convert to torch format
        cur_sample_x = torch.tensor(cur_sample_x.tolist())
        cur_sample_y = torch.tensor(cur_sample_y.tolist())
        # remove dimension on gt
        cur_sample_y = cur_sample_y.squeeze()
        # convert to int
        cur_sample_y = cur_sample_y.long()
        # return values
        return cur_sample_x, cur_sample_y

# create train and validation datasets
train_ds = Dataset('Brcancer/train.csv')
test_ds =  Dataset('Brcancer/test.csv')

eligible = [i for i, sample in enumerate(train_ds) if sample["age"] <= 40]

# --- Option A: pick exactly 1 ---
random.seed(42)  # for reproducibility
indices_to_forget = [random.choice(eligible)] if eligible else []

# --- Option B: pick a fixed count k ---
k = 10  # choose how many to remove
#indices_to_forget = random.sample(eligible, k=min(k, len(eligible)))

# --- Option C: pick a percentage p% of eligible ---
p = 30  # e.g., 30%
n = math.floor(len(eligible) * (p / 100.0))
#indices_to_forget = random.sample(eligible, k=n) if n > 0 else []

# Get the indices for the data we want to keep.
full_train_indices = list(range(len(train_ds)))
indices_to_retain = list(set(full_train_indices) - set(indices_to_forget))

# Create the two datasets using torch.utils.data.Subset
forget_set = Subset(train_ds, indices_to_forget)
retain_set = Subset(train_ds, indices_to_retain)

datasets_for_unlearning = [retain_set, forget_set]

print(f"Total training samples: {len(train_ds)}")
print(f"Number of samples to retain: {len(retain_set)}")
print(f"Number of samples to forget: {len(forget_set)}")

In [None]:
n_inputs = 12
n_classes = 3

class Net(nn.Sequential):
    def __init__(self):
        super(Net, self).__init__(
            nn.Linear(12, 32),
            nn.ReLU(),
            nn.Linear(32, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 3)
        )

In [None]:
def adam_trainer(model, loss_fn, dataloader, epochs):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.train()

    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Using Adam optimizer

    for epoch in tqdm(range(epochs), desc="Training (Adam)", unit="epoch", leave=False):
        epoch_loss = 0.0
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device).float(), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        logging.info(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader):.4f}")

    model.eval()
    return model.cpu()

In [None]:
criterion = nn.CrossEntropyLoss()
batch_size = 5

if TRAIN == True:
  train_dl = torch.utils.data.DataLoader(
    train_ds,
    batch_size = batch_size,
    drop_last = True,
    shuffle = True
    )
  trained_model = adam_trainer(Net(), criterion, train_dl, epochs=300)
  print("Successfully trained the model.")

else:
  experiment_name = 'test'

  trained_model = Net()

  checkpoint = torch.load(experiment_name + '_best.pth', map_location=torch.device(DEVICE))
  trained_model.load_state_dict(checkpoint['net'])
  trained_model.to(DEVICE)
  print("Successfully loaded the trained model.")

In [None]:
# create train dataloader
retain_loader = torch.utils.data.DataLoader(
    retain_set,
    batch_size = batch_size,
    drop_last = True,
    shuffle = False
    )

forget_loader = torch.utils.data.DataLoader(
    forget_set,
    batch_size = batch_size,
    drop_last = True,
    shuffle = False
    )

test_dl = torch.utils.data.DataLoader(
    test_ds,
    batch_size = 1,
    drop_last = False,
    shuffle = False
)

acc_original_on_retain, _ = validate(trained_model, retain_loader, n_classes, DEVICE)
acc_original_on_forget, _ = validate(trained_model, forget_loader, n_classes, DEVICE)
acc_original_on_test, _ = validate(trained_model, test_dl, n_classes, DEVICE)

baseline_results = {
    "Unlearning %": "0% (Original Model)",
    "Forget Set Acc": f"{acc_original_on_forget.item():.4f}",
    "Retain Set Acc": f"{acc_original_on_retain.item():.4f}",
    "Test Set Acc": f"{acc_original_on_test.item():.4f}",
    "Reset Params": 0
}


In [None]:
client_idx_to_remove = 1

unlearning_info = compute_client_information(
    client_idx=client_idx_to_remove,
    model=trained_model,
    criterion=criterion,
    datasets_list=datasets_for_unlearning,
    method='diag_ggn'
)

unlearning_method = "information" #"parameter"

unlearning_percentage = 40

whitelist = None
blacklist = None

test_params_dict = {
    'unlearning_method': unlearning_method,
    'unlearning_percentage': unlearning_percentage,
    'whitelist': whitelist,
    'blacklist': blacklist
}

def run_tests(test_params_dict, trained_model, DEVICE):

  unlearning_method = test_params_dict['unlearning_method']
  unlearning_percentage = test_params_dict['unlearning_percentage']
  whitelist = test_params_dict.get('whitelist', None)
  blacklist = test_params_dict.get('blacklist', None)

  informative_params = find_informative_params(unlearning_info, unlearning_method, unlearning_percentage, whitelist, blacklist)
  total_individual_reset_params = 0
  for name, indices_tensor in informative_params.items():
            if indices_tensor is not None and indices_tensor.numel() > 0:
                 total_individual_reset_params += indices_tensor.shape[0]

  reset_model = Net()
  reset_state_dict = reset_parameters(trained_model.cpu(), informative_params)
  reset_model.load_state_dict(reset_state_dict)
  reset_model.to(DEVICE)

  unlearn_model = UnlearnNet(reset_model, informative_params)

  print("Unlearn model parameters:")
  for name , param in unlearn_model.named_parameters():
      print(name, param.shape)

  unlearned_model_finetuned = adam_trainer(unlearn_model, criterion, retain_loader, epochs=1)

  acc_forget, _ = validate(reset_model, forget_loader, n_classes, DEVICE)
  acc_retain, _ = validate(reset_model, retain_loader, n_classes, DEVICE)
  acc_test, _ = validate(reset_model, test_dl, n_classes, DEVICE)

  acc_forget_relearned, _ = validate(unlearned_model_finetuned, forget_loader, n_classes, DEVICE)
  acc_retain_relearned, _ = validate(unlearned_model_finetuned, retain_loader, n_classes, DEVICE)
  acc_test_relearned, _ = validate(unlearned_model_finetuned, test_dl, n_classes, DEVICE)

  result = {
      "Unlearning %": f"{unlearning_percentage}%",
      "Forget Set Acc": f"{acc_forget.item():.4f}",
      "Retain Set Acc": f"{acc_retain.item():.4f}",
      "Test Set Acc": f"{acc_test.item():.4f}",
      "Forget Set Acc Relearning": f"{acc_forget_relearned.item():.4f}",
      "Retain Set Acc Relearning": f"{acc_retain_relearned.item():.4f}",
      "Test Set Acc Relearning": f"{acc_test_relearned.item():.4f}",
      "Reset Params": total_individual_reset_params
  }
  return result

In [None]:
percentages_to_test = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
all_results = [baseline_results]

print("\n--- Starting Unlearning Evaluation Loop ---")

for p in percentages_to_test:
    print(f"Running test for {p}% unlearning...")

    test_params = {
        'unlearning_method': "information",
        'unlearning_percentage': p,
        'whitelist': None,
        'blacklist': None
    }

    current_result = run_tests(test_params, trained_model, n_inputs, n_classes, DEVICE)
    all_results.append(current_result)

print("\n--- Evaluation Complete ---")

results_df = pd.DataFrame(all_results)
display(results_df)