In [3]:
# pip install torch transformers pandas numpy matplotlib seaborn scikit-learn tqdm


In [5]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, AutoConfig
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm
import os

# ======================== Enhanced Core Components ========================

# 1. Data Preparation & SRL Processing
class SRLProcessor:
    def __init__(self):
        self.role_tags = {
            'V': 'verb', 'ARG0': 'subject', 'ARG1': 'object',
            'ARG2': 'indirect-object', 'ARGM-MNR': 'manner'
        }
    
    def process_srl(self, srl_data):
        """Process SRL data and add semantic role labels to text"""
        augmented = []
        for sent_info in srl_data:
            if 'srl_raw' not in sent_info or 'words' not in sent_info['srl_raw']:
                continue  # Skip malformed entries
                
            words = sent_info['srl_raw']['words']
            tags = [[] for _ in words]
            
            # Process each verb and its tags
            for verb_info in sent_info['srl_raw'].get('verbs', []):
                if 'tags' not in verb_info:
                    continue
                    
                current_tags = verb_info['tags']
                for idx, tag in enumerate(current_tags):
                    if idx >= len(tags):  # Prevent index error
                        break
                    if tag != 'O' and '-' in tag:
                        role = tag.split('-')[1]
                        if role in self.role_tags:  # Only add if it's in our defined roles
                            tags[idx].append(role)
            
            # Build the tagged sentence
            tagged_sentence = []
            for word, roles in zip(words, tags):
                for role in roles:
                    if role in self.role_tags:
                        tagged_sentence.append(f"[{self.role_tags[role]}]")
                tagged_sentence.append(word)
                for role in reversed(roles):
                    if role in self.role_tags:
                        tagged_sentence.append(f"[/{self.role_tags[role]}]")
            
            augmented.append(' '.join(tagged_sentence))
        
        return ' '.join(augmented)

def load_srl_dataset(json_path):
    """Load dataset from JSON and apply SRL processing"""
    try:
        with open(json_path) as f:
            data = json.load(f).get('samples', [])
    except (json.JSONDecodeError, FileNotFoundError) as e:
        print(f"Error loading JSON data: {e}")
        return pd.DataFrame()
    
    processor = SRLProcessor()
    samples = []
    
    for sample in tqdm(data, desc="Processing SRL data"):
        try:
            samples.append({
                'CVE_text': processor.process_srl(sample.get('CVE_srl', [])),
                'Technique_text': processor.process_srl(sample.get('Technique_srl', [])),
                'label': sample.get('label', 0),
                'role_score': sample.get('role_match_score', 0.0)
            })
        except Exception as e:
            print(f"Error processing sample: {e}")
            continue
    
    df = pd.DataFrame(samples)
    print(f"Loaded {len(df)} valid samples")
    return df

