<a href="https://colab.research.google.com/github/kutyadog/ai_notebooks/blob/main/Cybersecurity_AI_Helper_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# 1. Install necessary libraries
# This cell needs to be run first to ensure all dependencies are available.
print("Installing required libraries...")
!pip install -q groq transformers torch
print("Libraries installed.")

In [None]:
# -*- coding: utf-8 -*-
"""
This Colab notebook provides a proof-of-concept for integrating a Groq AI chatbot
with a local cybersecurity base transformer model (fdtn-ai/Foundation-Sec-8B).

The chatbot will primarily use the Groq AI model for general conversation.
However, if the user's query is identified as cybersecurity-related,
the query will be routed to the Foundation-Sec-8B model for a specialized response.

This code is designed to be run in a Google Colab environment, ideally with a GPU
runtime for faster inference with the Foundation-Sec-8B model.
"""



# 2. Set up your GROQ_API_KEY
# IMPORTANT: Replace 'YOUR_GROQ_API_KEY' with your actual Groq API key.
# For better security, consider storing this in Colab Secrets or environment variables.
# Example:
# from google.colab import userdata
# GROQ_API_KEY = userdata.get('GROQ_API_KEY')
# If you don't have it in secrets, you can paste it directly here for a quick test:
import os
GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "YOUR_GROQ_API_KEY") # Replace with your actual key or set as environment variable

if GROQ_API_KEY == "YOUR_GROQ_API_KEY":
    print("\nWARNING: Please replace 'YOUR_GROQ_API_KEY' with your actual Groq API Key.")
    print("You can get one from: https://console.groq.com/keys")
    print("Alternatively, set it as a Colab Secret named 'GROQ_API_KEY'.")
    # Exit or prompt for key if not set, for this PoC we'll proceed but it will fail if key is invalid.

# 3. Import and initialize Groq client
from groq import Groq
from google.colab import userdata


try:
    groq_client = Groq(api_key=userdata.get('GROQ_API_KEY'))
    print("\nGroq client initialized.")
except Exception as e:
    print(f"\nError initializing Groq client: {e}")
    print("Please ensure your GROQ_API_KEY is correct and valid.")
    groq_client = None # Set to None if initialization fails

# 4. Load the Cybersecurity Transformer Model
# This might take a few minutes depending on your Colab runtime and internet speed.
print("\nLoading Foundation-Sec-8B model and tokenizer...")
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

try:
    # Attempt to load the model in half-precision (float16) to reduce RAM usage.
    # This is often sufficient to fit larger models into Colab's free tier RAM.
    cybersecurity_tokenizer = AutoTokenizer.from_pretrained("fdtn-ai/Foundation-Sec-8B")
    cybersecurity_model = AutoModelForCausalLM.from_pretrained(
        "fdtn-ai/Foundation-Sec-8B",
        torch_dtype=torch.float16 # Load in half-precision
    )
    # Move model to GPU if available for faster inference
    if torch.cuda.is_available():
        cybersecurity_model.to("cuda")
        print("Foundation-Sec-8B model moved to GPU.")
    else:
        print("CUDA not available, Foundation-Sec-8B model will run on CPU (may be slower).")
    print("Foundation-Sec-8B model loaded successfully.")
except Exception as e:
    print(f"Error loading Foundation-Sec-8B model: {e}")
    print("Please check the model path, your internet connection, and available RAM.")
    print("If the issue persists, consider a Colab Pro subscription or using a smaller model.")
    cybersecurity_model = None
    cybersecurity_tokenizer = None

