In [None]:
!pip install --upgrade transformers

In [None]:
import os
os.environ("OMP_NUM_THREADS")=1

In [None]:
%%writefile ddp_modernbert_bilstm.py
#!/usr/bin/env python
import os
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import BCEWithLogitsLoss
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# -----------------------------
# Dataset Processing with Event Context
# -----------------------------
class BiasDataset(Dataset):
    """ Custom dataset for bias classification with event-level context """
    def __init__(self, df, tokenizer, max_length=1024):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        label = torch.tensor(row["label"], dtype=torch.float32)

        sentence_encoding = self.tokenizer(row['sentence'], truncation=True, padding="max_length",
                                           max_length=self.max_length, return_tensors="pt")
        article_encoding = self.tokenizer(row['article_text'], truncation=True, padding="max_length",
                                          max_length=self.max_length, return_tensors="pt")
        ev1_encoding = self.tokenizer(row['ev_1'], truncation=True, padding="max_length",
                                      max_length=self.max_length, return_tensors="pt")
        ev2_encoding = self.tokenizer(row['ev_2'], truncation=True, padding="max_length",
                                      max_length=self.max_length, return_tensors="pt")

        return {
            "sentence_input_ids": sentence_encoding["input_ids"].squeeze(0),
            "sentence_attention_mask": sentence_encoding["attention_mask"].squeeze(0),
            "article_input_ids": article_encoding["input_ids"].squeeze(0),
            "article_attention_mask": article_encoding["attention_mask"].squeeze(0),
            "ev1_input_ids": ev1_encoding["input_ids"].squeeze(0),
            "ev1_attention_mask": ev1_encoding["attention_mask"].squeeze(0),
            "ev2_input_ids": ev2_encoding["input_ids"].squeeze(0),
            "ev2_attention_mask": ev2_encoding["attention_mask"].squeeze(0),
            "label": label,
        }

