In [3]:
import pickle
import ipywidgets as widgets
from IPython.display import display, clear_output
from transformers import BertTokenizer, BertForSequenceClassification
from iocextract import extract_iocs
import lime
from lime.lime_text import LimeTextExplainer
import re

# Load the dictionary back from the pickle file
with open('technique_dictMitre.pkl', 'rb') as handle:
    loaded_technique_dict = pickle.load(handle)

import os
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification


# Function to split input text into chunks based on '#'
def chunk_training_data(text):
    return [chunk.strip() for chunk in text.split("#") if chunk.strip()]

# Model and tokenizer initialization
model_path = './saved_distilbert_model_Sec_Tram7'
tokenizer = BertTokenizer.from_pretrained(model_path)
model = BertForSequenceClassification.from_pretrained(model_path)



def chunk_training_data(text):
    """ Splits the provided training data sample into separate entries. """
    #return re.split(r'\n', text.strip())
    return [chunk.strip() for chunk in text.split("#") if chunk.strip()]
'''def chunk_sentences(text):
    """ Splits the provided text data based on hyphens. """
    # Split by hyphen and filter out empty strings
    return [chunk.strip() for chunk in text.split("#") if chunk.strip()]'''
def chunk_sentences(text, chunk_size=3):
    if not isinstance(text, str):
        print("Text passed to chunk_sentences is not a string:", type(text), text)
        return []

    sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?)\s', text)
    return [' '.join(sentences[i:i + chunk_size]) for i in range(0, len(sentences), chunk_size)]


def predict_mitre_technique_and_extract_iocs(text):
    """ Predicts the technique and extracts IoCs from a given text. """
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
    outputs = model(**inputs)
    logits = outputs.logits
    # Using a sigmoid function to get the probability of each label being positive
    predicted_proba = logits.sigmoid()
    # Using the confidence threshold for multi-label predictions
    predicted_classes = [i for i, proba in enumerate(predicted_proba[0]) if proba > confidence_threshold.value]
    iocs = extract_iocs(text)
    return predicted_classes, iocs, predicted_proba  # Return all predicted probabilities


def predict_proba(texts):
    """ Provides probability estimates for the given texts. """
    outputs = model(**tokenizer(texts, return_tensors='pt', truncation=True, padding=True, max_length=512))
    return outputs.logits.softmax(dim=1).detach().numpy()

explainer = LimeTextExplainer(class_names=['Class_0', 'Class_1'])  # Replace class names accordingly

def explain_prediction(text):
    """ Provides an explanation for the prediction. """
    explanation = explainer.explain_instance(text, predict_proba, num_features=10)
    return explanation.show_in_notebook()

# Widgets for user input
ti_data_input = widgets.Textarea(
    value='',
    placeholder='Enter text data here...',
    description='Text Data:',
    disabled=False,
    layout=widgets.Layout(width='90%', height='200px')
)

confidence_threshold = widgets.FloatSlider(
    value=0.6,
    min=0.0,
    max=1.0,
    step=0.01,
    description='Confidence:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f',
)

model_explainability = widgets.Checkbox(
    value=False,
    description='Model Explainability',
    disabled=False,
)

submit_button = widgets.Button(description="Submit")




def process_chunk(chunk):
    """Processes a single chunk from the training data."""
    parts = re.split(r'\s+https://', chunk)
    
    # Check if the chunk contains a link and can be split into the expected parts
    if len(parts) >= 3:
        title = parts[0]
        link = "https://" + parts[1]
        # Combining rest of the parts for description and tech_id extraction
        rest_of_the_data = ' '.join(parts[2:])
        description, tech_id = rest_of_the_data.rsplit(" ", 1)
        
        # Check if the extracted parts are valid
        if all([title, link, description, tech_id]):
            return title, link, description, tech_id
    # If the chunk doesn't fit the expected format, treat the whole chunk as a description
    return None, None, chunk, None


# Initialize an empty list to save IDs
saved_ids = []
#from IPython.core.display import display, HTML
def on_submit_button_click(button):
    global saved_ids  # Declare saved_ids as global to modify it
    
    # Clear saved_ids to avoid duplicates
    saved_ids.clear()
    
    # Clear previous outputs
    clear_output(wait=True)
    
    # Collect input data
    text_data = ti_data_input.value
    
    # Process input data
    chunks = chunk_training_data(text_data)
    for chunk in chunks:
        # Ensure chunk is not empty
        technique_ids, iocs, confidence_scores = predict_mitre_technique_and_extract_iocs(chunk)  # Get multiple techniques and their confidences
        print("\n" + "**"*50 + "\n")
        print(f"Threat_Summary: {chunk}")
        #display(HTML(f"<b>Threat_Summary:</b> {chunk}"))
        
        for tech_id in technique_ids:
            # Only print the predictions with confidence score > 0.90
            if confidence_scores[0][tech_id].item() > 0.95:
                technique_details = loaded_technique_dict.get(tech_id, {})
                
                print(f"Predicted Technique ID: {tech_id}")
                print(f"Confidence Score: {confidence_scores[0][tech_id].item():.2f}")  # Display the confidence score for each predicted technique
                print(f"Extracted IoCs: {iocs}")
                
                # Print the technique details from the loaded dictionary
                technique_id = technique_details.get('ID', 'N/A').strip()  # Strip spaces
                print(f"ID: {technique_id}")
                saved_ids.append(technique_id)  # Save the predicted ID
                print(f"Technique Name: {technique_details.get('Technique Name', 'N/A')}")
                key = "Technique Name's Webpage Link"
                print(f"Technique Name's Webpage Link: {technique_details.get(key, 'N/A')}")
                print(f"Technique Description: {technique_details.get('Technique Description', 'N/A')}")
                
                # If Model Explainability is checked
                if model_explainability.value:
                    explain_prediction(chunk)
                print("\n" + "-"*50 + "\n")
    
    # After all chunks are processed, display the saved IDs
    print("Saved IDs:", saved_ids)
 
    # Display the widgets again for subsequent use
    display(ti_data_input, confidence_threshold, model_explainability, submit_button)


# Link the button to the function
submit_button.on_click(on_submit_button_click)

# Initial display of the widgets
#display(ti_data_input, confidence_threshold, model_explainability, submit_button)
# Add this text before the input box
intro_text = widgets.HTML("Add a summary of your threat using <code>#</code>:")
display(intro_text, ti_data_input, confidence_threshold, model_explainability, submit_button)


HTML(value='Add a summary of your threat using <code>#</code>:')

Textarea(value='', description='Text Data:', layout=Layout(height='200px', width='90%'), placeholder='Enter te…

FloatSlider(value=0.6, continuous_update=False, description='Confidence:', max=1.0, step=0.01)

Checkbox(value=False, description='Model Explainability')

Button(description='Submit', style=ButtonStyle())

In [2]:
%cd /Users/fardin/mitregpt

/Users/fardin/mitregpt
