In [1]:
#Imports 
import torch
import copy
import pandas as pd
import numpy as np

from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from transformers import AutoTokenizer, AutoModelForSequenceClassification

from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix

#Checking Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
#Dataset Class Definition
class FinSentimentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts.tolist()
        self.labels = labels.tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }

In [3]:
#Loading Tokeniser
tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert")

In [5]:
#Loading Validation Splits
X_tw_val = pd.read_csv("data/splits/twitter_val_text.csv").squeeze()
y_tw_val = pd.read_csv("data/splits/twitter_val_labels.csv").squeeze()

X_news_val = pd.read_csv("data/splits/news_val_text.csv").squeeze()
y_news_val = pd.read_csv("data/splits/news_val_labels.csv").squeeze()

X_reports_val = pd.read_csv("data/splits/reports_val_text.csv").squeeze()
y_reports_val = pd.read_csv("data/splits/reports_val_labels.csv").squeeze()

print("Validation splits loaded")

Validation splits loaded


In [6]:
#DataLoaders
tw_val_loader = DataLoader(
    FinSentimentDataset(X_tw_val, y_tw_val, tokenizer),
    batch_size=16,
    shuffle=False
)

news_val_loader = DataLoader(
    FinSentimentDataset(X_news_val, y_news_val, tokenizer),
    batch_size=16,
    shuffle=False
)

reports_val_loader = DataLoader(
    FinSentimentDataset(X_reports_val, y_reports_val, tokenizer),
    batch_size=16,
    shuffle=False
)

print("Validation loaders ready")

Validation loaders ready


In [7]:
#Loading Full Processed Dataset
twitter = pd.read_csv("data/processed/twitter.csv")
news = pd.read_csv("data/processed/news.csv")
reports = pd.read_csv("data/processed/reports.csv")

print("Full datasets loaded")

Full datasets loaded


In [8]:
#Recreating Train Loaders
from sklearn.model_selection import train_test_split

X_tw_train, _, y_tw_train, _ = train_test_split(
    twitter["text"], twitter["label"], test_size=0.2, random_state=42)

X_news_train, _, y_news_train, _ = train_test_split(
    news["text"], news["label"], test_size=0.2, random_state=42)

X_reports_train, _, y_reports_train, _ = train_test_split(
    reports["text"], reports["label"], test_size=0.2, random_state=42)

tw_train_loader = DataLoader(
    FinSentimentDataset(X_tw_train, y_tw_train, tokenizer),
    batch_size=16,
    shuffle=True
)

news_train_loader = DataLoader(
    FinSentimentDataset(X_news_train, y_news_train, tokenizer),
    batch_size=16,
    shuffle=True
)

reports_train_loader = DataLoader(
    FinSentimentDataset(X_reports_train, y_reports_train, tokenizer),
    batch_size=16,
    shuffle=True
)

print("Train loaders ready")

Train loaders ready


In [9]:
#Defining Evaluation Function
def evaluate_model(model, dataloader):
    model.eval()
    preds, labels = [], []

    with torch.no_grad():
        for batch in dataloader:
            outputs = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device)
            )
            pred = torch.argmax(outputs.logits, dim=1).cpu().numpy()
            preds.extend(pred)
            labels.extend(batch["labels"].numpy())

    acc = accuracy_score(labels, preds)
    _, _, f1, _ = precision_recall_fscore_support(labels, preds, average="macro")

    return acc, f1

In [10]:
#Initialising a Global Model
global_model = AutoModelForSequenceClassification.from_pretrained(
    "ProsusAI/finbert",
    num_labels=3
).to(device)

print("Global model initialised")

Global model initialised


In [12]:
#Local Train Function - FedAvg (Standard Fine Tuning)
def local_train(model, train_loader, epochs=3):
    model.train()
    optimizer = AdamW(model.parameters(), lr=2e-5)

    for _ in range(epochs):
        for batch in train_loader:
            optimizer.zero_grad()

            outputs = model(
                input_ids=batch["input_ids"].to(device),
                attention_mask=batch["attention_mask"].to(device),
                labels=batch["labels"].to(device)
            )

            loss = outputs.loss
            loss.backward()
            optimizer.step()

    return model

In [13]:
#Adaptive Aggregation Function
def adaptive_fedavg(global_model, client_models, client_sizes, client_scores):

    global_dict = global_model.state_dict()

    # Weighted score = size Ã— performance
    weighted_scores = [
        client_sizes[i] * client_scores[i]
        for i in range(len(client_models))
    ]

    total_weight = sum(weighted_scores)

    for key in global_dict.keys():
        global_dict[key] = sum(
            weighted_scores[i] * client_models[i].state_dict()[key]
            for i in range(len(client_models))
        ) / total_weight

    global_model.load_state_dict(global_dict)

    return global_model

In [14]:
#Adaptive Federated Training (Multi Round)
ROUNDS = 10
LOCAL_EPOCHS = 3

adaptive_results = []