# 5. Define a function to interact with the Cybersecurity Model
def get_cybersecurity_response(prompt: str) -> str:
    """
    Generates a response using the Foundation-Sec-8B model for cybersecurity queries.
    """
    if cybersecurity_model is None or cybersecurity_tokenizer is None:
        return "Cybersecurity model is not loaded or initialized. Cannot process this request."

    # For this PoC, we'll try to frame the prompt to match the example given.
    # In a real application, you'd need more sophisticated prompt engineering
    # or fine-tuning for various cybersecurity tasks.
    formatted_prompt = f"Analyze the following cybersecurity context: {prompt}\n\nProvide relevant cybersecurity insights or information."

    try:
        inputs = cybersecurity_tokenizer(formatted_prompt, return_tensors="pt")
        if torch.cuda.is_available():
            inputs = {k: v.to("cuda") for k, v in inputs.items()} # Move inputs to GPU

        outputs = cybersecurity_model.generate(
            inputs["input_ids"],
            max_new_tokens=100, # Increased tokens for potentially more detailed responses
            do_sample=True,
            temperature=0.7, # Adjusted for more diverse but still coherent output
            top_p=0.9,
            pad_token_id=cybersecurity_tokenizer.eos_token_id # Important for generation
        )

        response = cybersecurity_tokenizer.decode(outputs[0], skip_special_tokens=True)
        # Remove the original prompt from the response
        response = response.replace(formatted_prompt, "").strip()
        return response
    except Exception as e:
        return f"Error generating response from cybersecurity model: {e}"

# 6. Define a function to interact with the Groq AI Model
def get_groq_response(prompt: str, chat_history: list) -> str:
    """
    Generates a response using the Groq AI model for general queries.
    """
    if groq_client is None:
        return "Groq AI client is not initialized. Cannot process this request."

    messages = chat_history + [{"role": "user", "content": prompt}]
    try:
        chat_completion = groq_client.chat.completions.create(
            messages=messages,
            model="llama3-8b-8192", # Or "mixtral-8x7b-32768" or "gemma-7b-it"
            temperature=0.7,
            max_tokens=1024,
            top_p=1,
            stop=None,
            stream=False,
        )
        return chat_completion.choices[0].message.content
    except Exception as e:
        return f"Error generating response from Groq AI: {e}"

# 7. Main Chatbot Loop
def main_chatbot():
    print("\n--- Groq AI Chatbot with Cybersecurity Integration ---")
    print("Type 'exit' or 'quit' to end the conversation.")
    print("Try asking cybersecurity-related questions (e.g., 'What is CVE-2021-44228?', 'Explain MITRE ATT&CK').")
    print("For general questions, the Groq AI will respond.")

    chat_history = []

    while True:
        user_input = input("\nYou: ").strip()
        if user_input.lower() in ["exit", "quit"]:
            print("Chatbot: Goodbye!")
            break

        # Simple keyword-based detection for cybersecurity queries
        # This can be made much more sophisticated using NLP techniques or even
        # by asking the Groq model itself to classify the query.
        cybersecurity_keywords = [
            "cve", "cwe", "vulnerability", "exploit", "malware", "phishing",
            "ransomware", "firewall", "ids", "ips", "soc", "threat intelligence",
            "incident response", "mitre att&ck", "zero-day", "security posture",
            "compliance", "penetration test", "red team", "blue team", "cyberattack"
        ]
        is_cybersecurity_query = any(keyword in user_input.lower() for keyword in cybersecurity_keywords)

        response = ""
        if is_cybersecurity_query:
            print("Chatbot (Cybersecurity Model): Thinking...")
            response = get_cybersecurity_response(user_input)
        else:
            print("Chatbot (Groq AI): Thinking...")
            response = get_groq_response(user_input, chat_history)

        print(f"Chatbot: {response}")
        chat_history.append({"role": "user", "content": user_input})
        chat_history.append({"role": "assistant", "content": response})

# Run the chatbot
if __name__ == "__main__":
    main_chatbot()



In [None]:
# Run these commands in a Colab cell to install necessary libraries for TPU.
# IMPORTANT: After running this cell, go to Runtime -> Restart runtime,
# then proceed to the next cells. This is crucial for torch_xla to initialize correctly.

