In [None]:
import ast
import json
import logging
import os
from collections import Counter, OrderedDict
from pathlib import Path
from typing import Dict, List, Optional

import numpy as np
import pandas as pd
import spacy
from spacy.lang.en import English
from tqdm.auto import tqdm

# ---------------------------- Logging Setup ----------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# ---------------------------- NLP Processor ----------------------------
class NLPProcessor:
    """Modern NLP processing class with efficient model loading and caching."""
    
    def __init__(self, cache_dir: Optional[str] = None):
        """Initialize the NLP processor with optional caching directory."""
        self.cache_dir = Path(cache_dir) if cache_dir else None
        self.models = {}
        
        # Load base model (en_core_web_lg) 
        self.nlp = self._load_spacy_model("en_core_web_lg")
        
        # Initialize the SpaCy sentencizer using English class
        self.sentencizer = English()
        self.sentencizer.add_pipe("sentencizer")
        
        logger.info("NLP processor initialized successfully")
    
    def _load_spacy_model(self, model_name: str, disable: List[str] = None) -> spacy.language.Language:
        """Load a SpaCy model with caching support."""
        if disable is None:
            disable = []
            
        model_key = f"{model_name}_{'-'.join(disable)}"
        
        if model_key in self.models:
            return self.models[model_key]
        
        logger.info(f"Loading SpaCy model: {model_name} (disabled: {disable})")
        model = spacy.load(model_name, disable=disable)
        self.models[model_key] = model
        return model
    
    def get_model(self, task: str) -> spacy.language.Language:
        """Get a specialized model for a specific NLP task."""
        if task == "ner":
            return self._load_spacy_model("en_core_web_lg", disable=['parser', 'tagger'])
        elif task == "pos":
            return self._load_spacy_model("en_core_web_lg", disable=['ner', 'parser'])
        elif task in ("parse", "dep"):
            return self._load_spacy_model("en_core_web_lg", disable=['ner'])
        elif task == "base":
            return self.nlp
        else:
            return self.nlp
    
    def separate_sentences(self, text: str) -> List[str]:
        """Split text into sentences using the SpaCy sentencizer."""
        doc = self.sentencizer(text)
        sentences = [sent.text.strip() for sent in doc.sents]
        return sentences
    
    def lemmatize(self, word: str, is_verb: bool = False) -> str:
        """Lemmatize a word using SpaCy's model."""
        word = word.lower()
        try:
            doc = self.nlp(word)
            if len(doc) > 0:
                return doc[0].lemma_
        except Exception as e:
            logger.warning(f"SpaCy lemmatization failed: {e}")
        return word

    @staticmethod
    def cos_sim(x: np.ndarray, y: np.ndarray) -> float:
        """Calculate cosine similarity between two vectors."""
        from numpy.linalg import norm
        from numpy import dot
        a = np.array(x)
        b = np.array(y)
        if norm(a) == 0 or norm(b) == 0:
            return 0
        return abs(dot(a, b) / (norm(a) * norm(b)))

