In [1]:
# mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os
import pandas as pd

def load_datasets(base_dir='/content/drive/MyDrive/text_analysis/datasets'):
    """
    Load train, validation and test datasets
    Return all three datasets
    """
    # Define file paths
    train_path = os.path.join(base_dir, "train_drug_reviews2.csv")
    val_path = os.path.join(base_dir, "val_drug_reviews2.csv")
    test_path = os.path.join(base_dir, "test_drug_reviews2.csv")

    # Load datasets
    train_df = pd.read_csv(train_path)
    val_df = pd.read_csv(val_path)
    test_df = pd.read_csv(test_path)


    return train_df, val_df, test_df

train_df, val_df, test_df = load_datasets()

In [4]:
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import random
import re
import nltk
from nltk.corpus import stopwords

# Set random seed to ensure reproducible results
def set_seed(seed_value=42):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    torch.cuda.manual_seed_all(seed_value)

set_seed(42)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Text preprocessing function
def preprocess_text(text):
    """Preprocess drug reviews"""
    if not isinstance(text, str):
        return ""

    # Convert to lowercase
    text = text.lower()

    # Remove URLs
    text = re.sub(r'https?://\S+|www\.\S+', '', text)

    # Remove HTML tags
    text = re.sub(r'<.*?>', '', text)

    # Keep letters, numbers, common punctuation (preserving drug information)
    text = re.sub(r'[^\w\s.,!?;:]', ' ', text)

    # Remove extra spaces
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# Create custom Dataset class
class DrugReviewsDataset(Dataset):
    def __init__(self, reviews, labels, tokenizer, max_len=256):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        label = int(self.labels[idx])

        # Process text using BERT tokenizer
        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'review_text': review,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Model training function