# 1. Aggressively uninstall potentially conflicting pre-installed packages.
#    This ensures a clean slate for TPU-specific installations.
#    Added 'thinc' to the uninstall list to resolve NumPy 2.x conflict.
!pip uninstall -y fastai torchaudio torchvision torch accelerate transformers numpy thinc

# 2. Install a compatible version of NumPy (less than 2.0).
!pip install -q numpy==1.26.4 # Explicitly install a NumPy 1.x version

# 3. Install PyTorch and TPU-specific libraries.
#    Note: We are not using CUDA-specific PyTorch here, as we are on TPU.
!pip install -q torch==2.2.1 torchvision==0.17.1 torchaudio==2.2.1
!pip install -q cloud-tpu-client==0.10 torch_xla==2.2

# 4. Install other required libraries.
!pip install -q accelerate transformers requests



In [None]:
!pip uninstall -y thinc

# 2. Install a compatible version of NumPy (less than 2.0).
# !pip install -q scipy

Found existing installation: thinc 8.3.6
Uninstalling thinc-8.3.6:
  Successfully uninstalled thinc-8.3.6


In [None]:
# --- Colab Setup and Library Installations (for TPU) ---


# --- Imports ---
import json
import asyncio
import requests # This is needed for making HTTP requests to the GROQ API
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch # For general PyTorch operations
import torch_xla.core.xla_model as xm # For TPU device management
from google.colab import userdata


# --- Global Variables for Model and Tokenizer ---
# These will be loaded once when the script starts in Colab
foundation_sec_model = None
foundation_sec_tokenizer = None
# Set DEVICE to TPU
DEVICE = xm.xla_device() if 'xla' in str(xm.xla_device()) else "cpu"


# --- GROQ API Configuration ---
# IMPORTANT: Replace "YOUR_GROQ_API_KEY" with your actual GROQ API key.
# You can get one from https://console.groq.com/keys
GROQ_API_KEY = userdata.get('GROQ_API_KEY')
GROQ_CHAT_MODEL = "llama3-8b-8192" # A capable and cheap model from GROQ

# --- Model Loading Function ---
async def load_foundation_sec_model():
    """
    Loads the Foundation-Sec-8B model and tokenizer for TPU.
    This function should be called once at the start of your Colab session.
    """
    global foundation_sec_model, foundation_sec_tokenizer
    if foundation_sec_model is None:
        print("Loading fdtn-ai/Foundation-Sec-8B model and tokenizer for TPU...")

        # Check if TPU is available
        if 'xla' not in str(DEVICE):
            print("ERROR: TPU is not available. Please ensure you have a TPU runtime enabled in Colab.")
            print("Go to Runtime -> Change runtime type -> Hardware accelerator -> TPU.")
            return

        try:
            foundation_sec_tokenizer = AutoTokenizer.from_pretrained("fdtn-ai/Foundation-Sec-8B")
            # Load the model directly to the TPU device.
            # 4-bit quantization (bitsandbytes) is not used with TPUs.
            foundation_sec_model = AutoModelForCausalLM.from_pretrained(
                "fdtn-ai/Foundation-Sec-8B",
                torch_dtype=torch.float16, # Use float16 for memory efficiency on TPU
            )
            # Move model to TPU device
            foundation_sec_model.to(DEVICE)
            foundation_sec_model.eval() # Set model to evaluation mode
            print("Model loaded successfully!")
        except Exception as e:
            print(f"Error loading Foundation-Sec-8B model: {e}")
            print("Please ensure you have a TPU runtime enabled in Colab (Runtime -> Change runtime type).")
            print("If issues persist, try restarting the Colab runtime or checking TPU memory usage.")
            # Fallback or error handling if model loading fails
            foundation_sec_model = None
            foundation_sec_tokenizer = None

