In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# from datasets import load_dataset
# from peft import get_peft_model, LoraConfig, TaskType
import flwr as fl
# from opacus import PrivacyEngine
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
# from collections import OrderedDict

# my py files
from device import move_to_device, get_device
from data import load_data
from lora import create_lora_model
from clients import FlowerClient
from differential_privacy import differential_privacy
from server import start_fl_server, stop_fl_server
import logging

logging.basicConfig(level=logging.INFO)

In [2]:
# list of  different experimental setup options:

scenarios = [
    {"name": "LoRA Only", "federated": False, "differential_privacy": False},
    {"name": "LoRA + Federated", "federated": True, "differential_privacy": False},
    {"name": "LoRA + Federated + DP", "federated": True, "differential_privacy": True},
]

# List of transformers to test

models_to_test = [
    "prajjwal1/bert-tiny",
    "google/mobilebert-uncased",
    "distilbert-base-uncased",
]

Prepare a DataFrame to store the results of each experiment:

In [3]:
results = pd.DataFrame(columns=["Model", "Scenario", "Accuracy", "Loss"])

In [None]:
for model_name in models_to_test:
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    trainloader = load_data(tokenizer, split="train[:10%]") # 10% for testing 
    testloader = load_data(tokenizer, split="test[:10%]")

    # trainloader = load_data(tokenizer, split="train")       # 100% for final 
    # testloader = load_data(tokenizer, split="test")



    for scenario in scenarios:
        logging.info(f"Running experiment: {scenario['name']} with {model_name}")

        #initialize the model
        model = create_lora_model(model_class= AutoModelForSequenceClassification, transformer_model = model_name, rank = 16, num_labels = 2)

        # Differential Privacy parameters
        dp_params = {
            "noise_multiplier": 1.0,
            "max_grad_norm": 1.0,
        }

         # Federated Learning Scenario
        if scenario["federated"]:
            # Start FL server
            server_process = start_fl_server()

            try:
                # Simulate Federated Learning
                client = FlowerClient(
                    model=model,
                    trainloader=trainloader,
                    testloader=testloader,
                    device=get_device(),
                    dp_enabled=scenario["differential_privacy"],
                    dp_params=dp_params
                )
                fl.client.start_numpy_client(server_address="localhost:8080", client=client)
            finally:
                # Stop FL server
                stop_fl_server(server_process)
        else:
            # Local training
            optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
            if scenario["differential_privacy"]:
                model, optimizer, trainloader = differential_privacy(model, optimizer, trainloader, **dp_params)
                criterion = torch.nn.CrossEntropyLoss()
                model.train()
                for epoch in range(5):  # Adjust epochs as needed
                    for batch in trainloader:
                        optimizer.zero_grad()

                        # GPU acceleration 
                        input_ids = move_to_device(batch["input_ids"])
                        attention_mask = move_to_device(batch["attention_mask"])
                        labels = move_to_device(batch["label"])

                        # The acual training
                        outputs = model(input_ids, attention_mask=attention_mask)
                        loss = criterion(outputs.logits, labels)
                        loss.backward()
                        optimizer.step()

        # Evaluation
        criterion = torch.nn.CrossEntropyLoss()
        model.eval()
        total_loss, correct, total = 0.0, 0, 0
        with torch.no_grad():
            for batch in testloader:
                print(f"Processing batch of size: {len(batch['input_ids'])}")
                print("Batch name:",batch)
                print(batch["input_ids"].shape)
                
                # GPU acceleration 
                input_ids = move_to_device(batch["input_ids"])
                attention_mask = move_to_device(batch["attention_mask"])
                labels = move_to_device(batch["label"])

                # The actual testing/evaluation
                outputs = model(input_ids, attention_mask=attention_mask)
                total_loss += criterion(outputs.logits, labels).item()
                preds = outputs.logits.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
            
        # Handle zero division
        if total == 0:
            print("Warning: No data in testloader.")
            accuracy = 0.0
        else:
            accuracy = correct / total
        # Add new rows with .loc
        results.loc[len(results)] = {
            "Model": model_name,
            "Scenario": scenario["name"],
            "Loss": total_loss / (len(testloader) if len(testloader) > 0 else 1),
            "Accuracy": accuracy,
        }


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Processing batch of size: 8
Batch name: {'label': tensor([0, 0, 0, 0, 0, 0, 0, 0]), 'input_ids': tensor([[  101,  1045,  2293,  ...,     0,     0,     0],
        [  101,  4276,  1996,  ...,     0,     0,     0],
        [  101,  2049,  1037,  ...,     0,     0,     0],
        ...,
        [  101,  1045,  2018,  ...,     0,     0,     0],
        [  101,  7527, 13109,  ...,     0,     0,     0],
        [  101,  2009,  2941,  ...,     0,     0,     0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}
torch.Size([8, 512])
Processing batch of size: 8
Batch name: {'label': tensor([0, 0, 0, 0, 0, 0, 0, 0]), 'input_ids': tensor([[  101, 10892,  1045,  ...,     0,     0,     0],
        [  101,  9826,  9643,  ...,     0,     0,     0],
        [  101,  2023, 17312,  ...,     0,     0,     0],
        ...

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO flwr 2024-12-01 19:34:06,035 | app.py:163 | Starting Flower server, config: ServerConfig(num_rounds=10, round_timeout=None)
INFO flwr 2024-12-01 19:34:06,048 | app.py:176 | Flower ECE: gRPC server running (10 rounds), SSL is disabled
INFO flwr 2024-12-01 19:34:06,048 | server.py:89 | Initializing global parameters
INFO flwr 2024-12-01 19:34:06,048 | server.py:276 | Requesting initial parameters from one random client
INFO flwr 2024-12-01 19:34:10,802 | grpc.py:52 | Opened insecure gRPC connection (no certificates were passed)
12/01/2024 19:34:10:INFO:Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2024-12-01 19:34:10,806 | connection.py:42 | ChannelConnectivity.IDLE


In [None]:
# Convert results to DataFrame
results_df = pd.DataFrame(results)
print(results_df)

# Accuracy visualization
sns.set(style="whitegrid")
plt.figure(figsize=(12, 6))
sns.barplot(data=results_df, x="Scenario", y="Accuracy", hue="Model")
plt.title("Accuracy Across Different Scenarios and Models")
plt.ylabel("Accuracy")
plt.xticks(rotation=45)
plt.legend(title="Model")
plt.tight_layout()
plt.show()

# Loss visualization
plt.figure(figsize=(12, 6))
sns.barplot(data=results_df, x="Scenario", y="Loss", hue="Model")
plt.title("Loss Across Different Scenarios and Models")
plt.ylabel("Loss")
plt.xticks(rotation=45)
plt.legend(title="Model")
plt.tight_layout()
plt.show()