def train_model(model, train_dataloader, val_dataloader, epochs=4):
    # Optimizer
    optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

    # Calculate total training steps
    total_steps = len(train_dataloader) * epochs

    # Create learning rate scheduler
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=total_steps
    )

    # Record training history
    training_history = {
        'train_loss': [],
        'val_loss': [],
        'val_accuracy': []
    }

    # Record best validation accuracy
    best_val_accuracy = 0

    # Training loop
    for epoch in range(epochs):
        print(f'\n======== Epoch {epoch + 1} / {epochs} ========')

        # Training mode
        model.train()
        train_loss = 0

        # Training progress bar
        progress_bar = tqdm(train_dataloader, desc="Training", position=0, leave=True)

        for batch in progress_bar:
            # Load data to GPU
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Clear previous gradients
            model.zero_grad()

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            train_loss += loss.item()

            # Backward pass
            loss.backward()

            # Gradient clipping to prevent gradient explosion
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            # Update parameters
            optimizer.step()

            # Update learning rate
            scheduler.step()

            # Update progress bar
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})

        # Calculate average training loss
        avg_train_loss = train_loss / len(train_dataloader)
        training_history['train_loss'].append(avg_train_loss)

        print(f"Average Training Loss: {avg_train_loss:.4f}")

        # Evaluate on validation set
        val_loss, val_accuracy = evaluate_model(model, val_dataloader)
        training_history['val_loss'].append(val_loss)
        training_history['val_accuracy'].append(val_accuracy)

        print(f"Validation Loss: {val_loss:.4f}")
        print(f"Validation Accuracy: {val_accuracy:.4f}")

        # Save best model
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            # Save model
            torch.save(model.state_dict(), 'best_bert_model.pt')
            print("Saved new best model")

    print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")

    # Plot training history
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(training_history['train_loss'], label='Training Loss')
    plt.plot(training_history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')

    plt.subplot(1, 2, 2)
    plt.plot(training_history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Validation Accuracy')

    plt.tight_layout()
    plt.savefig('bert_training_history.png')
    plt.close()

    return training_history, best_val_accuracy

# Model evaluation function
def evaluate_model(model, dataloader):
    # Evaluation mode
    model.eval()

    total_loss = 0
    all_preds = []
    all_labels = []

    # No gradient calculation
    with torch.no_grad():
        for batch in dataloader:
            # Load data to GPU
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            logits = outputs.logits

            total_loss += loss.item()

            # Get prediction results
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate average loss
    avg_loss = total_loss / len(dataloader)

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy

# Test model and generate classification report
def test_model(model, test_dataloader):
    # Evaluation mode
    model.eval()

    all_preds = []
    all_labels = []

    # No gradient calculation
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Testing"):
            # Load data to GPU
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            logits = outputs.logits

            # Get prediction results
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {accuracy:.4f}")

    # Generate classification report
    target_names = ['Not Effective(0)', 'Moderately Effective(1)', 'Effective(2)']
    report = classification_report(all_labels, all_preds,
                                  target_names=target_names,
                                  digits=3)
    print("\nClassification Report:")
    print(report)

    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
               xticklabels=target_names,
               yticklabels=target_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('BERT Model Confusion Matrix')
    plt.tight_layout()
    plt.savefig('bert_confusion_matrix.png')
    plt.close()

    # Generate detailed classification report (dictionary form)
    report_dict = classification_report(all_labels, all_preds,
                                      target_names=target_names,
                                      output_dict=True)
    report_df = pd.DataFrame(report_dict).transpose()

    return accuracy, report_df, all_preds, all_labels

# Main function
def main():
    # Load datasets
    def load_datasets(base_dir='/content/drive/MyDrive/text_analysis/datasets'):
        print("Loading datasets...")
        train_path = os.path.join(base_dir, "train_drug_reviews2.csv")
        val_path = os.path.join(base_dir, "val_drug_reviews2.csv")
        test_path = os.path.join(base_dir, "test_drug_reviews2.csv")

        train_df = pd.read_csv(train_path)
        val_df = pd.read_csv(val_path)
        test_df = pd.read_csv(test_path)

        # Select only needed columns
        train_df = train_df[['review', 'effectiveness']]
        val_df = val_df[['review', 'effectiveness']]
        test_df = test_df[['review', 'effectiveness']]

        # Check data distribution
        print(f"Training set samples: {len(train_df)}")
        print(f"Validation set samples: {len(val_df)}")
        print(f"Test set samples: {len(test_df)}")

        print("\nLabel distribution:")
        print("Training set:", train_df['effectiveness'].value_counts().sort_index())
        print("Validation set:", val_df['effectiveness'].value_counts().sort_index())
        print("Test set:", test_df['effectiveness'].value_counts().sort_index())

        return train_df, val_df, test_df

    # Load data
    train_df, val_df, test_df = load_datasets()

    # Preprocess text
    print("\nPreprocessing text...")
    train_df['processed_review'] = train_df['review'].apply(preprocess_text)
    val_df['processed_review'] = val_df['review'].apply(preprocess_text)
    test_df['processed_review'] = test_df['review'].apply(preprocess_text)

    # Load BERT tokenizer
    print("\nLoading BERT tokenizer...")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Create datasets
    print("\nCreating datasets...")
    train_dataset = DrugReviewsDataset(
        reviews=train_df['processed_review'].values,
        labels=train_df['effectiveness'].values,
        tokenizer=tokenizer
    )

    val_dataset = DrugReviewsDataset(
        reviews=val_df['processed_review'].values,
        labels=val_df['effectiveness'].values,
        tokenizer=tokenizer
    )

    test_dataset = DrugReviewsDataset(
        reviews=test_df['processed_review'].values,
        labels=test_df['effectiveness'].values,
        tokenizer=tokenizer
    )

    # Create DataLoaders
    print("\nCreating DataLoaders...")
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=16,  # Small batches can reduce memory requirements
        shuffle=True
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=32,
        shuffle=False
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False
    )

    # Calculate class weights (handling class imbalance)
    class_counts = train_df['effectiveness'].value_counts().sort_index()
    class_weights = 1.0 / torch.tensor(class_counts.values, dtype=torch.float)
    class_weights = class_weights / class_weights.sum()
    class_weights = class_weights.to(device)

    # Load pre-trained BERT model
    print("\nLoading BERT model...")
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased',
        num_labels=3,  # Three classes: 0, 1, 2
        problem_type="single_label_classification"
    )

    # Move model to GPU
    model = model.to(device)

    # Set weights
    # Note: Some versions of transformers may not support directly setting class_weights
    # If error occurs, consider manually weighting in loss function calculation

    # Train model
    print("\nStarting BERT model training...")
    training_history, best_val_accuracy = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        epochs=4  # BERT typically needs only a few epochs
    )

    # Load best model for testing
    print("\nLoading best model for testing...")
    model.load_state_dict(torch.load('best_bert_model.pt'))

    # Evaluate on test set
    print("\nEvaluating model on test set...")
    test_accuracy, report_df, all_preds, all_labels = test_model(
        model=model,
        test_dataloader=test_dataloader
    )

    print("\nFinal results:")
    print(f"Best validation accuracy: {best_val_accuracy:.4f}")
    print(f"Test accuracy: {test_accuracy:.4f}")
    print("\nDetailed classification report:")
    print(report_df)

    # Save prediction results (optional)
    test_results = pd.DataFrame({
        'review': test_df['review'],
        'true_label': all_labels,
        'predicted_label': all_preds
    })
    test_results.to_csv('bert_test_predictions.csv', index=False)

    return model, test_accuracy, report_df

