In [1]:
# This disables SSL certificate verification for the current session which enable us
# to download the MNIST dataset without any problem.
import ssl
ssl._create_default_https_context = ssl._create_unverified_context

In [2]:
# Import dependencies
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader, Subset

# Import of custom functions
from specific_sample_unlearning import selective_train_unlearning
from plot_generator import plot_accuracy, plot_loss, plot_confusion_matrix, plot_pca
from utils import adjust_class_representation, check_class_distribution

In [3]:
# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = MNIST('./data', train=True, download=True, transform=transform)
test_dataset = MNIST('./data', train=False, transform=transform)

# Prepare data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

Create Imabalanced dataset with the overrepresentation of the class 6

In [None]:
#If i > 1, it performs overrepresentation. If 0 < i < 1, it performs underrepresentation
adjusted_train_dataset = adjust_class_representation(train_dataset, class_to_adjust=6, i=2, reduce_to=100, duplication_factor=10)

# Check the class distribution in the underrepresented dataset
check_class_distribution(adjusted_train_dataset)

# Create a dataloader from the adjusted dataset
train_loader = DataLoader(adjusted_train_dataset, batch_size=64, shuffle=True)

In [5]:
# Define the model
class CustomNeuralNetwork(nn.Module):
    def __init__(self):
        super(CustomNeuralNetwork, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu1 = nn.ReLU()  # Register ReLU as part of the model needed for torchviz library
        self.fc2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU()  
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten input
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [6]:
# Initialize model, loss, and optimizer
model = CustomNeuralNetwork()
optimizer = optim.Adam(model.parameters())
criterion = nn.CrossEntropyLoss()

In [None]:
# Visualize the model
from torchviz import make_dot
dummy_input = torch.randn(1, 1, 28, 28)
y = model(dummy_input)
make_dot(y.mean(), params=dict(model.named_parameters()))

In [8]:
# Variables to store accuracy and loss for graphing later
train_acc = []
unlearn_acc = []
train_loss = []
unlearn_loss = []
epochs = list(range(1, 20))

In [9]:
# Train model on the dataset
def train(model, train_loader, criterion, optimizer, epoch):
    model.train()
    correct = 0
    running_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
        
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

    accuracy = 100. * correct / len(train_loader.dataset)
    avg_loss = running_loss / len(train_loader.dataset)
    train_acc.append(accuracy)
    train_loss.append(avg_loss)
    print(f"Train Epoch: {epoch} Accuracy: {accuracy:.2f}%, Loss: {avg_loss:.4f}")

In [10]:
# Test model performance
def test(model, test_loader):
    model.eval()  # Set the model to evaluation mode
    test_loss = 0
    correct = 0
    y_pred = []  # To store all predictions
    y_true = []  # To store all true labels        
        
    with torch.no_grad():  # Disable gradient calculation for inference
        for data, target in test_loader:
            output = model(data)
            test_loss += criterion(output, target).item()  # Add batch loss
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
            y_pred.extend(pred.view(-1).tolist())  # Store predictions, flatten tensor and convert to list
            y_true.extend(target.view(-1).tolist())  # Store true labels, flatten tensor and convert to list
            correct += pred.eq(target.view_as(pred)).sum().item()  # Count correct predictions
    
    test_loss /= len(test_loader.dataset)  # Average loss per sample
    accuracy = 100. * correct / len(test_loader.dataset)  # Compute accuracy
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)\n')
    
    return accuracy, y_pred, y_true

In [None]:
# Train the model on the entire dataset first
for epoch in range(1, 20): 
    train(model, train_loader, criterion, optimizer, epoch)

In [None]:
# Test the model before unlearning
print("Testing pre-trained model:")
train_accuracy_before, y_pred_before, y_true_before = test(model, test_loader)

In [13]:
# Now for unlearning class 6 and learning class 3
# we freeze all weights except for the final layer
for param in model.fc1.parameters():
    param.requires_grad = False
for param in model.fc2.parameters():
    param.requires_grad = False

In [None]:
# Train with selective specific sample unlearning 
for epoch in range(1, 20):
    selective_train_unlearning(model, train_loader, optimizer, criterion, epoch, unlearn_acc, unlearn_loss)

In [None]:
# Test the model after unlearning
print("Testing model after unlearning:")
unlearn_accuracy_after, y_pred_after, y_true_after = test(model, test_loader)

In [None]:
# Accuracy vs Epochs
plot_accuracy(train_acc, unlearn_acc, epochs)

In [None]:
# Loss vs Epochs
plot_loss(train_loss, unlearn_loss, epochs)

In [None]:
# Collect predictions and plot confusion matrix BEFORE unlearning
plot_confusion_matrix(y_true_before, y_pred_before, title="Confusion Matrix Before Unlearning")

In [None]:
plot_confusion_matrix(y_true_after, y_pred_after, title="Confusion Matrix After Unlearning")

In [None]:
# Assume `features` is the output of the model and `labels` are the true labels
features = torch.cat([model(data).detach() for data, _ in test_loader], dim=0).cpu().numpy()
labels = torch.cat([target for _, target in test_loader], dim=0).cpu().numpy()

# Plot the PCA visualization
plot_pca(features, labels, title="PCA Visualization After Unlearning")