# --- Foundation-Sec-8B Inference Function ---
async def generate_cyber_response(prompt: str, max_new_tokens: int = 200) -> str:
    """
    Generates a response using the loaded Foundation-Sec-8B model on TPU.

    Args:
        prompt (str): The input prompt for the model.
        max_new_tokens (int): The maximum number of tokens to generate.

    Returns:
        str: The generated text from the Foundation-Sec-8B model.
    """
    global foundation_sec_model, foundation_sec_tokenizer

    if foundation_sec_model is None or foundation_sec_tokenizer is None:
        print("Foundation-Sec-8B model not loaded. Attempting to load now...")
        await load_foundation_sec_model()
        if foundation_sec_model is None:
            return "Error: Cybersecurity model not available. Please ensure it loaded correctly."

    try:
        # Encode the prompt and move inputs to the TPU device
        inputs = foundation_sec_tokenizer(prompt, return_tensors="pt").to(DEVICE)

        # Generate response
        with torch.no_grad(): # Disable gradient calculation for inference
            outputs = foundation_sec_model.generate(
                inputs["input_ids"], # Use input_ids directly as in the example
                max_new_tokens=max_new_tokens,
                num_return_sequences=1,
                pad_token_id=foundation_sec_tokenizer.eos_token_id, # Good practice for generation
                do_sample=True, # Enable sampling for more diverse responses
                temperature=0.1, # Adjusted to match example
                top_k=50, # Retained for control, example didn't specify but often used with do_sample
                top_p=0.9, # Adjusted to match example
            )

        # Decode the generated tokens
        generated_text = foundation_sec_tokenizer.decode(outputs[0], skip_special_tokens=True)

        # The generated_text will include the prompt, so we need to remove it
        # A simple way is to find the prompt in the generated text and return what comes after.
        # This might need refinement based on how the model typically responds.
        if generated_text.startswith(prompt):
            return generated_text[len(prompt):].strip()
        return generated_text.strip()

    except Exception as e:
        print(f"Error generating response from Foundation-Sec-8B: {e}")
        return "An error occurred while getting cybersecurity information. Check model loading and TPU."

# --- Updated get_cybersecurity_info to use Foundation-Sec-8B ---
async def get_cybersecurity_info(query: str) -> str:
    """
    Queries the Foundation-Sec-8B model for cybersecurity context.

    Args:
        query (str): The user's query related to cybersecurity.

    Returns:
        str: The response from the Foundation-Sec-8B model, acting as cybersecurity context.
    """
    print(f"Querying Foundation-Sec-8B for cybersecurity context: '{query}'")
    # Construct a prompt suitable for the Foundation-Sec-8B model
    # The model is a base model, so it needs a clear instruction.
    cyber_prompt = f"Provide detailed information about the following cybersecurity topic: {query}\n\nAnswer:"
    cyber_context = await generate_cyber_response(cyber_prompt)

    if cyber_context:
        return cyber_context
    else:
        return "No specific cybersecurity context found from Foundation-Sec-8B for this query."

