In [1]:
#!/usr/bin/env python3
# chatGPT

import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, Sampler
from transformers import AutoModel, AutoTokenizer
from sklearn.model_selection import train_test_split
import os
import random
from itertools import cycle
import time
from torch.cuda.amp import autocast, GradScaler

# Set random seed
torch.manual_seed(42)
torch.backends.cudnn.benchmark = True  # Optimize GPU performance

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")

# Load datasets
english_df = pd.read_csv("/kaggle/input/english-ds-binary/BinaryCsv_English_dataset_UTF8.csv", encoding='utf-8')
telugu_df = pd.read_csv("/kaggle/input/telugu-ds-binary/Binary_telugu_dataset.csv", encoding='utf-8')

# Add domain labels
english_df['domain'] = 0
telugu_df['domain'] = 1

# Combine datasets
full_df = pd.concat([english_df, telugu_df])

# Split into train and validation sets
train_df, val_df = train_test_split(full_df, test_size=0.2, stratify=full_df[['class', 'domain']])

# Custom Dataset class
class SuicideDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=128):
        self.texts = dataframe['text'].tolist()
        self.labels = dataframe['class'].tolist()
        self.domains = dataframe['domain'].tolist()
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        domain = self.domains[idx]
        encoding = self.tokenizer(
            text, return_tensors='pt', max_length=self.max_length,
            truncation=True, padding='max_length'
        )
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long),
            'domains': torch.tensor(domain, dtype=torch.long)
        }

# Balanced Batch Sampler
class BalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size_per_lang):
        self.eng_indices = [i for i, domain in enumerate(dataset.domains) if domain == 0]
        self.tel_indices = [i for i, domain in enumerate(dataset.domains) if domain == 1]
        self.batch_size_per_lang = batch_size_per_lang
        self.total_batches = len(self.eng_indices) // batch_size_per_lang

    def __iter__(self):
        random.shuffle(self.eng_indices)
        tel_iter = cycle(self.tel_indices)
        for i in range(self.total_batches):
            eng_batch = self.eng_indices[i * self.batch_size_per_lang:(i + 1) * self.batch_size_per_lang]
            tel_batch = [next(tel_iter) for _ in range(self.batch_size_per_lang)]
            yield eng_batch + tel_batch

    def __len__(self):
        return self.total_batches