if __name__ == "__main__":
    main()

Using device: cuda
Loading datasets...
Training set samples: 148966
Validation set samples: 31818
Test set samples: 31904

Label distribution:
Training set: effectiveness
0    37081
1    13254
2    98631
Name: count, dtype: int64
Validation set: effectiveness
0     7920
1     2831
2    21067
Name: count, dtype: int64
Test set: effectiveness
0     7942
1     2838
2    21124
Name: count, dtype: int64

Preprocessing text...

Loading BERT tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]


Creating datasets...

Creating DataLoaders...

Loading BERT model...


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Starting BERT model training...



Training: 100%|██████████| 9311/9311 [27:32<00:00,  5.64it/s, loss=0.1634]


Average Training Loss: 0.4944
Validation Loss: 0.4077
Validation Accuracy: 0.8435
Saved new best model



Training: 100%|██████████| 9311/9311 [27:21<00:00,  5.67it/s, loss=0.4478]


Average Training Loss: 0.3349
Validation Loss: 0.3714
Validation Accuracy: 0.8729
Saved new best model



Training: 100%|██████████| 9311/9311 [27:23<00:00,  5.67it/s, loss=0.0080]


Average Training Loss: 0.2238
Validation Loss: 0.4050
Validation Accuracy: 0.8900
Saved new best model



Training:   1%|          | 74/9311 [00:13<27:23,  5.62it/s, loss=0.1055]


KeyboardInterrupt: 

In [5]:
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import re

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Text preprocessing function (same as in training)
def preprocess_text(text):
    """Preprocess drug reviews"""
    if not isinstance(text, str):
        return ""

    # Convert to lowercase
    text = text.lower()

    # Remove URLs
    text = re.sub(r'https?://\S+|www\.\S+', '', text)

    # Remove HTML tags
    text = re.sub(r'<.*?>', '', text)

    # Keep letters, numbers, common punctuation (preserving drug information)
    text = re.sub(r'[^\w\s.,!?;:]', ' ', text)

    # Remove extra spaces
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# Dataset class (same as in training)
class DrugReviewsDataset(Dataset):
    def __init__(self, reviews, labels, tokenizer, max_len=256):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        label = int(self.labels[idx])

        # Process text using BERT tokenizer
        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'review_text': review,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Test model and generate classification report