for r in range(ROUNDS):
    print(f"\n===== Adaptive Round {r+1}/{ROUNDS} =====")

    #Copying global model to clients
    tw_model = copy.deepcopy(global_model)
    news_model = copy.deepcopy(global_model)
    reports_model = copy.deepcopy(global_model)

    #Local training
    tw_model = local_train(tw_model, tw_train_loader, LOCAL_EPOCHS)
    news_model = local_train(news_model, news_train_loader, LOCAL_EPOCHS)
    reports_model = local_train(reports_model, reports_train_loader, LOCAL_EPOCHS)

    #Evaluating each client model on its validation set
    _, tw_f1 = evaluate_model(tw_model, tw_val_loader)
    _, news_f1 = evaluate_model(news_model, news_val_loader)
    _, reports_f1 = evaluate_model(reports_model, reports_val_loader)

    client_scores = [tw_f1, news_f1, reports_f1]
    client_sizes = [len(X_tw_train), len(X_news_train), len(X_reports_train)]

    #Adaptive Aggregation 
    global_model = adaptive_fedavg(
        global_model,
        [tw_model, news_model, reports_model],
        client_sizes,
        client_scores
    )

    #Evaluating updated global model
    _, g_tw = evaluate_model(global_model, tw_val_loader)
    _, g_news = evaluate_model(global_model, news_val_loader)
    _, g_reports = evaluate_model(global_model, reports_val_loader)

    avg_f1 = (g_tw + g_news + g_reports) / 3

    adaptive_results.append([r+1, g_tw, g_news, g_reports, avg_f1])

    print(f"Twitter F1: {g_tw:.4f} | News F1: {g_news:.4f} | Reports F1: {g_reports:.4f}")
    print(f"Average F1: {avg_f1:.4f}")


===== Adaptive Round 1/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.4737 | News F1: 0.5751 | Reports F1: 0.6085
Average F1: 0.5524

===== Adaptive Round 2/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.5963 | News F1: 0.8725 | Reports F1: 0.8020
Average F1: 0.7570

===== Adaptive Round 3/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6218 | News F1: 0.9210 | Reports F1: 0.8355
Average F1: 0.7928

===== Adaptive Round 4/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6312 | News F1: 0.9240 | Reports F1: 0.8322
Average F1: 0.7958

===== Adaptive Round 5/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6351 | News F1: 0.9283 | Reports F1: 0.8349
Average F1: 0.7994

===== Adaptive Round 6/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6282 | News F1: 0.9718 | Reports F1: 0.8591
Average F1: 0.8197

===== Adaptive Round 7/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6275 | News F1: 0.9741 | Reports F1: 0.8630
Average F1: 0.8215

===== Adaptive Round 8/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6306 | News F1: 0.9604 | Reports F1: 0.8586
Average F1: 0.8165

===== Adaptive Round 9/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6306 | News F1: 0.9649 | Reports F1: 0.8578
Average F1: 0.8178

===== Adaptive Round 10/10 =====


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Twitter F1: 0.6310 | News F1: 0.9752 | Reports F1: 0.8637
Average F1: 0.8233


In [15]:
#Saving Adaptive Aggregation Results

import os
os.makedirs("results/adaptive", exist_ok=True)

adaptive_df = pd.DataFrame(
    adaptive_results,
    columns=["Round","Twitter_F1","News_F1","Reports_F1","Avg_F1"]
)

adaptive_df.to_csv("results/adaptive/adaptive_fedavg_10_rounds.csv", index=False)

print("Results saved successfully")
adaptive_df.tail()

Results saved successfully


Unnamed: 0,Round,Twitter_F1,News_F1,Reports_F1,Avg_F1
5,6,0.62815,0.971818,0.859079,0.819682
6,7,0.627504,0.974073,0.862977,0.821518
7,8,0.630613,0.960372,0.858615,0.816534
8,9,0.630588,0.964942,0.857829,0.817786
9,10,0.630984,0.975198,0.863701,0.823294


In [16]:
#Saving Final Adaptive Global Model
os.makedirs("models/adaptive_fedavg", exist_ok=True)

global_model.save_pretrained("models/adaptive_fedavg")
tokenizer.save_pretrained("models/adaptive_fedavg")

print("Adaptive FedAvg model saved")

Adaptive FedAvg model saved


In [17]:
adaptive_df

Unnamed: 0,Round,Twitter_F1,News_F1,Reports_F1,Avg_F1
0,1,0.473695,0.575124,0.608521,0.552446
1,2,0.596299,0.87255,0.80201,0.756953
2,3,0.621846,0.921038,0.83553,0.792805
3,4,0.631242,0.924029,0.832222,0.795831
4,5,0.635147,0.928296,0.834875,0.799439
5,6,0.62815,0.971818,0.859079,0.819682
6,7,0.627504,0.974073,0.862977,0.821518
7,8,0.630613,0.960372,0.858615,0.816534
8,9,0.630588,0.964942,0.857829,0.817786
9,10,0.630984,0.975198,0.863701,0.823294
