# Cisco Foundation Model Quickstart - SageMaker Endpoint
This notebook demonstrates basic usage of the Cisco Foundation Security model via SageMaker endpoint.

**Note**: Update the configuration variables below to match your deployment.

## Configuration
Update these variables to match your SageMaker deployment:

In [None]:
# UPDATE THESE VARIABLES TO MATCH YOUR DEPLOYMENT
endpoint_name = 'foundation-sec-8b-endpoint'  # Your SageMaker endpoint name
aws_region = 'us-east-1'  # Your AWS region

print(f"Configuration:")
print(f"Endpoint: {endpoint_name}")
print(f"Region: {aws_region}")

## Prerequisites Installation

In [None]:
# Install required packages
!pip install transformers torch --quiet

print("Required packages installed successfully!")

In [None]:
import boto3
import json
import re
from IPython.display import display, Markdown

# Initialize SageMaker runtime client
sagemaker_runtime = boto3.client('sagemaker-runtime', region_name=aws_region)

print(f"Connected to SageMaker endpoint: {endpoint_name}")

## Generation Configuration
Configure the model's text generation behavior:

In [None]:
# Generation arguments for reproducible outputs
generation_args = {
    "max_new_tokens": 1024,
    "temperature": None,  # None means deterministic (temperature=0)
    "repetition_penalty": 1.2,
    "do_sample": False,   # Deterministic sampling
    "use_cache": True,
    # Note: eos_token_id and pad_token_id are handled by the TGI server
}

print("Default generation configuration:")
for key, value in generation_args.items():
    print(f"  {key}: {value}")

In [None]:
# System prompt configuration
DEFAULT_SYSTEM_PROMPT = "You are a cybersecurity expert."
# The system prompt is for demo purpose.
# We have developed a detailed system prompt for general user interaction, which was tested
# in internal testing and found that it improved user satisfaction and safety.

def inference(request, system_prompt=DEFAULT_SYSTEM_PROMPT, custom_args=None):
    """Inference function that mimics the local model behavior but uses SageMaker endpoint"""
    
    # Handle different request formats
    if isinstance(request, str):
        messages = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": request},
        ]
    elif isinstance(request, list):
        if request[0].get("role") != "system":
            messages = [{"role": "system", "content": system_prompt}] + request
        else:
            messages = request
    else:
        raise ValueError(
            "Request is not well formed. It must be a string or list of dict with correct format."
        )
    
    # Format the conversation for the model
    formatted_prompt = ""
    for message in messages:
        role = message["role"]
        content = message["content"]
        if role == "system":
            formatted_prompt += f"System: {content}\n\n"
        elif role == "user":
            formatted_prompt += f"User: {content}\n\n"
        elif role == "assistant":
            formatted_prompt += f"Assistant: {content}\n\n"
    
    formatted_prompt += "Assistant: "
    
    # Use custom args if provided, otherwise use defaults
    args_to_use = custom_args if custom_args else generation_args
    
    # Prepare payload for SageMaker endpoint
    payload = {
        "inputs": formatted_prompt,
        "parameters": args_to_use
    }
    
    try:
        response = sagemaker_runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType='application/json',
            Body=json.dumps(payload)
        )
        
        result = json.loads(response['Body'].read().decode())
        
        # Handle different TGI response formats
        if isinstance(result, list) and len(result) > 0:
            generated_text = result[0].get('generated_text', '')
        elif isinstance(result, dict):
            generated_text = result.get('generated_text', str(result))
        else:
            generated_text = str(result)
        
        # Clean up the response (remove the original prompt if it's included)
        if generated_text.startswith(formatted_prompt):
            response_text = generated_text[len(formatted_prompt):].strip()
        else:
            response_text = generated_text.strip()
            
        # Remove any trailing special tokens
        response_text = re.sub(r'<\|.*?\|>$', '', response_text).strip()
        
        return response_text
        
    except Exception as e:
        print(f"Error invoking endpoint: {str(e)}")
        return f"Error: {str(e)}"

# Test the inference function
test_response = inference("Hello, can you help me with cybersecurity?")
print("Test Response:")
print(test_response)

## Example Usage

In [None]:
# Example 1: Basic security question
response1 = inference("Explain the importance of network segmentation in cybersecurity.")
print("=== Network Segmentation Question ===")
display(Markdown(response1))
print("\n" + "="*60 + "\n")

In [None]:
# Example 2: Threat analysis with custom generation args
custom_args = {
    "max_new_tokens": 400,
    "temperature": 0.3,  # Slightly more creative
    "repetition_penalty": 1.1,
    "do_sample": True,
}

threat_query = """Analyze this network activity and identify potential security concerns:

- Multiple connections from IP 192.168.1.100 to various external IPs on port 443
- Unusual data transfer volumes (10GB outbound in 1 hour)
- Connections occurring outside business hours (2-4 AM)
- User account: john.doe@company.com

What could this indicate and what steps should be taken?"""

response2 = inference(threat_query, custom_args=custom_args)
print("=== Threat Analysis ===")
display(Markdown(response2))
print("\n" + "="*60 + "\n")

In [None]:
# Example 3: Multi-turn conversation
conversation = [
    {"role": "user", "content": "What is a zero-day vulnerability?"},
    {"role": "assistant", "content": "A zero-day vulnerability is a security flaw in software that is unknown to the vendor and has no available patch. Attackers can exploit these vulnerabilities before developers can create and distribute fixes."},
    {"role": "user", "content": "How can organizations protect themselves against zero-day attacks?"}
]

response3 = inference(conversation)
print("=== Multi-turn Conversation ===")
display(Markdown(response3))

## Advanced Configuration Examples

In [None]:
# Example with different generation settings for different use cases

# Conservative settings for factual responses
conservative_args = {
    "max_new_tokens": 300,
    "temperature": None,  # Deterministic
    "repetition_penalty": 1.2,
    "do_sample": False,
}

# Creative settings for brainstorming
creative_args = {
    "max_new_tokens": 500,
    "temperature": 0.8,
    "repetition_penalty": 1.1,
    "do_sample": True,
}

question = "List 5 innovative ways to improve cybersecurity awareness in an organization."

print("=== Conservative Response ===")
conservative_response = inference(question, custom_args=conservative_args)
display(Markdown(conservative_response))

print("\n=== Creative Response ===")
creative_response = inference(question, custom_args=creative_args)
display(Markdown(creative_response))

## Custom System Prompt Example

In [None]:
# Example with custom system prompt
custom_system_prompt = """You are a senior cybersecurity consultant with 15+ years of experience in enterprise security. 
You provide detailed, actionable advice and always consider both technical and business implications. 
Your responses should be professional but accessible to both technical and non-technical stakeholders."""

consultant_query = "Our company wants to implement zero-trust architecture. What are the key steps and considerations?"

response4 = inference(consultant_query, system_prompt=custom_system_prompt)
print("=== Senior Consultant Response ===")
display(Markdown(response4))

## Interactive Testing
Use this cell to test your own prompts:

In [None]:
# Add your own prompt here for testing
your_prompt = "Explain the MITRE ATT&CK framework and its main components."

your_response = inference(your_prompt)
print("=== Your Custom Prompt ===")
print(f"Prompt: {your_prompt}")
print("\nResponse:")
display(Markdown(your_response))