def test_model(model, test_dataloader):
    """
    Evaluate a trained model on the test dataset

    Args:
        model: The trained BERT model
        test_dataloader: DataLoader containing the test dataset

    Returns:
        accuracy: Test accuracy
        report_df: DataFrame with detailed classification metrics
        all_preds: List of model predictions
        all_labels: List of true labels
    """
    # Set model to evaluation mode
    model.eval()

    all_preds = []
    all_labels = []

    # No gradient calculation needed for evaluation
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Testing"):
            # Load data to device (CPU or GPU)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )

            logits = outputs.logits

            # Get prediction results
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate accuracy
    accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {accuracy:.4f}")

    # Generate classification report
    target_names = ['Not Effective(0)', 'Moderately Effective(1)', 'Effective(2)']
    report = classification_report(all_labels, all_preds,
                                  target_names=target_names,
                                  digits=3)
    print("\nClassification Report:")
    print(report)

    # Generate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
               xticklabels=target_names,
               yticklabels=target_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('BERT Model Confusion Matrix')
    plt.tight_layout()
    plt.savefig('bert_evaluation_confusion_matrix.png')
    plt.close()

    # Generate detailed classification report (dictionary form)
    report_dict = classification_report(all_labels, all_preds,
                                      target_names=target_names,
                                      output_dict=True)
    report_df = pd.DataFrame(report_dict).transpose()

    return accuracy, report_df, all_preds, all_labels

def evaluate_saved_model(model_path, test_data_path):
    """
    Load a saved model and evaluate it on test data

    Args:
        model_path: Path to the saved model file
        test_data_path: Path to the test CSV file

    Returns:
        test_accuracy: Accuracy on test set
        report_df: Detailed classification report
    """
    print(f"Loading test data from: {test_data_path}")

    # Load test dataset
    test_df = pd.read_csv(test_data_path)
    test_df = test_df[['review', 'effectiveness']]  # Select only needed columns

    print(f"Test set samples: {len(test_df)}")
    print(f"Label distribution: {test_df['effectiveness'].value_counts().sort_index()}")

    # Preprocess text
    print("Preprocessing test data...")
    test_df['processed_review'] = test_df['review'].apply(preprocess_text)

    # Load BERT tokenizer
    print("Loading BERT tokenizer...")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Create test dataset
    test_dataset = DrugReviewsDataset(
        reviews=test_df['processed_review'].values,
        labels=test_df['effectiveness'].values,
        tokenizer=tokenizer
    )

    # Create DataLoader
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=False
    )

    # Initialize model architecture
    print("Initializing model architecture...")
    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased',
        num_labels=3,  # Three classes: 0, 1, 2
        problem_type="single_label_classification"
    )

    # Load saved weights
    print(f"Loading saved model from: {model_path}")
    model.load_state_dict(torch.load(model_path, map_location=device))

    # Move model to device
    model = model.to(device)

    # Evaluate model
    print("Evaluating model...")
    test_accuracy, report_df, all_preds, all_labels = test_model(
        model=model,
        test_dataloader=test_dataloader
    )

    # Save results to CSV
    results_df = pd.DataFrame({
        'review': test_df['review'],
        'true_label': all_labels,
        'predicted_label': all_preds
    })

    results_path = 'bert_evaluation_results.csv'
    results_df.to_csv(results_path, index=False)
    print(f"Saved prediction results to: {results_path}")

    return test_accuracy, report_df

if __name__ == "__main__":
    # Set paths
    model_path = 'best_bert_model.pt'  # Path to saved model
    test_data_path = '/content/drive/MyDrive/text_analysis/datasets/test_drug_reviews2.csv'  # Path to test data

    # Run evaluation
    test_accuracy, report_df = evaluate_saved_model(model_path, test_data_path)

    print("\nFinal Results:")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print("\nDetailed Classification Report:")
    print(report_df)

Using device: cuda
Loading test data from: /content/drive/MyDrive/text_analysis/datasets/test_drug_reviews2.csv
Test set samples: 31904
Label distribution: effectiveness
0     7942
1     2838
2    21124
Name: count, dtype: int64
Preprocessing test data...
Loading BERT tokenizer...
Initializing model architecture...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loading saved model from: best_bert_model.pt
Evaluating model...


Testing: 100%|██████████| 997/997 [02:10<00:00,  7.65it/s]


Test Accuracy: 0.8865

Classification Report:
                         precision    recall  f1-score   support

       Not Effective(0)      0.863     0.850     0.856      7942
Moderately Effective(1)      0.573     0.492     0.530      2838
           Effective(2)      0.930     0.953     0.942     21124

               accuracy                          0.887     31904
              macro avg      0.789     0.765     0.776     31904
           weighted avg      0.882     0.887     0.884     31904