# 2. Dataset Class with Hinge Labels
class HingeSiameseDataset(Dataset):
    def __init__(self, df, tokenizer, max_len=128):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        # Normalize role weights to [0,1]
        if len(df) > 0:  # Check if df is not empty
            role_min = df['role_score'].min()
            role_max = df['role_score'].max()
            self.role_weights = (df['role_score'] - role_min) / (role_max - role_min + 1e-8)
            
            # Convert to -1/1 labels for hinge loss
            self.labels = 2 * df['label'].values - 1
        else:
            self.role_weights = pd.Series()
            self.labels = np.array([])

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # Tokenize CVE text
        cve_encodings = self.tokenizer(
            row['CVE_text'], 
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Tokenize Technique text
        tech_encodings = self.tokenizer(
            row['Technique_text'],
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'cve_input_ids': cve_encodings['input_ids'].squeeze(),
            'cve_attention_mask': cve_encodings['attention_mask'].squeeze(),
            'tech_input_ids': tech_encodings['input_ids'].squeeze(),
            'tech_attention_mask': tech_encodings['attention_mask'].squeeze(),
            'labels': torch.tensor(self.labels[idx], dtype=torch.float),
            'role_weights': torch.tensor(self.role_weights.iloc[idx], dtype=torch.float),
            'CVE_text': row['CVE_text'],
            'Technique_text': row['Technique_text']
        }

# 3. Enhanced Model Architecture with Contrastive Learning
class EnhancedContrastiveSRLModel(nn.Module):
    def __init__(self, model_name="bert-base-uncased", hidden_size=768, margin=0.4, contrastive_weight=0.3):
        super().__init__()
        try:
            self.bert = AutoModel.from_pretrained(model_name)
        except Exception as e:
            print(f"Error loading pretrained model: {e}")
            raise RuntimeError("Failed to initialize the model")
        
        # Projection layer remains unchanged
        self.srl_proj = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.LayerNorm(256)
        )
        
        # Classifier using aggregation of embeddings and interaction features
        self.classifier = nn.Sequential(
            nn.Linear(256 * 4, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 1)
        )
        
        self.contrastive_loss = nn.CosineEmbeddingLoss(margin=margin)
        self.contrastive_weight = contrastive_weight

    def soft_align_attention(self, a, b, mask_a, mask_b):
        """
        Compute soft-alignment between two sequences:
          a: [batch, seq_len_a, hidden]
          b: [batch, seq_len_b, hidden]
          mask_a: [batch, seq_len_a]
          mask_b: [batch, seq_len_b]
        Returns:
          aligned_a: weighted sum of b for each token in a
          aligned_b: weighted sum of a for each token in b
        """
        # Compute similarity matrix [B, L_a, L_b]
        similarity = torch.bmm(a, b.transpose(1, 2))
        
        # Create proper broadcasting dimensions for masks
        batch_size = mask_b.size(0)
        seq_len_a = similarity.size(1)
        seq_len_b = similarity.size(2)
        
        # Masking the padded tokens in b for computing attention on a's tokens
        mask_b_exp = mask_b.unsqueeze(1).expand(batch_size, seq_len_a, seq_len_b)
        attn_weights_a = F.softmax(similarity.masked_fill(mask_b_exp == 0, -1e9), dim=2)
        
        # Similarly, compute attention weights for b (using a's mask)
        mask_a_exp = mask_a.unsqueeze(2).expand(batch_size, seq_len_a, seq_len_b)
        attn_weights_b = F.softmax(similarity.transpose(1,2).masked_fill(mask_a_exp == 0, -1e9), dim=2)
        
        aligned_a = torch.bmm(attn_weights_a, b)  # [B, L_a, hidden]
        aligned_b = torch.bmm(attn_weights_b, a)  # [B, L_b, hidden]
        return aligned_a, aligned_b

    def pooling(self, token_embeddings, mask):
        """
        Apply masked average pooling to token embeddings.
          token_embeddings: [B, L, hidden]
          mask: [B, L] with 1 for valid tokens and 0 for padding.
        Returns:
          pooled embedding [B, hidden]
        """
        mask = mask.unsqueeze(2).float()  # [B, L, 1]
        summed = torch.sum(token_embeddings * mask, dim=1)
        counts = mask.sum(dim=1).clamp(min=1e-9)
        
        # Add safety check for cases where mask is all zeros
        valid_counts = (counts > 1e-8).float()
        return (summed / counts) * valid_counts

    def forward(self, cve_input, tech_input, labels=None):
        # Encode inputs
        cve_outputs = self.bert(**cve_input)       # shape: [B, L_cve, hidden_size]
        tech_outputs = self.bert(**tech_input)     # shape: [B, L_tech, hidden_size]
        
        cve_seq = cve_outputs.last_hidden_state
        tech_seq = tech_outputs.last_hidden_state
        
        # Get the attention masks from inputs
        cve_mask = cve_input['attention_mask']     # [B, L_cve]
        tech_mask = tech_input['attention_mask']   # [B, L_tech]
        
        # -----------------------
        # Add: Soft Align Attention
        # -----------------------
        aligned_cve, aligned_tech = self.soft_align_attention(cve_seq, tech_seq, cve_mask, tech_mask)
        
        # Combine the original sequence with the aligned one (e.g., by averaging)
        cve_combined = (cve_seq + aligned_cve) / 2.0
        tech_combined = (tech_seq + aligned_tech) / 2.0
        
        # -----------------------
        # Add: Pooling over the token dimension
        # -----------------------
        cve_pooled = self.pooling(cve_combined, cve_mask)   # [B, hidden_size]
        tech_pooled = self.pooling(tech_combined, tech_mask)  # [B, hidden_size]
        
        # -----------------------
        # Apply projection to lower-dimensional embeddings
        # -----------------------
        cve_emb = self.srl_proj(cve_pooled)   # [B, 256]
        tech_emb = self.srl_proj(tech_pooled)   # [B, 256]
        
        # -----------------------
        # Aggregating features from both branches
        # -----------------------
        diff = torch.abs(cve_emb - tech_emb)
        prod = cve_emb * tech_emb
        combined_features = torch.cat([cve_emb, tech_emb, diff, prod], dim=1)  # [B, 256*4]
        
        # Compute classifier output (similarity score or decision)
        classifier_out = self.classifier(combined_features).squeeze()
        
        # In case labels is None (inference mode)
        cont_loss = torch.tensor(0.0, device=classifier_out.device)
        
        # Optionally, compute contrastive loss if labels provided
        if labels is not None:
            contrastive_labels = torch.where(labels > 0, 1.0, -1.0).to(labels.device)
            cont_loss = self.contrastive_loss(cve_emb, tech_emb, contrastive_labels)
        
        # Always return both values to maintain consistent return structure
        return classifier_out, cont_loss

    def get_embeddings(self, input_dict):
        # Compute embeddings for inference (using pooling, projection)
        with torch.no_grad():
            outputs = self.bert(**input_dict)
            pooled = self.pooling(outputs.last_hidden_state, input_dict['attention_mask'])
            return self.srl_proj(pooled).cpu().numpy()