# -----------------------------
# Model: ModernBERT-BiLSTM with Event Context
# -----------------------------
class ModernBertBiLSTM(nn.Module):
    def __init__(self, hidden_dim=512, num_layers=2, num_classes=1, fine_tune_layers=2):
        super(ModernBertBiLSTM, self).__init__()
        self.bert = AutoModel.from_pretrained("answerdotai/modernBERT-base")

        #  Freeze All Layers First
        for param in self.bert.parameters():
            param.requires_grad = False

        #  Unfreeze Last Few Layers
        for param in self.bert.layers[-fine_tune_layers:].parameters():
            param.requires_grad = True


        self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size, hidden_size=hidden_dim,
                            num_layers=num_layers, batch_first=True, bidirectional=True)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 8, 256),  # Increased size for event-level embeddings
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
        )

    def encode_text(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        _, (hidden, _) = self.lstm(last_hidden_state)
        embedding = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
        return embedding

    def forward(self, article_input_ids, article_attention_mask, sentence_input_ids, sentence_attention_mask,
                ev1_input_ids, ev1_attention_mask, ev2_input_ids, ev2_attention_mask):
        primary_embedding = self.encode_text(article_input_ids, article_attention_mask)
        sentence_embedding = self.encode_text(sentence_input_ids, sentence_attention_mask)
        ev1_embedding = self.encode_text(ev1_input_ids, ev1_attention_mask)
        ev2_embedding = self.encode_text(ev2_input_ids, ev2_attention_mask)

        combined_features = torch.cat([sentence_embedding, primary_embedding, ev1_embedding, ev2_embedding], dim=1)
        logits = self.classifier(combined_features).squeeze(1)
        return logits
# -----------------------------
# Loss Function
# -----------------------------
def get_weighted_loss(train_df, device):
    num_pos = train_df['label'].sum()
    num_neg = len(train_df) - num_pos
    pos_weight = torch.tensor(num_neg / num_pos, dtype=torch.float32).to(device)
    return BCEWithLogitsLoss(pos_weight=pos_weight)

# -----------------------------
# Evaluation Function
# -----------------------------
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    predictions, true_labels = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            sentence_input_ids = batch["sentence_input_ids"].to(device)
            sentence_attention_mask = batch["sentence_attention_mask"].to(device)
            article_input_ids = batch["article_input_ids"].to(device)
            article_attention_mask = batch["article_attention_mask"].to(device)
            ev1_input_ids = batch["ev1_input_ids"].to(device)
            ev1_attention_mask = batch["ev1_attention_mask"].to(device)
            ev2_input_ids = batch["ev2_input_ids"].to(device)
            ev2_attention_mask = batch["ev2_attention_mask"].to(device)
            labels = batch["label"].to(device)

            logits = model(article_input_ids, article_attention_mask, sentence_input_ids, sentence_attention_mask,
                           ev1_input_ids, ev1_attention_mask, ev2_input_ids, ev2_attention_mask)
            loss = criterion(logits, labels)
            total_loss += loss.item()

            preds = torch.sigmoid(logits).round().cpu().numpy()
            predictions.extend(preds)
            true_labels.extend(labels.cpu().numpy())

    acc = accuracy_score(true_labels, predictions)
    precision, recall, f1, _ = precision_recall_fscore_support(true_labels, predictions, average="binary", zero_division=0)
    return total_loss / len(dataloader), acc, precision, recall, f1

# -----------------------------
# Training Function
# -----------------------------
def train_one_epoch(model, dataloader, optimizer, criterion, device, grad_accum_steps=1):
    model.train()
    total_loss = 0
    optimizer.zero_grad()
    loop = tqdm(enumerate(dataloader), total=len(dataloader), desc="Training", leave=False)

    for i, batch in loop:
        sentence_input_ids = batch["sentence_input_ids"].to(device)
        sentence_attention_mask = batch["sentence_attention_mask"].to(device)
        article_input_ids = batch["article_input_ids"].to(device)
        article_attention_mask = batch["article_attention_mask"].to(device)
        ev1_input_ids = batch["ev1_input_ids"].to(device)
        ev1_attention_mask = batch["ev1_attention_mask"].to(device)
        ev2_input_ids = batch["ev2_input_ids"].to(device)
        ev2_attention_mask = batch["ev2_attention_mask"].to(device)
        labels = batch["label"].to(device)

        logits = model(article_input_ids, article_attention_mask, sentence_input_ids, sentence_attention_mask,
                       ev1_input_ids, ev1_attention_mask, ev2_input_ids, ev2_attention_mask)
        loss = criterion(logits, labels) / grad_accum_steps
        loss.backward()
        total_loss += loss.item() * grad_accum_steps

        if (i + 1) % grad_accum_steps == 0 or (i + 1) == len(dataloader):
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

        loop.set_postfix(loss=loss.item() * grad_accum_steps)

    return total_loss / len(dataloader)

# -----------------------------
# Main DDP Training Function
# -----------------------------
def train_ddp(args):
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    dist.init_process_group(backend="nccl")
    world_size = dist.get_world_size()

    torch.manual_seed(args.seed)

    tokenizer = AutoTokenizer.from_pretrained("answerdotai/modernBERT-base")

    # Load Data
    articles_df = pd.read_excel("/kaggle/input/d/jahnavimurali/basil-dataset/articles.xlsx")
    bias_df = pd.read_excel("/kaggle/input/d/jahnavimurali/basil-dataset/labeled_dataset.xlsx")
    bias_df.dropna(subset=['sentence'], inplace=True)

    merged_df = bias_df.merge(articles_df[['event_id', 'source', 'article_text']], on=['event_id', 'source'], how="left")
    merged_df['label']=merged_df['inf']

    articles_grouped = articles_df.groupby("event_id")["article_text"].apply(list).to_dict()

    def find_secondary_articles(row):
        all_articles = articles_grouped.get(row["event_id"], [])
        primary_article = row["article_text"]
        secondary_articles = [article for article in all_articles if article != primary_article]
        return secondary_articles[:2] if len(secondary_articles) >= 2 else (None, None)

    ev_1_articles, ev_2_articles = zip(*merged_df.apply(find_secondary_articles, axis=1))
    merged_df["ev_1"], merged_df["ev_2"] = ev_1_articles, ev_2_articles
    merged_df = merged_df.dropna(subset=['sentence', 'ev_1', 'ev_2'])

    unique_events = merged_df['event_id'].unique()
    train_ids, rem_ids = train_test_split(unique_events, test_size=0.1, random_state=42)
    val_ids, _ = train_test_split(rem_ids, test_size=0.5, random_state=42)

    train_df = merged_df[merged_df['event_id'].isin(train_ids)].reset_index(drop=True)
    val_df = merged_df[merged_df['event_id'].isin(val_ids)].reset_index(drop=True)

    if local_rank == 0:
        print(f"Train: {len(train_df)} samples, Val: {len(val_df)} samples")

    train_dataset = BiasDataset(train_df, tokenizer, max_length=args.max_length)
    val_dataset = BiasDataset(val_df, tokenizer, max_length=args.max_length)

    train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=local_rank, shuffle=True)
    val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=local_rank, shuffle=False)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.num_workers, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, sampler=val_sampler, num_workers=args.num_workers, pin_memory=True)


    model = ModernBertBiLSTM(fine_tune_layers=2).to(device)
    model = DDP(model, device_ids=[local_rank], output_device=local_rank)

    criterion = get_weighted_loss(train_df, device)
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs)

    for epoch in range(1, args.epochs + 1):
        train_sampler.set_epoch(epoch)
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, args.grad_accum_steps)
        scheduler.step()

        # Perform validation after each epoch
        val_loss, val_acc, val_prec, val_rec, val_f1 = evaluate(model, val_loader, criterion, device)

        # Only rank 0 prints results and saves checkpoints
        if local_rank == 0:
            print(f"\nEpoch {epoch}:")
            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            print(f"Val Accuracy: {val_acc:.4f} | Precision: {val_prec:.4f} | Recall: {val_rec:.4f} | F1-score: {val_f1:.4f}")

            # Save checkpoint
            checkpoint = {
                "model_state_dict": model.module.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "epoch": epoch
            }
            checkpoint_path = os.path.join(args.output_dir, f"model_ddp_epoch_{epoch}.pth")
            torch.save(checkpoint, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

    dist.destroy_process_group()

# -----------------------------
# Main function and argument parsing
# -----------------------------
def main():
    parser = argparse.ArgumentParser(description="DDP Training for ModermBERT+BiLSTM with Event Context")
    parser.add_argument("--output_dir", type=str, default="./checkpoints", help="Directory to save checkpoints")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size per GPU")
    parser.add_argument("--num_workers", type=int, default=4, help="Number of workers for DataLoader")
    parser.add_argument("--epochs", type=int, default=5, help="Total number of epochs")
    parser.add_argument("--lr", type=float, default=1e-5, help="Learning rate")
    parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay")
    parser.add_argument("--max_length", type=int, default=1024, help="Maximum sequence length")
    parser.add_argument("--hidden_dim", type=int, default=512, help="Hidden dimension for BiLSTM")
    parser.add_argument("--num_layers", type=int, default=2, help="Number of BiLSTM layers")
    parser.add_argument("--grad_accum_steps", type=int, default=4, help="Gradient accumulation steps")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    train_ddp(args)

if __name__ == "__main__":
    main()

Overwriting ddp_modernbert_bilstm.py


In [5]:
!torchrun --nproc_per_node=2 ddp_modernbert_bilstm.py \
    --output_dir ./kaggle/working/checkpoints \
    --batch_size 8 \
    --epochs 5 \
    --grad_accum_steps 4


W0221 09:19:04.540000 236 torch/distributed/run.py:793] 
W0221 09:19:04.540000 236 torch/distributed/run.py:793] *****************************************
W0221 09:19:04.540000 236 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0221 09:19:04.540000 236 torch/distributed/run.py:793] *****************************************
Train: 7058 samples, Val: 458 samples
Train: 7058 samples, Val: 458 samples
2025-02-21 09:19:12.250710: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-02-21 09:19:12.274255: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already 