# ---------------------------- SpaCy-based SRL Processor ----------------------------
class SpacySRLProcessor:
    """Semantic Role Labeling (SRL) using SpaCy dependency parsing."""
    
    def __init__(self):
        """Initialize the SpaCy-based SRL processor."""
        try:
            logger.info("Loading SpaCy model for SRL processing")
            self.nlp = spacy.load("en_core_web_lg")
            if "parser" not in self.nlp.pipe_names:
                logger.warning("Dependency parser not found in pipeline, enabling it")
                self.nlp.enable_pipe("parser")
            logger.info("SpaCy SRL processor initialized successfully")
        except Exception as e:
            logger.error(f"Failed to initialize SpaCy SRL processor: {e}")
            self.nlp = None
            
    def extract_srl(self, text: str) -> Dict:
        """Extract semantic roles from text using SpaCy's dependency parsing."""
        if not self.nlp:
            logger.error("SRL processor not properly initialized")
            return {"words": [], "verbs": []}
        
        doc = self.nlp(text)
        words = [token.text for token in doc]
        verbs = []
        
        for token in doc:
            # Check if token is a verb (or certain AUX tokens)
            if token.pos_ == "VERB" or (token.pos_ == "AUX" and any(child.dep_ == "xcomp" for child in token.children)):
                verb_dict = {"verb": token.text, "description": text}
                tags = ["O"] * len(words)
                
                # Mark the verb token
                verb_index = token.i
                tags[verb_index] = "B-V"
                
                # Subjects (ARG0)
                for child in token.children:
                    if child.dep_ in ("nsubj", "nsubjpass", "csubj", "csubjpass", "agent"):
                        span = list(child.subtree)
                        for i, span_token in enumerate(span):
                            idx = span_token.i
                            tags[idx] = "B-ARG0" if i == 0 else "I-ARG0"
                
                # Direct objects (ARG1)
                for child in token.children:
                    if child.dep_ in ("dobj", "attr", "oprd"):
                        span = list(child.subtree)
                        for i, span_token in enumerate(span):
                            idx = span_token.i
                            tags[idx] = "B-ARG1" if i == 0 else "I-ARG1"
                
                # Indirect objects (ARG2)
                for child in token.children:
                    if child.dep_ == "iobj":
                        span = list(child.subtree)
                        for i, span_token in enumerate(span):
                            idx = span_token.i
                            tags[idx] = "B-ARG2" if i == 0 else "I-ARG2"
                
                # Prepositional objects (ARG2)
                for child in token.children:
                    if child.dep_ == "prep":
                        for pobj in child.children:
                            if pobj.dep_ == "pobj":
                                span = list(pobj.subtree)
                                prep_span = [child] + span
                                for i, span_token in enumerate(prep_span):
                                    idx = span_token.i
                                    tags[idx] = "B-ARG2" if i == 0 else "I-ARG2"
                
                # Adverbial modifiers (e.g. temporal)
                for child in token.children:
                    if child.dep_ in ("advmod", "npadvmod"):
                        idx = child.i
                        tags[idx] = "B-ARGM-MNR"
                    if child.dep_ == "npadvmod" and child.head.pos_ == "VERB":
                        for grandchild in child.children:
                            if grandchild.ent_type_ in ("DATE", "TIME"):
                                span = list(grandchild.subtree)
                                for i, span_token in enumerate(span):
                                    idx = span_token.i
                                    tags[idx] = "B-ARGM-TMP" if i == 0 else "I-ARGM-TMP"
                
                # Create unique verb ID
                verb_dict["id"] = f"{token.text}_{verb_index}"
                verb_dict["tags"] = tags
                verbs.append(verb_dict)
        
        return {"words": words, "verbs": verbs}
    
    def srl_to_dict(self, srl_output: Dict) -> Dict:
        """Convert SRL output into a structured dictionary format."""
        srl_dict = {}
        for verb in srl_output.get('verbs', []):
            verb_str = verb.get('id', verb.get('verb', 'unknown'))
            srl_dict[verb_str] = {}
            for ind, tag in enumerate(verb.get('tags', [])):
                if tag != 'O':
                    new_tag = tag[tag.find('-') + 1:]
                    if tag.startswith('B'):
                        # Start of a new role
                        if new_tag not in srl_dict[verb_str]:
                            srl_dict[verb_str][new_tag] = {'text': srl_output['words'][ind]}
                        else:
                            srl_dict[verb_str][new_tag]['text'] += f"/ {srl_output['words'][ind]}"
                    else:
                        # Continuing a role
                        if new_tag in srl_dict[verb_str]:
                            srl_dict[verb_str][new_tag]['text'] += f" {srl_output['words'][ind]}"
        return srl_dict

# ---------------------------- File Utilities ----------------------------
class FileUtils:
    """Utilities for file operations such as saving and reading data."""
    
    @staticmethod
    def save_dict_as_json(output_path: str, data_dict: Dict, note: str = '') -> str:
        """Save a dictionary as JSON; optionally add a note in a separate text file."""
        output_path = Path(output_path)
        with open(output_path, 'w', encoding='utf-8') as fp:
            
            json.dump(data_dict, fp, indent=2)
        if note:
            note_path = output_path.with_suffix('.txt')
            with open(note_path, 'w', encoding='utf-8') as fp:
                fp.write(note)
        return 'done'
    
    @staticmethod
    def read_json_as_dict(input_path: str) -> Dict:
        """Read a JSON file into a dictionary."""
        with open(Path(input_path), 'r', encoding='utf-8') as json_file:
            data = json.load(json_file)
        return data

