In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
import spacy
from torch.cuda.amp import autocast, GradScaler
import torch.nn.functional as F
from tqdm import tqdm
import os  # Import os module
import shutil # Import shutil for removing directory
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from captum.attr import LayerIntegratedGradients

# Initialize device
device = torch.device('cuda')

# Initialize models
nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'lemmatizer'])

# Move nlp to GPU (optional but recommended)
# nlp.to(device)

# Use a BERT model fine-tuned on Amazon reviews
model_name = 'nlptown/bert-base-multilingual-uncased-sentiment'
tokenizer = BertTokenizer.from_pretrained(model_name)
bert_model = BertModel.from_pretrained(model_name).to(device)

class RegularizedBERT(nn.Module):
    def __init__(self, num_labels, feature_dim, hyperparams):
        super().__init__()
        self.bert = bert_model
        self.dropout = nn.Dropout(hyperparams["Dropout"])
        self.feature_dim = feature_dim
        self.bert_projection = nn.Linear(768, 384)
        self.feature_projection = nn.Linear(feature_dim, 384) if feature_dim > 0 else None
        self.classifier = nn.Sequential(
            nn.Linear(384, 384),
            nn.LayerNorm(384),
            nn.ReLU(),
            nn.Dropout(hyperparams["Dropout"]),
            nn.Linear(384, num_labels)
        )
        self.all_preds = []
        self.all_labels = []
        self.all_texts = []

    def forward(self, input_ids, attention_mask, features=None):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = bert_output.last_hidden_state[:, 0, :]
        bert_projected = self.bert_projection(pooled_output)
        
        if features is not None and self.feature_dim > 0:
            feature_projected = self.feature_projection(features)
            combined_features = bert_projected + feature_projected
        else:
            combined_features = bert_projected
            
        output = self.classifier(combined_features)
        return output