# 4. Enhanced Hinge Loss with Weighting
class WeightedHingeLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin
        
    def forward(self, outputs, labels, weights):
        losses = torch.clamp(self.margin - labels * outputs, min=0)
        return (losses * (1 + weights)).mean()

# 6. Data Splitting with Stratification
def get_splits_for_model(df, test_size=0.15, val_size=0.15, random_state=42):
    """Split data into train/val/test with stratification"""
    if df.empty:
        raise ValueError("DataFrame is empty, cannot split")
        
    # First split off the test set
    train_val_df, test_df = train_test_split(
        df, test_size=test_size, random_state=random_state, stratify=df['label']
    )
    
    # Then split the remaining data into train and validation
    relative_val_size = val_size / (1 - test_size)
    train_df, val_df = train_test_split(
        train_val_df, 
        test_size=relative_val_size, 
        random_state=random_state,
        stratify=train_val_df['label']
    )
    
    print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
    print(f"Train label distribution: {train_df['label'].value_counts().to_dict()}")
    print(f"Val label distribution: {val_df['label'].value_counts().to_dict()}")
    print(f"Test label distribution: {test_df['label'].value_counts().to_dict()}")
    
    return train_df, val_df, test_df

# 7. Enhanced Training with Contrastive Loss
def train_hinge_model(train_df, val_df=None, epochs=10, batch_size=16, 
                      margin=1.0, patience=3, lr=2e-5, contrastive_weight=0.4):
    """Train model with contrastive learning objective using the enhanced model with soft-align attention"""
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Use the enhanced model that now includes soft-align attention, pooling, and aggregation
    model = EnhancedContrastiveSRLModel(contrastive_weight=contrastive_weight).to(device)
    
    train_dataset = HingeSiameseDataset(train_df, tokenizer)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    if val_df is not None:
        val_dataset = HingeSiameseDataset(val_df, tokenizer)
        val_loader = DataLoader(val_dataset, batch_size=batch_size)
    else:
        val_loader = None
    
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    criterion = WeightedHingeLoss(margin=margin)
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_val_acc = 0
    patience_counter = 0
    
    for epoch in range(epochs):
        model.train()
        train_loss = cont_loss_total = 0
        train_correct = train_total = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} - Training"):
            cve_input = {
                'input_ids': batch['cve_input_ids'].to(device),
                'attention_mask': batch['cve_attention_mask'].to(device)
            }
            tech_input = {
                'input_ids': batch['tech_input_ids'].to(device),
                'attention_mask': batch['tech_attention_mask'].to(device)
            }
            labels = batch['labels'].to(device)
            weights = batch['role_weights'].to(device)
            
            optimizer.zero_grad()
            # Forward pass: enhanced model returns (classifier_out, contrastive_loss)
            outputs, cont_loss = model(cve_input, tech_input, labels)
            
            hinge_loss = criterion(outputs, labels, weights)
            total_loss = hinge_loss + cont_loss * model.contrastive_weight
            
            total_loss.backward()
            optimizer.step()
            
            train_loss += total_loss.item()
            cont_loss_total += cont_loss.item()
            preds = torch.sign(outputs)
            train_correct += (preds == labels).sum().item()
            train_total += labels.size(0)
        
        avg_train_loss = train_loss / len(train_loader)
        avg_cont_loss = cont_loss_total / len(train_loader)
        train_acc = train_correct / train_total
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_acc)
        
        print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Contrastive Loss: {avg_cont_loss:.4f}, Train Acc: {train_acc:.4f}")
        
        if val_loader is not None:
            val_acc, val_loss = validate_model(model, val_loader, criterion, device)
            history['val_loss'].append(val_loss)
            history['val_acc'].append(val_acc)
            
            print(f"Epoch {epoch+1}/{epochs} - Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
            
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                torch.save(model.state_dict(), "best_model.pth")
                print(f"Saved new best model with validation accuracy: {best_val_acc:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    model.load_state_dict(torch.load("best_model.pth"))
                    break
    
    if val_loader is None or patience_counter < patience:
        torch.save(model.state_dict(), "final_model.pth")
    
    plot_training_history(history)
    return model


def validate_model(model, val_loader, criterion, device):
    """Validate model on validation set"""
    model.eval()
    val_loss = val_correct = val_total = 0
    
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation"):
            cve_input = {
                'input_ids': batch['cve_input_ids'].to(device),
                'attention_mask': batch['cve_attention_mask'].to(device)
            }
            
            tech_input = {
                'input_ids': batch['tech_input_ids'].to(device),
                'attention_mask': batch['tech_attention_mask'].to(device)
            }
            
            labels = batch['labels'].to(device)
            weights = batch['role_weights'].to(device)
            
            # Forward pass
            outputs, cont_loss = model(cve_input, tech_input, labels)
            
            # Calculate loss
            hinge_loss = criterion(outputs, labels, weights)
            total_loss = hinge_loss + cont_loss * model.contrastive_weight
            
            val_loss += total_loss.item()
            
            # Calculate accuracy
            preds = torch.sign(outputs)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
    
    return val_correct/val_total, val_loss/len(val_loader)

def plot_training_history(history):
    """Plot training and validation metrics"""
    plt.figure(figsize=(12, 5))
    
    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    if 'val_loss' in history and history['val_loss']:
        plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss Curves')
    
    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    if 'val_acc' in history and history['val_acc']:
        plt.plot(history['val_acc'], label='Val Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy Curves')
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.close()

# 8. Modified Evaluation without FAISS - Focus on Test Results
def evaluate_model(model, test_loader, device):
    """Evaluate model and output detailed results on test set"""
    model.eval()
    
    # For metrics
    all_preds = []
    all_true = []
    all_outputs = []
    test_correct = test_total = 0
    
    # Test sample storage for detailed analysis
    test_samples = {
        'cve_text': [],
        'tech_text': [],
        'true_label': [],
        'predicted_label': [],
        'confidence_score': []
    }
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating on test set"):
            # Get batch data
            cve_input = {
                'input_ids': batch['cve_input_ids'].to(device),
                'attention_mask': batch['cve_attention_mask'].to(device)
            }
            
            tech_input = {
                'input_ids': batch['tech_input_ids'].to(device),
                'attention_mask': batch['tech_attention_mask'].to(device)
            }
            
            labels = batch['labels'].to(device)
            
            # Forward pass
            outputs, _ = model(cve_input, tech_input, labels)
            
            # Calculate accuracy
            preds = torch.sign(outputs)
            test_correct += (preds == labels).sum().item()
            test_total += labels.size(0)
            
            all_preds.extend(preds.cpu().numpy())
            all_true.extend(labels.cpu().numpy())
            all_outputs.extend(outputs.cpu().numpy())
            
            # Store samples for analysis
            for i in range(len(batch['cve_input_ids'])):
                test_samples['cve_text'].append(batch['CVE_text'][i])
                test_samples['tech_text'].append(batch['Technique_text'][i])
                test_samples['true_label'].append(float(labels[i].cpu().numpy()))
                test_samples['predicted_label'].append(float(preds[i].cpu().numpy()))
                test_samples['confidence_score'].append(float(outputs[i].cpu().numpy()))
    
    # Calculate metrics
    test_acc = test_correct / test_total
    print(f"Test Accuracy: {test_acc:.4f}")
    
    # Create a DataFrame for test results for easier analysis and output
    test_results_df = pd.DataFrame(test_samples)
    
    # Convert labels from -1/1 to 0/1 for readability
    test_results_df['true_label'] = (test_results_df['true_label'] + 1) / 2
    test_results_df['predicted_label'] = (test_results_df['predicted_label'] + 1) / 2
    
    # Add a column for correct/incorrect predictions
    test_results_df['correct'] = test_results_df['true_label'] == test_results_df['predicted_label']
    
    # Save detailed test results to CSV
    test_results_df.to_csv('test_results_detailed.csv', index=False)
    print(f"Saved detailed test results to 'test_results_detailed.csv'")
    
    # Classification report
    all_true_01 = [(label + 1) / 2 for label in all_true]  # Convert -1/1 to 0/1
    all_preds_01 = [(pred + 1) / 2 for pred in all_preds]  # Convert -1/1 to 0/1
    
    class_report = classification_report(all_true_01, all_preds_01, output_dict=True)
    print("\nClassification Report:")
    for label, metrics in class_report.items():
        if label in ['0.0', '1.0']:
            print(f"Class {label}: Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}, F1: {metrics['f1-score']:.4f}")
    
    # Confusion matrix
    cm = confusion_matrix(all_true_01, all_preds_01)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Negative', 'Positive'], yticklabels=['Negative', 'Positive'])
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    # Sample analysis: Show some correct and incorrect examples
    print("\n=== Sample Correct Predictions ===")
    correct_samples = test_results_df[test_results_df['correct']].head(5)
    for i, row in correct_samples.iterrows():
        print(f"True label: {int(row['true_label'])}, Predicted: {int(row['predicted_label'])}, Confidence: {row['confidence_score']:.4f}")
        print(f"CVE excerpt: {row['cve_text'][:100]}...")
        print(f"Technique excerpt: {row['tech_text'][:100]}...")
        print("-" * 50)
    
    print("\n=== Sample Incorrect Predictions ===")
    incorrect_samples = test_results_df[~test_results_df['correct']].head(5)
    for i, row in incorrect_samples.iterrows():
        print(f"True label: {int(row['true_label'])}, Predicted: {int(row['predicted_label'])}, Confidence: {row['confidence_score']:.4f}")
        print(f"CVE excerpt: {row['cve_text'][:100]}...")
        print(f"Technique excerpt: {row['tech_text'][:100]}...")
        print("-" * 50)
    
    # Distribution of confidence scores
    plt.figure(figsize=(10, 6))
    correct_scores = test_results_df[test_results_df['correct']]['confidence_score']
    incorrect_scores = test_results_df[~test_results_df['correct']]['confidence_score']
    
    plt.hist(correct_scores, alpha=0.7, label='Correct predictions', bins=20)
    plt.hist(incorrect_scores, alpha=0.7, label='Incorrect predictions', bins=20)
    plt.title('Distribution of Model Confidence Scores')
    plt.xlabel('Confidence Score')
    plt.ylabel('Count')
    plt.legend()
    plt.savefig('confidence_distribution.png')
    plt.close()
    
    # Error analysis summary
    print("\n=== Error Analysis Summary ===")
    print(f"Total test samples: {len(test_results_df)}")
    print(f"Correct predictions: {len(test_results_df[test_results_df['correct']])} ({len(test_results_df[test_results_df['correct']])/len(test_results_df)*100:.2f}%)")
    print(f"Incorrect predictions: {len(test_results_df[~test_results_df['correct']])} ({len(test_results_df[~test_results_df['correct']])/len(test_results_df)*100:.2f}%)")
    
    # False positives and false negatives
    false_positives = test_results_df[(test_results_df['true_label'] == 0) & (test_results_df['predicted_label'] == 1)]
    false_negatives = test_results_df[(test_results_df['true_label'] == 1) & (test_results_df['predicted_label'] == 0)]
    
    print(f"False positives: {len(false_positives)} ({len(false_positives)/len(test_results_df)*100:.2f}%)")
    print(f"False negatives: {len(false_negatives)} ({len(false_negatives)/len(test_results_df)*100:.2f}%)")
    
    return test_acc, test_results_df

# Check if we should load an existing model
def main():
    """Main function to run the SRL model training and evaluation pipeline"""
    # Set seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Load data
    print("Loading SRL dataset...")
    try:
        df = load_srl_dataset("siamese_samples_with_srl (6).json")
        if df.empty:
            print("Error: Dataset is empty. Please check the data file.")
            return
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return
    
    # Split data
    try:
        train_df, val_df, test_df = get_splits_for_model(df)
    except Exception as e:
        print(f"Error splitting data: {e}")
        return
    
    # Initialize tokenizer
    try:
        tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    except Exception as e:
        print(f"Error loading tokenizer: {e}")
        return
    
    # Check if we need to train or load a saved model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_path = "best_model.pth"
    
    if os.path.exists(model_path):
        print(f"Loading pre-trained model from {model_path}")
        try:
            model = EnhancedContrastiveSRLModel().to(device)
            model.load_state_dict(torch.load(model_path, map_location=device))
        except Exception as e:
            print(f"Error loading model: {e}")
            print("Training a new model instead...")
            model = train_hinge_model(
                train_df=train_df,
                val_df=val_df,
                epochs=10,
                batch_size=16,
                margin=1.0,
                patience=3,
                lr=2e-5,
                contrastive_weight=0.4
            )
    else:
        print("No pre-trained model found. Training a new model...")
        model = train_hinge_model(
            train_df=train_df,
            val_df=val_df,
            epochs=50,
            batch_size=16,
            margin=1.0,
            patience=3,
            lr=2e-5,
            contrastive_weight=0.4
        )
    
    # Create test dataset and dataloader
    print("Preparing test data...")
    test_dataset = HingeSiameseDataset(test_df, tokenizer)
    test_loader = DataLoader(test_dataset, batch_size=16)
    
    # Evaluate model on test set
    print("Evaluating model on test set...")
    test_acc, test_results_df = evaluate_model(model, test_loader, device)
    
    # Generate embeddings visualization
    print("Generating embedding visualizations...")
    try:
        # Sample a subset for visualization (t-SNE can be slow with large datasets)
        vis_sample = test_df.sample(min(500, len(test_df))).reset_index(drop=True)
        vis_dataset = HingeSiameseDataset(vis_sample, tokenizer)
        vis_loader = DataLoader(vis_dataset, batch_size=16)
        
        # Extract embeddings
        cve_embeddings = []
        tech_embeddings = []
        labels = []
        
        model.eval()
        with torch.no_grad():
            for batch in tqdm(vis_loader, desc="Extracting embeddings"):
                cve_input = {
                    'input_ids': batch['cve_input_ids'].to(device),
                    'attention_mask': batch['cve_attention_mask'].to(device)
                }
                tech_input = {
                    'input_ids': batch['tech_input_ids'].to(device),
                    'attention_mask': batch['tech_attention_mask'].to(device)
                }
                
                # Get embeddings using the model's embedding extraction method
                cve_emb = model.get_embeddings(cve_input)
                tech_emb = model.get_embeddings(tech_input)
                
                cve_embeddings.extend(cve_emb)
                tech_embeddings.extend(tech_emb)
                labels.extend(batch['labels'].numpy())
        
        # Convert to arrays
        cve_embeddings = np.array(cve_embeddings)
        tech_embeddings = np.array(tech_embeddings)
        labels = np.array(labels)
        
        # Apply t-SNE for dimensionality reduction
        print("Applying t-SNE for visualization...")
        tsne = TSNE(n_components=2, random_state=42, perplexity=30)
        
        # Combine embeddings for visualization
        combined_embeddings = np.vstack([cve_embeddings, tech_embeddings])
        combined_labels = np.concatenate([labels, labels])
        
        # Add a type indicator (0 for CVE, 1 for Technique)
        types = np.concatenate([np.zeros(len(cve_embeddings)), np.ones(len(tech_embeddings))])
        
        # Apply t-SNE
        reduced_embeddings = tsne.fit_transform(combined_embeddings)
        
        # Plot
        plt.figure(figsize=(12, 10))
        
        # Convert to 0/1 label (from -1/+1)
        binary_labels = (combined_labels + 1) / 2
        
        # Create a scatter plot with four categories:
        # 1. CVE - Negative, 2. CVE - Positive, 3. Technique - Negative, 4. Technique - Positive
        markers = {0: 'o', 1: '^'}  # circle for CVE, triangle for Technique
        colors = {0: 'red', 1: 'blue'}  # red for negative, blue for positive
        
        for t in [0, 1]:  # type (CVE or Technique)
            for l in [0, 1]:  # label (0 or 1)
                mask = (types == t) & (binary_labels == l)
                plt.scatter(
                    reduced_embeddings[mask, 0],
                    reduced_embeddings[mask, 1],
                    marker=markers[t],
                    c=colors[l],
                    alpha=0.7,
                    label=f"{'CVE' if t == 0 else 'Technique'} - {'Negative' if l == 0 else 'Positive'}"
                )
        
        plt.legend()
        plt.title('t-SNE Visualization of CVE and Technique Embeddings')
        plt.savefig('embeddings_visualization.png')
        plt.close()
        print("Saved embeddings visualization to 'embeddings_visualization.png'")
        
    except Exception as e:
        print(f"Error generating visualizations: {e}")
    
    print("Pipeline completed successfully!")

# Run the main function if this script is executed directly
if __name__ == "__main__":
    main()

Loading SRL dataset...


Processing SRL data: 100%|██████████| 6759/6759 [00:01<00:00, 4927.67it/s]


Loaded 6759 valid samples
Train: 4731, Validation: 1014, Test: 1014
Train label distribution: {0: 3581, 1: 1150}
Val label distribution: {0: 767, 1: 247}
Test label distribution: {0: 767, 1: 247}
No pre-trained model found. Training a new model...
Using device: cpu


Epoch 1/50 - Training:  14%|█▍        | 41/296 [03:21<20:52,  4.91s/it]


KeyboardInterrupt: 