# ---------------------------- Siamese Data Preparation ----------------------------
class SiameseDataPreparer:
    """Prepares data for a Siamese network for CVE-Technique matching."""
    
    def __init__(self, nlp_processor: NLPProcessor, srl_processor: SpacySRLProcessor):
        """Initialize using NLP processor for text and SRL processor for semantic roles."""
        self.nlp = nlp_processor
        self.srl = srl_processor
    
    def prepare_data(
        self,
        cve_file_path: str,
        technique_file_path: str,
        neg_samples_per_cve: int = 3,
        output_path: Optional[str] = None
    ) -> List[Dict]:
        """
        Prepare data for Siamese network training.
        
        Args:
            cve_file_path: Path to Excel file with CVE descriptions and MITRE technique numbers.
            technique_file_path: Path to Excel file with technique IDs and their descriptions.
            neg_samples_per_cve: Number of negative samples to generate per CVE.
            output_path: Optional path to save the processed JSON output.
            
        Returns:
            List of dictionaries with prepared sample pairs.
        """
        logger.info(f"Loading CVE data from {cve_file_path}")
        df_cve = pd.read_excel(cve_file_path)
        
        logger.info(f"Loading technique data from {technique_file_path}")
        df_tech = pd.read_excel(technique_file_path)
        
        # Normalize column names for the techniques file
        df_tech.columns = ["Technique_ID", "Technique_Description"]
        tech_dict = pd.Series(df_tech.Technique_Description.values, index=df_tech.Technique_ID).to_dict()
        
        samples = []
        logger.info("Building Siamese pairs...")
        
        for idx, row in tqdm(df_cve.iterrows(), total=len(df_cve), desc="Processing CVEs"):
            cve_text = row["CVE_Description"]
            cve_id = row.get("CVE_ID", f"Unknown_{idx}")
            
            # ---- Robust handling for MITRE_Technique_Numbers ----
            technique_list = row["MITRE_Technique_Numbers"]
            if isinstance(technique_list, str):
                technique_list = technique_list.strip()
                # If the string looks like a list, attempt to parse it
                if technique_list.startswith("[") and technique_list.endswith("]"):
                    try:
                        technique_list = ast.literal_eval(technique_list)
                    except Exception as e:
                        logger.warning(f"Error parsing list for {cve_id}: {technique_list}, error: {e}")
                        technique_list = [technique_list]
                else:
                    # Assume it is a single technique and wrap it in a list
                    technique_list = [technique_list]
            elif not isinstance(technique_list, list):
                logger.warning(f"Unexpected type for MITRE_Technique_Numbers for {cve_id}: {type(technique_list)}")
                technique_list = []
            
            # Create positive samples: for each technique that matches
            for tech_id in technique_list:
                if tech_id in tech_dict:
                    samples.append({
                        "CVE_ID": cve_id,
                        "CVE_text": cve_text,
                        "Technique_ID": tech_id,
                        "Technique_text": tech_dict[tech_id],
                        "label": 1  # Positive pair
                    })
            
            # Create negative samples: choose techniques not listed for this CVE
            negative_techs = [t for t in tech_dict.keys() if t not in technique_list]
            if negative_techs:
                num_neg_samples = min(neg_samples_per_cve, len(negative_techs))
                sampled_negatives = np.random.choice(negative_techs, num_neg_samples, replace=False)
                for tech_id in sampled_negatives:
                    samples.append({
                        "CVE_ID": cve_id,
                        "CVE_text": cve_text,
                        "Technique_ID": tech_id,
                        "Technique_text": tech_dict[tech_id],
                        "label": 0  # Negative pair
                    })
        
        logger.info(f"Generated {len(samples)} Siamese pairs.")
        
        # Process sentences and extract SRL features
        logger.info("Processing SRL for text pairs...")
        for sample in tqdm(samples, desc="Extracting SRL"):
            sample["CVE_sentences"] = self.nlp.separate_sentences(sample["CVE_text"])
            sample["Technique_sentences"] = self.nlp.separate_sentences(sample["Technique_text"])
            
            sample["CVE_srl"] = []
            for sentence in sample["CVE_sentences"]:
                srl_output = self.srl.extract_srl(sentence)
                srl_dict = self.srl.srl_to_dict(srl_output)
                sample["CVE_srl"].append({
                    "sentence": sentence,
                    "srl_raw": srl_output,
                    "srl_structured": srl_dict
                })
            
            sample["Technique_srl"] = []
            for sentence in sample["Technique_sentences"]:
                srl_output = self.srl.extract_srl(sentence)
                srl_dict = self.srl.srl_to_dict(srl_output)
                sample["Technique_srl"].append({
                    "sentence": sentence,
                    "srl_raw": srl_output,
                    "srl_structured": srl_dict
                })
            
            # Calculate a basic verb match count between the two texts
            sample["verb_match_count"] = self._calculate_verb_match(
                sample["CVE_srl"], sample["Technique_srl"]
            )
            
            # Calculate role match score based on matching role texts
            sample["role_match_score"] = self._calculate_role_match_score(
                sample["CVE_srl"], sample["Technique_srl"]
            )
        
        # Optionally save the prepared data to a JSON file
        if output_path:
            output_path = Path(output_path)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            FileUtils.save_dict_as_json(
                str(output_path),
                {"samples": samples},
                note=f"Generated {len(samples)} Siamese pairs with {sum(s['label'] for s in samples)} positive and {len(samples) - sum(s['label'] for s in samples)} negative samples."
            )
            logger.info(f"Data saved to: {output_path}")
        
        return samples
    
    def _calculate_verb_match(self, cve_srl_list, tech_srl_list) -> int:
        """Calculate the number of matching verbs between CVE and Technique texts."""
        cve_verbs = set()
        for srl_item in cve_srl_list:
            for verb_id in srl_item["srl_structured"].keys():
                verb = verb_id.split('_')[0] if '_' in verb_id else verb_id
                cve_verbs.add(self.nlp.lemmatize(verb, is_verb=True))
        
        tech_verbs = set()
        for srl_item in tech_srl_list:
            for verb_id in srl_item["srl_structured"].keys():
                verb = verb_id.split('_')[0] if '_' in verb_id else verb_id
                tech_verbs.add(self.nlp.lemmatize(verb, is_verb=True))
        
        return len(cve_verbs.intersection(tech_verbs))
    
    def _calculate_role_match_score(self, cve_srl_list, tech_srl_list) -> float:
        """Calculate a similarity score based on matching semantic role texts."""
        cve_roles = {}
        for srl_item in cve_srl_list:
            for verb_id, roles in srl_item["srl_structured"].items():
                for role_type, role_info in roles.items():
                    cve_roles.setdefault(role_type, []).append(role_info["text"])
        
        tech_roles = {}
        for srl_item in tech_srl_list:
            for verb_id, roles in srl_item["srl_structured"].items():
                for role_type, role_info in roles.items():
                    tech_roles.setdefault(role_type, []).append(role_info["text"])
        
        role_similarities = []
        for role_type in set(cve_roles.keys()).intersection(set(tech_roles.keys())):
            if role_type == "V":  # Skip the verb role if already counted
                continue
            role_sim = self._calculate_text_list_similarity(cve_roles[role_type], tech_roles[role_type])
            role_similarities.append(role_sim)
        
        if role_similarities:
            return sum(role_similarities) / len(role_similarities)
        return 0.0
    
    def _calculate_text_list_similarity(self, text_list1: List[str], text_list2: List[str]) -> float:
        """Calculate a similarity score between two lists of texts using SpaCy vectors."""
        max_sim = 0.0
        for text1 in text_list1:
            for text2 in text_list2:
                try:
                    doc1 = self.nlp.get_model("base")(text1)
                    doc2 = self.nlp.get_model("base")(text2)
                    if doc1.vector_norm and doc2.vector_norm:
                        sim = doc1.similarity(doc2)
                        max_sim = max(max_sim, sim)
                except Exception as e:
                    logger.warning(f"Error calculating similarity: {e}")
        return max_sim