Saved prediction results to: bert_evaluation_results.csv

Final Results:
Test Accuracy: 0.8865

Detailed Classification Report:
                         precision    recall  f1-score       support
Not Effective(0)          0.862898  0.849534  0.856164   7942.000000
Moderately Effective(1)   0.573011  0.492248  0.529568   2838.000000
Effective(2)              0.930337  0.953371  0.941713  21124.000000
accuracy                  0.886503  0.886503  0.886503      0.886503
macro avg              

In [6]:
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader
import transformers
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import re
import matplotlib.cm as cm
from IPython.display import HTML, display
import warnings
warnings.filterwarnings('ignore')

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Text preprocessing function (same as in training)
def preprocess_text(text):
    """Preprocess drug reviews"""
    if not isinstance(text, str):
        return ""

    # Convert to lowercase
    text = text.lower()

    # Remove URLs
    text = re.sub(r'https?://\S+|www\.\S+', '', text)

    # Remove HTML tags
    text = re.sub(r'<.*?>', '', text)

    # Keep letters, numbers, common punctuation (preserving drug information)
    text = re.sub(r'[^\w\s.,!?;:]', ' ', text)

    # Remove extra spaces
    text = re.sub(r'\s+', ' ', text).strip()

    return text

# Dataset class with attention output flag
class DrugReviewsDataset(Dataset):
    def __init__(self, reviews, labels, tokenizer, max_len=256):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, idx):
        review = str(self.reviews[idx])
        label = int(self.labels[idx])

        # Process text using BERT tokenizer
        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'review_text': review,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

# Custom BERT model to extract attention weights
class BertWithAttention(torch.nn.Module):
    """
    Custom BERT model that returns attention weights along with predictions
    """
    def __init__(self, model_path=None, num_labels=3):
        super(BertWithAttention, self).__init__()

        # If model path is provided, we'll load it later
        self.model_path = model_path

        # Initialize model configuration
        config = BertConfig.from_pretrained(
            'bert-base-uncased',
            num_labels=num_labels,
            output_attentions=True  # Set this to get attention weights
        )

        # Initialize BERT with the configuration
        self.bert = BertForSequenceClassification.from_pretrained(
            'bert-base-uncased',
            config=config
        )

    def forward(self, input_ids, attention_mask, labels=None):
        """
        Forward pass that returns both outputs and attention weights
        """
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_attentions=True  # Ensure attention outputs
        )

        # Return all outputs for further processing
        return outputs

    def load_trained_weights(self):
        """
        Load weights from a trained model file
        """
        if self.model_path:
            # Load model weights into the BERT model
            state_dict = torch.load(self.model_path, map_location=device)
            self.bert.load_state_dict(state_dict)
            print(f"Loaded trained weights from {self.model_path}")