# --- Updated chat_with_cyber_helper to use GROQ API ---
async def chat_with_cyber_helper(user_message: str) -> str:
    """
    Acts as the main chatbot logic, integrating general chat capabilities
    with specialized cybersecurity knowledge from Foundation-Sec-8B,
    and using the GROQ API for general chat responses.

    Args:
        user_message (str): The message from the user.

    Returns:
        str: The chatbot's response.
    """
    # Define keywords to identify cybersecurity-related queries
    cyber_keywords = [
        'cybersecurity', 'vulnerability', 'threat', 'malware', 'exploit',
        'phishing', 'ransomware', 'firewall', 'encryption', 'security',
        'attack', 'breach', 'incident response', 'SOC', 'TTP', 'compliance',
        'patch', 'zero-day', 'DDos', 'APT', 'SIEM', 'IDS', 'IPS', 'VPN',
        'authentication', 'authorization', 'audit', 'forensics', 'risk assessment',
        'cyber attack', 'data breach', 'security policy', 'penetration testing',
        'ethical hacking', 'threat intelligence', 'vulnerability management'
    ]

    is_cyber_related = any(keyword in user_message.lower() for keyword in cyber_keywords)

    system_prompt = ""
    user_content = ""

    if is_cyber_related:
        # If the query is cybersecurity-related, get context from Foundation-Sec-8B
        cyber_context = await get_cybersecurity_info(user_message)

        # Construct a detailed prompt for the GROQ chat model
        system_prompt = f"""
        You are a helpful and knowledgeable AI assistant specializing in cybersecurity.
        Your goal is to provide accurate and relevant information based on the user's query.

        Here is some cybersecurity context that might be relevant to the user's question,
        generated by a specialized cybersecurity model:
        ---
        {cyber_context}
        ---

        Please use this context to inform your answer. If the context is not directly relevant or is insufficient,
        please answer based on your general knowledge about cybersecurity.
        """
        user_content = user_message
    else:
        # If not cybersecurity-related, use a general prompt for the GROQ model
        system_prompt = "You are a helpful AI assistant. Answer the user's question."
        user_content = user_message

    # Prepare the messages for the GROQ API
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_content}
    ]

    # Call the GROQ API to generate the response
    api_url = "https://api.groq.com/openai/v1/chat/completions"
    headers = {
        "Authorization": f"Bearer {GROQ_API_KEY}",
        "Content-Type": "application/json"
    }
    payload = {
        "model": GROQ_CHAT_MODEL,
        "messages": messages,
        "temperature": 0.7,
        "max_tokens": 500 # Adjust as needed
    }

    try:
        print(f"Calling GROQ API with model '{GROQ_CHAT_MODEL}' for general chat response...")
        response = requests.post(api_url, headers=headers, json=payload)
        response.raise_for_status() # Raise an exception for HTTP errors (4xx or 5xx)
        result = response.json()

        if result.get("choices") and result["choices"][0].get("message") and \
           result["choices"][0]["message"].get("content"):
            return result["choices"][0]["message"]["content"]
        else:
            print(f"Unexpected GROQ API response structure: {result}")
            return "I couldn't generate a response based on your query. The API response was unexpected."

    except requests.exceptions.RequestException as e:
        print(f"Error calling GROQ API: {e}")
        return "I apologize, but I encountered an error while trying to process your request to the general AI."
    except json.JSONDecodeError as e:
        print(f"Error decoding JSON from GROQ API response: {e}")
        return "I apologize, but I received an unreadable response from the general AI."
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return "An unexpected error occurred while processing your request."


async def main():
    """
    Main function to run the chatbot interaction loop.
    Includes KeyboardInterrupt handling for graceful exit in Colab.
    """
    print("Welcome to the Cybersecurity AI Helper!")
    print("Initializing models. This might take a moment...")
    await load_foundation_sec_model() # Load the Foundation-Sec-8B model at startup

    if foundation_sec_model is None:
        print("Failed to load cybersecurity model. Chatbot will operate in general mode.")

    print("\nType 'exit' to end the chat.")
    print("Press Ctrl+C or use Colab's 'Stop' button to gracefully interrupt.")

    try:
        while True:
            user_input = input("\nYou: ")
            if user_input.lower() == 'exit':
                print("Goodbye!")
                break

            response = await chat_with_cyber_helper(user_input)
            print(f"Cyber Helper: {response}")
    except KeyboardInterrupt:
        print("\nChatbot interrupted by user. Exiting gracefully.")
    except Exception as e:
        print(f"\nAn unexpected error occurred during chat: {e}")
    finally:
        print("Chat session ended.")


# To run this in Google Colab, execute the cells sequentially:
# 1. The pip install commands.
# 2. The rest of the code.
# 3. Call asyncio.run(main()) in a new cell.
if __name__ == "__main__":
    # In a Colab environment, you would typically run this in a separate cell
    # after defining all functions.
    # asyncio.run(main())
    print("\nTo start the chat in Colab, run 'await main()' in a new cell after executing this one.")
    print("Remember to replace 'YOUR_GROQ_API_KEY' with your actual key.")


In [None]:
await main()