In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from moe import SimpleCNN, RoutedCNN, SparseMoEConvBlock, SparseMoEConvBlockWeighted
from torchvision import transforms
import torchvision
from tqdm.notebook import tqdm

In [None]:
dataset = load_dataset("cifar100")

In [None]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
def collate_fn(batch):
    imgs = []
    labels = []
    for i in range(len(batch)):
        img = batch[i]["img"]
        img = transform(torchvision.transforms.ToPILImage()(img).convert("RGB"))
        imgs.append(img)
        labels.append(batch[i]["coarse_label"])
    return {
        "img": torch.stack(imgs),
        "coarse_label": torch.tensor(labels),
    }

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim


def identify_top_experts(model, dataloader, target_class, num_top_experts=2):
    model.eval()
    expert_usage = {i: 0 for i in range(model.conv1.num_experts)}

    with torch.no_grad():
        for batch in dataloader:
            inputs = batch["img"]
            labels = batch["coarse_label"]
            inputs, labels = inputs.to(device), labels.to(device)
            mask = labels == target_class
            if not mask.any():
                continue

            target_inputs = inputs[mask]

            # Forward pass through conv1
            router_output = model.conv1.router(
                target_inputs.view(target_inputs.size(0), -1)
            )
            _, selected_experts = torch.topk(router_output, model.conv1.top_k)

            for expert in selected_experts.flatten():
                expert_usage[expert.item()] += 1

    # Sort experts by usage and return top num_top_experts
    sorted_experts = sorted(expert_usage.items(), key=lambda x: x[1], reverse=True)
    return [expert for expert, _ in sorted_experts[:num_top_experts]]


def evaluate_model(model, test_loader, target_class):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        preds = []
        targets = []
        for batch in tqdm(test_loader):
            images = batch["img"].to(device)
            labels = batch["coarse_label"].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            preds.extend(predicted.cpu().numpy())
            targets.extend(labels.cpu().numpy())

        # Accuracy on target class
        correct_target = 0
        total_target = 0
        for pred, target in zip(preds, targets):
            if target == target_class:
                total_target += 1
                if pred == target:
                    correct_target += 1

        # Accuracy on non-target classes
        correct_non_target = 0
        total_non_target = 0
        for pred, target in zip(preds, targets):
            if target != target_class:
                total_non_target += 1
                if pred == target:
                    correct_non_target += 1

        accuracy_target = 100 * correct_target / total_target
        accuracy_non_target = 100 * correct_non_target / total_non_target
        return accuracy_target, accuracy_non_target, targets, preds


class UnlearningLoss(nn.Module):
    def __init__(self, target_class, penalty_weight=2):
        super().__init__()
        self.target_class = target_class
        self.ce_loss = nn.CrossEntropyLoss()
        self.penalty_weight = penalty_weight

    def forward(self, outputs, labels):
        # Standard cross-entropy loss
        # change labels of target class to random class != target_class
        labels = torch.where(
            labels == self.target_class,
            torch.randint_like(labels, 0, 9),
            labels,
        )

        ce_loss = self.ce_loss(outputs, labels)

        predicted = torch.argmax(outputs, 1)

        # Penalize correct classification of target class
        correct_target_class_mask = (predicted == labels) & (
            labels == self.target_class
        )

        penalty = correct_target_class_mask.sum() * self.penalty_weight

        return ce_loss + penalty


def unlearning_procedure(
    model,
    train_dataloader,
    test_dataloder,
    target_class,
    num_epochs=20,
    learning_rate=0.001,
):
    # Identify top experts for the target class
    top_experts = identify_top_experts(model, train_dataloader, target_class)
    # print(f"Top experts for class {target_class}: {top_experts}")

    # Freeze all parameters except the identified experts
    for name, param in model.named_parameters():
        param.requires_grad = False

    for expert_idx in top_experts:
        for param in model.conv1.experts[expert_idx].parameters():
            param.requires_grad = True

    model.conv1.router.weight.requires_grad = True

    model.fc2.weight.requires_grad = True
    # Set up optimizer and loss function
    optimizer = optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate
    )
    unlearning_loss = UnlearningLoss(target_class)

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for batch in train_dataloader:
            inputs = batch["img"]
            labels = batch["coarse_label"]
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = unlearning_loss(outputs, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        # print(
        #    f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_dataloader):.4f}"
        # )

    # Unfreeze all parameters
    for param in model.parameters():
        param.requires_grad = True

    # Evaluate the model
    accuracy_target, accuracy_non_target, targets, preds = evaluate_model(
        model, test_dataloder, target_class
    )

    # print("Unlearning procedure completed.")
    # print(f"Accuracy on target class {target_class}: {accuracy_target}%")
    # print(f"Accuracy on non-target classes: {accuracy_non_target}%")
    return accuracy_target, accuracy_non_target


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


test_loader = DataLoader(
    dataset["test"].with_format("torch"),
    batch_size=32,
    shuffle=False,
    collate_fn=collate_fn,
)


def unlearn(h_params):
    model = RoutedCNN().to(device)
    model.load_state_dict(torch.load("routed_cnn_100.pth", weights_only=True))

    train_loader = DataLoader(
        dataset["train"]
        .shuffle()
        .select(
            range(
                int(
                    len(dataset["train"]) * (h_params["train_dataset_percentage"] / 100)
                )
            )
        )
        .with_format("torch"),
        batch_size=32,
        shuffle=True,
        collate_fn=collate_fn,
    )

    accuracy_target, accuracy_non_target = unlearning_procedure(
        model,
        train_loader,
        test_loader,
        h_params["target_class"],
        num_epochs=h_params["num_epochs"],
        learning_rate=h_params["learning_rate"],
    )

    return accuracy_target, accuracy_non_target


In [None]:
hyperparameters = {
    "train_dataset_percentage": 25,
    "num_epochs": 2,
    "learning_rate": 0.001,
    "target_class": 0,
}

accuracy_target, accuracy_non_target = unlearn(hyperparameters)

In [None]:
print(f"Accuracy on target class: {accuracy_target}")
print(f"Accuracy on non-target classes: {accuracy_non_target}")

In [None]:
import itertools
from tqdm.notebook import tqdm

hyperparameters = {
    "train_dataset_percentage": [5, 25, 50, 100],
    "num_epochs": [1, 2, 3, 4, 5],
    "learning_rate": [0.01, 0.001, 0.0001],
    "target_class": [0],  # , 1, 2, 3, 4, 5, 6, 7, 8, 9],
}

# Generate all combinations of hyperparameters
param_combinations = list(itertools.product(*hyperparameters.values()))

results = []

# Iterate over all combinations
for params in tqdm(param_combinations, desc="Grid Search Progress"):
    h_params = dict(zip(hyperparameters.keys(), params))

    accuracy_target, accuracy_non_target = unlearn(h_params)

    results.append(
        {
            "hyperparameters": h_params,
            "accuracy_target": accuracy_target,
            "accuracy_non_target": accuracy_non_target,
        }
    )

best_result = max(
    results, key=lambda x: ((100 - x["accuracy_target"]) + x["accuracy_non_target"]) / 2
)
print("Best hyperparameters:", best_result["hyperparameters"])
print("Best target accuracy:", best_result["accuracy_target"])
print("Corresponding non-target accuracy:", best_result["accuracy_non_target"])

import json

with open("grid_search_results.json", "w") as f:
    json.dump(results, f)