class GPUOptimizedTrainer:
    def __init__(self, df, text_column, labels, hyperparams, embedding_dir="embeddings"):
        # Validate hyperparameters
        required_params = ["Epochs", "Batch Size", "Learning Rate", "Dropout", "Weight Decay", 
                         "Label Smoothing", "Early Stopping Patience", "Gradient Accumulation Steps"]
        for param in required_params:
            if param not in hyperparams:
                raise ValueError(f"Missing required hyperparameter: {param}")
            
        if hyperparams["Batch Size"] <= 0:
            raise ValueError("Batch Size must be positive")
        if not 0 <= hyperparams["Dropout"] <= 1:
            raise ValueError("Dropout must be between 0 and 1")
            
        self.device = device
        self.hyperparams = hyperparams
        self.batch_size = hyperparams["Batch Size"]
        self.embedding_dir = embedding_dir
        self.text_column = text_column

        torch.backends.cudnn.benchmark = True
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        
        # Initialize feature_dim first
        sample_features = self.extract_features(df[text_column].head(1).tolist())
        self.feature_dim = sample_features.shape[1]
        
        # Create model with correct feature_dim
        self.model = RegularizedBERT(
            num_labels=5, 
            feature_dim=self.feature_dim, 
            hyperparams=hyperparams
        ).to(self.device)

        self.scaler = GradScaler()
        self.prepare_data(df, labels)
        self.setup_training()

    def extract_features(self, texts):
        print("Extracting features...")
        os.makedirs(self.embedding_dir, exist_ok=True) # Ensure embedding directory exists

        bert_embedding_file = os.path.join(self.embedding_dir, "bert_embeddings.npy")
        syntactic_feature_file = os.path.join(self.embedding_dir, "syntactic_features.npy")

        if os.path.exists(bert_embedding_file) and os.path.exists(syntactic_feature_file):
            print("Loading embeddings from disk...")
            contextual_features = np.load(bert_embedding_file)
            syntactic_features = np.load(syntactic_feature_file)
            print("Embeddings loaded.")
            return np.hstack([contextual_features, syntactic_features])
        else:
            print("Calculating and saving embeddings...")

            def extract_contextual(texts, batch_size=128):
                features = []
                for i in range(0, len(texts), batch_size):
                    batch = texts[i:i + batch_size]
                    # Ensure each text is a string before tokenizing
                    batch = [str(text) for text in batch] # <--- ENSURE STRING
                    inputs = tokenizer(batch, padding=True, truncation=True, return_tensors="pt").to(self.device)
                    with torch.no_grad():
                        outputs = bert_model(**inputs)
                        embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                    features.extend(embeddings)
                return np.array(features)

            def extract_syntactic(texts):
                features = []
                for text in texts:
                    # Ensure each text is a string before passing to nlp
                    text = str(text) # <--- ENSURE STRING
                    doc = nlp(text)
                    pos_tags = [token.pos_ for token in doc]
                    features.append([
                        len(doc),
                        len(set(pos_tags)) / len(pos_tags),
                        pos_tags.count('NOUN') / len(pos_tags),
                        pos_tags.count('VERB') / len(pos_tags),
                    ])
                return np.array(features)

            contextual_features = extract_contextual(texts)
            syntactic_features = extract_syntactic(texts)

            np.save(bert_embedding_file, contextual_features) # Save BERT embeddings
            np.save(syntactic_feature_file, syntactic_features) # Save syntactic features
            print("Embeddings calculated and saved.")
            return np.hstack([contextual_features, syntactic_features])

    def prepare_data(self, df, labels, val_split=0.2): # Takes dataFrame
        print("Preparing data...")

        # Extract the data
        texts = df[self.text_column].tolist()

        # Wrap the text and label processing in a tqdm progress bar
        encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
        input_ids = encodings['input_ids']
        attention_mask = encodings['attention_mask']
        features = self.extract_features(texts) # Features are extracted here
        if self.feature_dim is None: # Update feature_dim only once, after extracting features
            self.feature_dim = features.shape[1]
            self.model.classifier[0] = nn.Linear(768 + self.feature_dim, 384) # Update the first linear layer with correct feature_dim
            self.model = self.model.to(self.device) # Move model to device again after changing the layer
        scaler = StandardScaler()
        features = scaler.fit_transform(features)

        split_idx = int(len(input_ids) * (1 - val_split))
        indices = np.random.permutation(len(input_ids))

        train_idx = indices[:split_idx]
        val_idx = indices[split_idx:]

        self.train_input_ids = input_ids[train_idx]
        self.train_attention_mask = attention_mask[train_idx]
        self.train_features = torch.tensor(features[train_idx], dtype=torch.float32)
        self.train_labels = torch.tensor(labels[train_idx], dtype=torch.long)
        self.train_texts = [texts[i] for i in train_idx] # Store training texts

        self.val_input_ids = input_ids[val_idx]
        self.val_attention_mask = attention_mask[val_idx]
        self.val_features = torch.tensor(features[val_idx], dtype=torch.float32)
        self.val_labels = torch.tensor(labels[val_idx], dtype=torch.long)
        self.val_texts = [texts[i] for i in val_idx] # Store validation texts


        self.create_dataloaders()

    def create_dataloaders(self):
        train_dataset = TensorDataset(self.train_input_ids, self.train_attention_mask, self.train_features, self.train_labels)
        val_dataset = TensorDataset(self.val_input_ids, self.val_attention_mask, self.val_features, self.val_labels)

        self.train_loader = DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=4,
            prefetch_factor=3,
            persistent_workers=True
        )

        self.val_loader = DataLoader(
            val_dataset,
            batch_size=self.batch_size * 2,  # Validation batch size can be larger
            pin_memory=True,
            num_workers=4,
            persistent_workers=True
        )

    def setup_training(self):
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.hyperparams["Learning Rate"],
            weight_decay=self.hyperparams["Weight Decay"],
            betas=(0.9, 0.999)
        )

        # Total steps for OneCycleLR
        total_steps = (len(self.train_loader) // self.hyperparams["Gradient Accumulation Steps"]) * self.hyperparams["Epochs"]

        self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.optimizer,
            max_lr=self.hyperparams["Learning Rate"] * 10,  # Usually a good practice to have a higher max_lr
            steps_per_epoch=len(self.train_loader) // self.hyperparams["Gradient Accumulation Steps"],
            epochs=self.hyperparams["Epochs"],
            pct_start=0.1
        )

        self.early_stopping = EarlyStopping(patience=self.hyperparams["Early Stopping Patience"])

    def train(self):
        epochs = self.hyperparams["Epochs"]
        accumulation_steps = self.hyperparams["Gradient Accumulation Steps"]

        # Clear previous predictions
        self.model.all_preds = []
        self.model.all_labels = []
        self.model.all_texts = []

        best_val_loss = float('inf')
        train_losses, val_losses = [], []
        train_accs, val_accs = [], []

        for epoch in range(epochs):
            train_metrics = self.train_epoch(accumulation_steps)
            val_metrics = self.validate()

            train_losses.append(train_metrics['loss'])
            val_losses.append(val_metrics['loss'])
            train_accs.append(train_metrics['acc'])
            val_accs.append(val_metrics['acc'])

            print(f"Epoch {epoch+1}/{epochs}")
            print(f"Train Loss: {train_metrics['loss']:.4f}, Acc: {train_metrics['acc']:.4f}")
            print(f"Val Loss: {val_metrics['loss']:.4f}, Acc: {val_metrics['acc']:.4f}")

            if val_metrics['loss'] < best_val_loss:
                best_val_loss = val_metrics['loss']
                torch.save(self.model.state_dict(), 'best_model.pt')

            if self.early_stopping(val_metrics['loss']):
                print("Early stopping triggered")
                break

        self.perform_error_analysis() #Moved to the very end of train
        return {
            'train_loss': train_losses,
            'val_loss': val_losses,
            'train_acc': train_accs,
            'val_acc': val_accs
        }

    def train_epoch(self, accumulation_steps):
        torch.cuda.empty_cache()
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0

        for batch_idx, (input_ids, attention_mask, features, target) in enumerate(tqdm(self.train_loader, desc="Training", leave=False, dynamic_ncols=True, position=0)): # Explicit position=0
            torch.cuda.empty_cache() # Empty cache after every batch
            input_ids = input_ids.to(self.device, non_blocking=True)
            attention_mask = attention_mask.to(self.device, non_blocking=True)
            features = features.to(self.device, non_blocking=True)
            target = target.to(self.device, non_blocking=True)

            with autocast():
                output = self.model(input_ids=input_ids, attention_mask=attention_mask, features=features)
                loss = F.cross_entropy(output, target, label_smoothing=self.hyperparams["Label Smoothing"])
                loss = loss / accumulation_steps  # Normalize loss for accumulation

            self.scaler.scale(loss).backward()

            if (batch_idx + 1) % accumulation_steps == 0 or (batch_idx + 1) == len(self.train_loader):
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad(set_to_none=True)
                self.scheduler.step()

            total_loss += loss.item() * accumulation_steps  # Scale back for correct averaging
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
            total += target.size(0)

        return {'loss': total_loss / len(self.train_loader), 'acc': correct / total}

    def validate(self):
        torch.cuda.empty_cache()
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0

        with torch.no_grad():
            for batch_idx, (input_ids, attention_mask, features, target) in enumerate(tqdm(self.val_loader, desc="Validating", leave=False, dynamic_ncols=True, position=0)): # Explicit position=0
                torch.cuda.empty_cache() # Empty cache after every batch
                input_ids = input_ids.to(self.device, non_blocking=True)
                attention_mask = attention_mask.to(self.device, non_blocking=True)
                features = features.to(self.device, non_blocking=True)
                target = target.to(self.device, non_blocking=True)

                with autocast():
                    output = self.model(input_ids=input_ids, attention_mask=attention_mask, features=features)
                    loss = F.cross_entropy(output, target)

                total_loss += loss.item()
                pred = output.argmax(dim=1)
                correct += pred.eq(target).sum().item()
                total += target.size(0)

                # Get the text for the current batch
                start_index = batch_idx * self.batch_size * 2 #validation batch_size 2x the training
                end_index = start_index + target.size(0)
                batch_texts = self.val_texts[start_index:end_index]


                # Store predictions and labels for error analysis
                self.model.all_preds.extend(pred.cpu().numpy())
                self.model.all_labels.extend(target.cpu().numpy())
                self.model.all_texts.extend(batch_texts) # Store the corresponding texts




        return {'loss': total_loss / len(self.val_loader), 'acc': correct / total}
    def perform_error_analysis(self):
        # Generate classification report
        report = classification_report(self.model.all_labels, self.model.all_preds, digits=4) #4 for a more detailed overview
        print("Classification Report:\n", report)

        # Generate confusion matrix
        cm = confusion_matrix(self.model.all_labels, self.model.all_preds)
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.xlabel('Predicted Labels')
        plt.ylabel('True Labels')
        plt.title('Confusion Matrix')
        plt.show()

        # Find misclassified texts
        misclassified_indices = np.where(np.array(self.model.all_preds) != np.array(self.model.all_labels))[0]

        print("\nMisclassified Text Examples:")
        for idx in misclassified_indices[:5]:  # Display up to 5 examples
            print(f"True Label: {self.model.all_labels[idx]}, Predicted Label: {self.model.all_preds[idx]}")
            print(f"Text: {self.model.all_texts[idx]}\n")

        # Explain the predictions using Captum
        #print("\nAttribution Analysis for Misclassified Examples:")
        #for idx in misclassified_indices[:5]:  # Display up to 5 examples
        #    self.explain_instance(self.model.all_texts[idx], self.model.all_labels[idx], self.model) # Changed all_preds[idx] to all_labels[idx]

    def explain_instance(self, text, true_label, model):
       #Explain instance no longer prints so is left empty
       pass

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                return True
        else:
            self.best_loss = val_loss
            self.counter = 0
        return False