# Function to extract attention importance
def extract_attention_importance(model, dataloader, tokenizer, top_k=20):
    """
    Extract and aggregate attention weights to find important tokens

    Args:
        model: The BERT model with attention output
        dataloader: DataLoader with the dataset
        tokenizer: BERT tokenizer for decoding tokens
        top_k: Number of top tokens to identify

    Returns:
        word_importance_df: DataFrame with top important words and scores
        sample_attentions: Dictionary of sample reviews with their attention weights
    """
    model.eval()

    # Store aggregated attention scores for unique tokens
    token_attention_scores = {}

    # Store sample reviews with attention for visualization
    sample_attentions = {}

    # Process batches to extract attention
    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm(dataloader, desc="Extracting attention")):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            reviews = batch['review_text']

            # Forward pass to get attention weights
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            # Extract attention weights - shape is [batch_size, num_heads, seq_len, seq_len]
            # We get attention from the last layer (closest to classification decision)
            attentions = outputs.attentions[-1]  # Last layer's attention

            # Average over attention heads
            attentions = attentions.mean(dim=1)  # Now shape is [batch_size, seq_len, seq_len]

            # Process each sample in the batch
            for sample_idx in range(input_ids.size(0)):
                review = reviews[sample_idx]

                # Get tokens for this sample
                tokens = tokenizer.convert_ids_to_tokens(input_ids[sample_idx])

                # Get attention for this sample - we look at attention from CLS token (first token)
                # This shows how much the CLS token attends to each other token when making classification
                token_attn = attentions[sample_idx, 0, :].cpu().numpy()  # Attention from CLS token

                # Store a few samples for visualization
                if batch_idx < 3 and sample_idx < 3:
                    # Store only the tokens that have attention mask of 1 (not padding)
                    mask = batch['attention_mask'][sample_idx].cpu().numpy()
                    valid_tokens = [t for i, t in enumerate(tokens) if mask[i] == 1]
                    valid_attention = token_attn[:len(valid_tokens)]

                    sample_attentions[f"sample_{batch_idx}_{sample_idx}"] = {
                        "review": review,
                        "tokens": valid_tokens,
                        "attention": valid_attention
                    }

                # Aggregate attention scores for each unique token
                for i, token in enumerate(tokens):
                    if token in ['[PAD]', '[CLS]', '[SEP]']:
                        continue  # Skip special tokens

                    # Attention score from CLS to this token
                    score = token_attn[i]

                    # Aggregate scores for unique tokens
                    if token in token_attention_scores:
                        token_attention_scores[token].append(score)
                    else:
                        token_attention_scores[token] = [score]

    # Compute average attention for each token
    token_avg_attention = {
        token: np.mean(scores)
        for token, scores in token_attention_scores.items()
    }

    # Merge subword tokens (those starting with ##)
    merged_token_attention = {}
    for token, score in token_avg_attention.items():
        if token.startswith('##'):
            # Find the previous token and merge
            prefix = list(merged_token_attention.keys())[-1] if merged_token_attention else None
            if prefix:
                merged_token = prefix + token[2:]  # Remove ## and concatenate
                merged_score = (merged_token_attention[prefix] + score) / 2  # Average score

                # Update with merged token and remove the prefix
                merged_token_attention.pop(prefix)
                merged_token_attention[merged_token] = merged_score
            else:
                merged_token_attention[token] = score
        else:
            merged_token_attention[token] = score

    # Sort by attention score and get top k
    sorted_tokens = sorted(merged_token_attention.items(), key=lambda x: x[1], reverse=True)
    top_tokens = sorted_tokens[:top_k]

    # Create DataFrame for visualization
    word_importance_df = pd.DataFrame(top_tokens, columns=['Token', 'Importance'])

    return word_importance_df, sample_attentions

# Function to visualize attention for a sample
def visualize_attention(sample_data, title="Attention Visualization", save_path=None):
    """
    Create a visualization of token attention weights

    Args:
        sample_data: Dictionary with tokens and their attention weights
        title: Title for the visualization
        save_path: Path to save the visualization (if None, just display)
    """
    tokens = sample_data["tokens"]
    attention = sample_data["attention"]
    review = sample_data["review"]

    # Limit to non-padding tokens
    if len(tokens) > len(attention):
        tokens = tokens[:len(attention)]

    # Skip special tokens at the beginning and end ([CLS] and [SEP])
    tokens = tokens[1:-1]
    attention = attention[1:-1]

    plt.figure(figsize=(12, 6))

    # Plot bars for attention weights
    plt.bar(range(len(tokens)), attention, color=cm.viridis(attention / max(attention)))

    # Customize the plot
    plt.xticks(range(len(tokens)), tokens, rotation=45, ha="right")
    plt.xlabel("Tokens")
    plt.ylabel("Attention Weight")
    plt.title(f"{title}\nReview: {review[:100]}...")
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

# Function to visualize top important tokens
def visualize_top_tokens(word_importance_df, title="Top Important Tokens", save_path=None):
    """
    Create a visualization of top important tokens

    Args:
        word_importance_df: DataFrame with tokens and their importance scores
        title: Title for the visualization
        save_path: Path to save the visualization (if None, just display)
    """
    plt.figure(figsize=(10, 8))

    # Create horizontal bar chart
    bars = plt.barh(
        range(len(word_importance_df)),
        word_importance_df['Importance'],
        color=cm.viridis(word_importance_df['Importance'] / word_importance_df['Importance'].max())
    )

    # Customize the plot
    plt.yticks(range(len(word_importance_df)), word_importance_df['Token'])
    plt.xlabel("Average Attention Weight")
    plt.title(title)
    plt.gca().invert_yaxis()  # Highest importance at the top
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