# Gradient Reversal Layer
class GradientReversal(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambda_ * grad_output, None

# DAT Model
class DATModel(nn.Module):
    def __init__(self, base_model, hidden_size=768, num_classes=2, num_domains=2):
        super(DATModel, self).__init__()
        self.base_model = base_model
        self.task_classifier = nn.Linear(hidden_size, num_classes)
        self.domain_classifier = nn.Linear(hidden_size, num_domains)
        self.dropout = nn.Dropout(0.1)

    def forward(self, input_ids, attention_mask, lambda_=1.0):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        features = outputs.last_hidden_state[:, 0, :]
        features = self.dropout(features)
        task_logits = self.task_classifier(features)
        domain_features = GradientReversal.apply(features, lambda_)
        domain_logits = self.domain_classifier(domain_features)
        return task_logits, domain_logits

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/muril-base-cased")
base_model = AutoModel.from_pretrained("google/muril-base-cased")
model = DATModel(base_model).to(device)

if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

# Create datasets
train_dataset = SuicideDataset(train_df, tokenizer)
val_dataset = SuicideDataset(val_df, tokenizer)

# DataLoader settings
batch_size_per_lang = 64  # 32 English + 32 Telugu = 64 total batch size
train_sampler = BalancedBatchSampler(train_dataset, batch_size_per_lang)
train_loader = DataLoader(train_dataset, batch_sampler=train_sampler, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

# Loss functions
criterion_task = nn.CrossEntropyLoss()
criterion_domain = nn.CrossEntropyLoss()

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

# AMP (Mixed Precision Training)
scaler = GradScaler()

# Training loop
num_epochs = 3
lambda_ = 1.0

for epoch in range(num_epochs):
    model.train()
    total_batches = len(train_loader)
    
    # Initialize epoch metrics
    total_task_loss = 0.0
    total_domain_loss = 0.0
    total_correct_task = 0
    total_samples = 0
    start_time = time.time()
    
    for batch_idx, batch in enumerate(train_loader, 1):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        domains = batch['domains'].to(device)

        optimizer.zero_grad()

        with autocast():
            task_logits, domain_logits = model(input_ids, attention_mask, lambda_)
            task_loss = criterion_task(task_logits, labels)
            domain_loss = criterion_domain(domain_logits, domains)
            total_loss = task_loss + lambda_ * domain_loss

        # Scale loss and backpropagate
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update epoch metrics
        total_task_loss += task_loss.item()
        total_domain_loss += domain_loss.item()
        predictions = torch.argmax(task_logits, dim=1)
        total_correct_task += torch.sum(predictions == labels).item()
        total_samples += labels.size(0)

        # Print metrics every 100 batches
        if batch_idx % 100 == 0:
            progress = (batch_idx / total_batches) * 100
            print(f"Epoch {epoch+1}, Batch {batch_idx}/{total_batches} ({progress:.2f}%), "
                  f"Task Loss: {task_loss.item():.4f}, Domain Loss: {domain_loss.item():.4f}, "
                  f"Task Accuracy: {(predictions == labels).float().mean():.4f}")

    # Epoch summary metrics
    epoch_task_loss = total_task_loss / len(train_loader)
    epoch_domain_loss = total_domain_loss / len(train_loader)
    epoch_task_accuracy = total_correct_task / total_samples
    epoch_time = time.time() - start_time
    print(f"Epoch {epoch+1} Completed: Task Loss: {epoch_task_loss:.4f}, "
          f"Domain Loss: {epoch_domain_loss:.4f}, Task Accuracy: {epoch_task_accuracy:.4f}, "
          f"Time: {epoch_time:.2f}s")

    # Overall Validation metrics calculation (on entire 20% split)
    model.eval()
    val_task_loss = 0.0
    val_domain_loss = 0.0
    val_correct = 0
    val_samples = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            domains = batch['domains'].to(device)
            
            with autocast():
                task_logits, domain_logits = model(input_ids, attention_mask, lambda_)
                loss_task = criterion_task(task_logits, labels)
                loss_domain = criterion_domain(domain_logits, domains)
            
            val_task_loss += loss_task.item()
            val_domain_loss += loss_domain.item()
            predictions = torch.argmax(task_logits, dim=1)
            val_correct += torch.sum(predictions == labels).item()
            val_samples += labels.size(0)
    
    val_task_loss /= len(val_loader)
    val_domain_loss /= len(val_loader)
    val_accuracy = val_correct / val_samples
    print(f"Validation (Overall) - Task Loss: {val_task_loss:.4f}, Domain Loss: {val_domain_loss:.4f}, "
          f"Task Accuracy: {val_accuracy:.4f}")

    # Telugu-specific Validation metrics calculation
    # Filter the validation DataFrame for Telugu samples (domain==1)
    telugu_val_df = val_df[val_df['domain'] == 1]
    telugu_val_dataset = SuicideDataset(telugu_val_df, tokenizer)
    telugu_val_loader = DataLoader(telugu_val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)
    
    telugu_val_task_loss = 0.0
    telugu_val_domain_loss = 0.0
    telugu_val_correct = 0
    telugu_val_samples = 0
    with torch.no_grad():
        for batch in telugu_val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            domains = batch['domains'].to(device)
            
            with autocast():
                task_logits, domain_logits = model(input_ids, attention_mask, lambda_)
                loss_task = criterion_task(task_logits, labels)
                loss_domain = criterion_domain(domain_logits, domains)
            
            telugu_val_task_loss += loss_task.item()
            telugu_val_domain_loss += loss_domain.item()
            predictions = torch.argmax(task_logits, dim=1)
            telugu_val_correct += torch.sum(predictions == labels).item()
            telugu_val_samples += labels.size(0)
    
    telugu_val_task_loss /= len(telugu_val_loader)
    telugu_val_domain_loss /= len(telugu_val_loader)
    telugu_val_accuracy = telugu_val_correct / telugu_val_samples
    print(f"Validation (Telugu Only) - Task Loss: {telugu_val_task_loss:.4f}, Domain Loss: {telugu_val_domain_loss:.4f}, "
          f"Task Accuracy: {telugu_val_accuracy:.4f}")

    # Save model checkpoint
    torch.save(model.state_dict(), f'model_epoch_{epoch+1}.pth')


Using device: cuda
Using 2 GPUs


tokenizer_config.json:   0%|          | 0.00/206 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/3.16M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/113 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/953M [00:00<?, ?B/s]

  scaler = GradScaler()
  with autocast():


Epoch 1, Batch 100/2916 (3.43%), Task Loss: 0.6274, Domain Loss: 0.8337, Task Accuracy: 0.7969
Epoch 1, Batch 200/2916 (6.86%), Task Loss: 0.5736, Domain Loss: 0.8170, Task Accuracy: 0.8516
Epoch 1, Batch 300/2916 (10.29%), Task Loss: 0.4981, Domain Loss: 0.7931, Task Accuracy: 0.9453
Epoch 1, Batch 400/2916 (13.72%), Task Loss: 0.4731, Domain Loss: 0.7795, Task Accuracy: 0.9141
Epoch 1, Batch 500/2916 (17.15%), Task Loss: 0.4030, Domain Loss: 0.7628, Task Accuracy: 0.9453
Epoch 1, Batch 600/2916 (20.58%), Task Loss: 0.3735, Domain Loss: 0.7543, Task Accuracy: 0.9297
Epoch 1, Batch 700/2916 (24.01%), Task Loss: 0.3282, Domain Loss: 0.7429, Task Accuracy: 0.9219
Epoch 1, Batch 800/2916 (27.43%), Task Loss: 0.3229, Domain Loss: 0.7331, Task Accuracy: 0.9219
Epoch 1, Batch 900/2916 (30.86%), Task Loss: 0.2317, Domain Loss: 0.7188, Task Accuracy: 0.9531
Epoch 1, Batch 1000/2916 (34.29%), Task Loss: 0.2239, Domain Loss: 0.7156, Task Accuracy: 0.9531
Epoch 1, Batch 1100/2916 (37.72%), Task L

  with autocast():


Validation (Overall) - Task Loss: 0.1029, Domain Loss: 0.6933, Task Accuracy: 0.9651


  with autocast():


Validation (Telugu Only) - Task Loss: 0.1848, Domain Loss: 0.6930, Task Accuracy: 0.9318
Epoch 2, Batch 100/2916 (3.43%), Task Loss: 0.0664, Domain Loss: 0.6932, Task Accuracy: 0.9844
Epoch 2, Batch 200/2916 (6.86%), Task Loss: 0.0361, Domain Loss: 0.6932, Task Accuracy: 0.9922
Epoch 2, Batch 300/2916 (10.29%), Task Loss: 0.0866, Domain Loss: 0.6930, Task Accuracy: 0.9609
Epoch 2, Batch 400/2916 (13.72%), Task Loss: 0.1068, Domain Loss: 0.6933, Task Accuracy: 0.9609
Epoch 2, Batch 500/2916 (17.15%), Task Loss: 0.1034, Domain Loss: 0.6926, Task Accuracy: 0.9766
Epoch 2, Batch 600/2916 (20.58%), Task Loss: 0.0625, Domain Loss: 0.6933, Task Accuracy: 0.9922
Epoch 2, Batch 700/2916 (24.01%), Task Loss: 0.0245, Domain Loss: 0.6927, Task Accuracy: 1.0000
Epoch 2, Batch 800/2916 (27.43%), Task Loss: 0.1456, Domain Loss: 0.6932, Task Accuracy: 0.9453
Epoch 2, Batch 900/2916 (30.86%), Task Loss: 0.0633, Domain Loss: 0.6922, Task Accuracy: 0.9766
Epoch 2, Batch 1000/2916 (34.29%), Task Loss: 0.0