if __name__ == "__main__":
    # Load your DataFrame
    df = pd.read_csv("/content/modified_df_11.csv")

    # Ensure 'text' column is string
    df['text'] = df['text'].astype(str)

    #Set name of text column
    text_column = 'text'

    # Extract labels
    labels = np.array(df["rating"].tolist()) - 1  # Adjust as needed

    # Define hyperparameters
    hyperparams = {
        "Epochs": 1, #Reduced for faster example
        "Batch Size": 70,
        "Learning Rate": 1e-5,
        "Dropout": 0.3,
        "Weight Decay": 0.1,
        "Label Smoothing": 0.1,
        "Early Stopping Patience": 3,
        "Gradient Accumulation Steps": 4,
        "Optimizer": "AdamW",
        "Scheduler": "OneCycleLR",
        "Feature Dimension": 0,
        "Model": "BERT with Multiple Embeddings"
    }

    embedding_dir = "my_embeddings"

    # Clear embedding directory
    #if os.path.exists(embedding_dir):
    #    shutil.rmtree(embedding_dir)

    # Instantiate Trainer  Pass the DataFrame, text column name, and TextProcessor
    trainer = GPUOptimizedTrainer(df, text_column, labels, hyperparams, embedding_dir=embedding_dir)
    metrics = trainer.train()

    # Print Dataframe
    print(df)

    # Plot results
    import matplotlib.pyplot as plt

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(metrics['train_loss'], label='Train Loss')
    plt.plot(metrics['val_loss'], label='Val Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(metrics['train_acc'], label='Train Acc')
    plt.plot(metrics['val_acc'], label='Val Acc')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()