# Function to create colored text based on attention weights
def attention_html(tokens, attention):
    """
    Create HTML with colored text based on attention weights

    Args:
        tokens: List of tokens
        attention: Attention weights for each token

    Returns:
        HTML representation of colored text
    """
    html = ""

    # Normalize attention to [0, 1] for coloring
    if max(attention) > 0:
        attention = attention / max(attention)

    for token, attn in zip(tokens, attention):
        if token.startswith("##"):
            token = token[2:]  # Remove ## prefix for subwords
            space = ""  # No space for subword pieces
        else:
            space = " "  # Add space between words

        # Skip special tokens
        if token in ["[CLS]", "[SEP]", "[PAD]"]:
            continue

        # Get color intensity based on attention score
        r = int(255 * attn)
        g = int(100 * (1 - attn))
        b = int(100 * (1 - attn))

        html += f'{space}<span style="background-color:rgba({r},{g},{b},0.3)">{token}</span>'

    return html

# Main evaluation function
def evaluate_and_explain(model_path, test_data_path, output_dir="bert_attention_outputs"):
    """
    Load a saved model, evaluate it, and extract attention-based feature importance

    Args:
        model_path: Path to the saved model file
        test_data_path: Path to the test CSV file
        output_dir: Directory to save outputs

    Returns:
        test_accuracy: Accuracy on test set
        top_tokens_df: DataFrame with top important tokens
    """
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    print(f"Loading test data from: {test_data_path}")

    # Load test dataset
    test_df = pd.read_csv(test_data_path)
    test_df = test_df[['review', 'effectiveness']]  # Select only needed columns

    print(f"Test set samples: {len(test_df)}")
    print(f"Label distribution: {test_df['effectiveness'].value_counts().sort_index()}")

    # Preprocess text
    print("Preprocessing test data...")
    test_df['processed_review'] = test_df['review'].apply(preprocess_text)

    # Load BERT tokenizer
    print("Loading BERT tokenizer...")
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Create test dataset
    test_dataset = DrugReviewsDataset(
        reviews=test_df['processed_review'].values,
        labels=test_df['effectiveness'].values,
        tokenizer=tokenizer
    )

    # Create DataLoader
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=16,  # Smaller batch size for attention analysis
        shuffle=False
    )

    # Initialize custom BERT model with attention outputs
    print("Initializing model architecture with attention outputs...")
    model = BertWithAttention(model_path=model_path, num_labels=3)
    model.load_trained_weights()

    # Move model to device
    model = model.to(device)

    # Evaluate model
    print("Evaluating model and extracting attention...")

    # Run standard evaluation
    model.eval()

    all_preds = []
    all_labels = []

    # Evaluate accuracy
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Evaluating accuracy"):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            # Forward pass
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)

            logits = outputs.logits

            # Get prediction results
            preds = torch.argmax(logits, dim=1).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels.cpu().numpy())

    # Calculate accuracy
    test_accuracy = accuracy_score(all_labels, all_preds)
    print(f"Test Accuracy: {test_accuracy:.4f}")

    # Generate classification report
    target_names = ['Not Effective(0)', 'Moderately Effective(1)', 'Effective(2)']
    report = classification_report(all_labels, all_preds,
                                  target_names=target_names,
                                  digits=3)
    print("\nClassification Report:")
    print(report)

    # Extract attention importance
    print("\nExtracting token importance...")
    top_tokens_df, sample_attentions = extract_attention_importance(
        model=model,
        dataloader=test_dataloader,
        tokenizer=tokenizer,
        top_k=20
    )

    # Print top tokens
    print("\nTop 20 Important Tokens:")
    print(top_tokens_df)

    # Save top tokens to CSV
    top_tokens_path = os.path.join(output_dir, "top_important_tokens.csv")
    top_tokens_df.to_csv(top_tokens_path, index=False)
    print(f"Saved top tokens to: {top_tokens_path}")

    # Visualize top tokens
    visualize_top_tokens(
        top_tokens_df,
        title="Top 20 Important Tokens for BERT Drug Effectiveness Classification",
        save_path=os.path.join(output_dir, "top_tokens_visualization.png")
    )

    # Visualize sample attention patterns
    print("\nVisualizing sample attention patterns...")
    for i, (sample_key, sample_data) in enumerate(sample_attentions.items()):
        # Visualize attention weights
        visualize_attention(
            sample_data,
            title=f"Sample {i+1} Attention Weights",
            save_path=os.path.join(output_dir, f"sample_{i+1}_attention.png")
        )

        # Create colored text HTML
        tokens = sample_data["tokens"]
        attention = sample_data["attention"]
        review = sample_data["review"]

        html_content = f"""
        <h3>Sample {i+1} Review:</h3>
        <p><b>Original:</b> {review[:300]}...</p>
        <p><b>Attention Highlighted:</b> {attention_html(tokens, attention)}</p>
        <hr>
        """

        # Save HTML visualization
        with open(os.path.join(output_dir, f"sample_{i+1}_highlighted.html"), "w") as f:
            f.write(html_content)

    # Generate class-specific important tokens (if time allows)

    return test_accuracy, top_tokens_df