# ---------------------------- Main Execution ----------------------------
def main():
    """Main execution function."""
    # Initialize processors
    nlp_processor = NLPProcessor()
    srl_processor = SpacySRLProcessor()
    data_preparer = SiameseDataPreparer(nlp_processor, srl_processor)
    
    # Define input and output file paths (adjust these paths as needed)
    cve_file_path = "cve_single_technique.xlsx"
    technique_file_path = "techniques.xlsx"
    output_path = "siamese_samples_with_srl.json"
    
    # Create output directory if necessary
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    
    # Prepare data
    samples = data_preparer.prepare_data(
        cve_file_path=cve_file_path,
        technique_file_path=technique_file_path,
        neg_samples_per_cve=3,
        output_path=output_path
    )
    
    # Print summary statistics
    positive_count = sum(1 for s in samples if s["label"] == 1)
    negative_count = len(samples) - positive_count
    avg_verb_match_positive = (sum(s["verb_match_count"] for s in samples if s["label"] == 1) / positive_count) if positive_count else 0
    avg_verb_match_negative = (sum(s["verb_match_count"] for s in samples if s["label"] == 0) / negative_count) if negative_count else 0
    avg_role_match_positive = (sum(s["role_match_score"] for s in samples if s["label"] == 1) / positive_count) if positive_count else 0
    avg_role_match_negative = (sum(s["role_match_score"] for s in samples if s["label"] == 0) / negative_count) if negative_count else 0
    
    logger.info("Data preparation complete:")
    logger.info(f"  - Total samples: {len(samples)}")
    logger.info(f"  - Positive pairs: {positive_count}")
    logger.info(f"  - Negative pairs: {negative_count}")
    logger.info(f"  - Average verb matches (positive pairs): {avg_verb_match_positive:.2f}")
    logger.info(f"  - Average verb matches (negative pairs): {avg_verb_match_negative:.2f}")
    logger.info(f"  - Average role match score (positive pairs): {avg_role_match_positive:.2f}")
    logger.info(f"  - Average role match score (negative pairs): {avg_role_match_negative:.2f}")
    logger.info(f"  - Data saved to: {output_path}")

if __name__ == "__main__":
    main()


In [None]:
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')


In [None]:
import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')

In [None]:
!python -m spacy download en_core_web_lg
