# BERT Model

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset, RandomSampler
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from torch.optim import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tqdm import tqdm
import os

In [None]:
# Initialize tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', do_lower_case=True)

# Function to save model checkpoints
def save_model(model, tokenizer, output_dir, epoch):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model saved at epoch {epoch+1} to {output_dir}")

# Training function with progress bar
def train_model(model, train_dataloader, optimizer, device, epoch, num_epochs):
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    
    progress_bar = tqdm(train_dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=False)
    
    for batch in progress_bar:
        batch_input_ids = batch[0].to(device)
        batch_attention_mask = batch[1].to(device)
        batch_labels = batch[2].to(device)
        
        model.zero_grad()
        
        outputs = model(
            batch_input_ids,
            attention_mask=batch_attention_mask,
            labels=batch_labels
        )
        
        loss = outputs.loss
        total_loss += loss.item()
        
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)
        total_correct += (preds == batch_labels).sum().item()
        total_samples += batch_labels.size(0)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{(preds == batch_labels).float().mean().item():.4f}'
        })
    
    avg_loss = total_loss / len(train_dataloader)
    accuracy = total_correct / total_samples
    
    return avg_loss, accuracy

# Evaluation function
def evaluate_model(model, dataloader, device):
    model.eval()
    all_preds = []
    all_labels = []
    total_loss = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            batch_input_ids = batch[0].to(device)
            batch_attention_mask = batch[1].to(device)
            batch_labels = batch[2].to(device)
            
            outputs = model(
                batch_input_ids,
                attention_mask=batch_attention_mask,
                labels=batch_labels
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch_labels.cpu().numpy())
    
    avg_loss = total_loss / len(dataloader)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return avg_loss, accuracy, all_preds, all_labels


# Function to tokenize text data
def tokenize_texts(texts, labels, max_len=128):
    encodings = tokenizer.batch_encode_plus(
        list(texts), 
        add_special_tokens=True,
        max_length=max_len,
        padding='max_length',
        truncation=True,
        return_attention_mask=True,
        return_tensors='pt'
    )
    
    input_ids = encodings['input_ids']
    attention_masks = encodings['attention_mask']
    labels = torch.tensor(labels)
    
    return input_ids, attention_masks, labels


In [None]:
# Main training loop with early stopping
def train_with_early_stopping(model, train_dataloader, val_dataloader, optimizer, device, num_epochs=3, patience=1):
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    best_val_accuracy = 0
    epochs_no_improve = 0
    
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        # Train
        train_loss, train_acc = train_model(model, train_dataloader, optimizer, device, epoch, num_epochs)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        
        # Validate
        val_loss, val_acc, _, _ = evaluate_model(model, val_dataloader, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)
        
        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")
        
        # Save best model
        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            epochs_no_improve = 0
            save_model(model, tokenizer, '../models/best_distilbert_model', epoch)
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"\nEarly stopping after {epoch + 1} epochs!")
                break
    
    return train_losses, train_accuracies, val_losses, val_accuracies

# Load your data
df = pd.read_csv('../data/disaster_tweets.csv')
X = df['text'].values
y = df['target'].values

# Split into train, validation, and test sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Tokenize data
train_input_ids, train_attention_masks, train_labels = tokenize_texts(X_train, y_train)
val_input_ids, val_attention_masks, val_labels = tokenize_texts(X_val, y_val)
test_input_ids, test_attention_masks, test_labels = tokenize_texts(X_test, y_test)

# Create DataLoaders
batch_size = 16
train_data = TensorDataset(train_input_ids, train_attention_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

val_data = TensorDataset(val_input_ids, val_attention_masks, val_labels)
val_dataloader = DataLoader(val_data, batch_size=batch_size)

test_data = TensorDataset(test_input_ids, test_attention_masks, test_labels)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased',
    num_labels=2,
    output_attentions=False,
    output_hidden_states=False
).to(device)

# Set up optimizer
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

# Train with early stopping
train_losses, train_accuracies, val_losses, val_accuracies = train_with_early_stopping(
    model, train_dataloader, val_dataloader, optimizer, device, num_epochs=4, patience=1
)


In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(val_accuracies, label='Val Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.savefig('../models/training_curves.png')
plt.show()

# Final evaluation on test set
test_loss, test_acc, test_preds, test_labels = evaluate_model(model, test_dataloader, device)
print("\nFinal Test Results:")
print(f"Test Loss: {test_loss:.4f} | Test Accuracy: {test_acc:.4f}")
print("\nClassification Report:\n", classification_report(test_labels, test_preds))

# Confusion Matrix
cm = confusion_matrix(test_labels, test_preds)
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Test Set Confusion Matrix')
plt.savefig('../models/confusion_matrix.png')
plt.show()

# Save final model
save_model(model, tokenizer, '../models/final_distilbert_model', -1)