if __name__ == "__main__":
    # Set paths
    model_path = 'best_bert_model.pt'  # Path to saved model
    test_data_path = '/content/drive/MyDrive/text_analysis/datasets/test_drug_reviews2.csv'  # Path to test data
    output_dir = 'bert_attention_outputs'  # Directory to save outputs

    # Run evaluation with attention analysis
    test_accuracy, top_tokens_df = evaluate_and_explain(
        model_path=model_path,
        test_data_path=test_data_path,
        output_dir=output_dir
    )

    print("\nFinal Results:")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print("\nTop 20 Important Tokens:")
    print(top_tokens_df)

Using device: cuda
Loading test data from: /content/drive/MyDrive/text_analysis/datasets/test_drug_reviews2.csv
Test set samples: 31904
Label distribution: effectiveness
0     7942
1     2838
2    21124
Name: count, dtype: int64
Preprocessing test data...
Loading BERT tokenizer...
Initializing model architecture with attention outputs...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Loaded trained weights from best_bert_model.pt
Evaluating model and extracting attention...


Evaluating accuracy: 100%|██████████| 1994/1994 [02:25<00:00, 13.67it/s]


Test Accuracy: 0.8865

Classification Report:
                         precision    recall  f1-score   support

       Not Effective(0)      0.863     0.850     0.856      7942
Moderately Effective(1)      0.573     0.492     0.530      2838
           Effective(2)      0.930     0.953     0.942     21124

               accuracy                          0.887     31904
              macro avg      0.789     0.765     0.776     31904
           weighted avg      0.882     0.887     0.884     31904


Extracting token importance...


Extracting attention: 100%|██████████| 1994/1994 [04:33<00:00,  7.29it/s]



Top 20 Important Tokens:
           Token  Importance
0      suspended    0.121173
1          greed    0.118603
2     delightful    0.084953
3            uta    0.083168
4         dieter    0.080312
5           ohio    0.078077
6         coward    0.078074
7         carole    0.078070
8          roche    0.076103
9    netherlands    0.074857
10     vancouver    0.074343
11       mapping    0.071492
12        cheryl    0.071328
13     excellent    0.070921
14   exceptional    0.069072
15          iowa    0.068792
16      zimbabwe    0.066827
17     celebrate    0.066560
18  recognizable    0.066176
19          katy    0.065271
Saved top tokens to: bert_attention_outputs/top_important_tokens.csv

Visualizing sample attention patterns...

Final Results:
Test Accuracy: 0.8865

Top 20 Important Tokens:
           Token  Importance
0      suspended    0.121173
1          greed    0.118603
2     delightful    0.084953
3            uta    0.083168
4         dieter    0.080312
5           ohio