In [3]:
from google.colab import userdata
hf_token = userdata.get('HF_TOKEN')

# Use it with huggingface_hub
from huggingface_hub import login
login(token=hf_token)

In [5]:
"""
Enhanced Chat Agent for Llama 3.2-1B-Instruct
Token Budget Context Management with Pickle-based Restart Capability

This implementation includes:
✓ Token-based budgeting context management
✓ Pickle state persistence (NEW!)
✓ Restartable after interruption (NEW!)
✓ History toggle (ON/OFF)
✓ All 4 test cases covered
"""

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pickle
import os
import sys
import signal
from datetime import datetime
from typing import List, Dict, Optional

# ============================================================================
# CONFIGURATION
# ============================================================================

MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
SYSTEM_PROMPT = "You are a helpful AI assistant. Be concise and friendly."
STATE_FILE = "llama_chat_state.pkl"  # Pickle file for state persistence

# ============================================================================
# BEST PRACTICE 2: SET APPROPRIATE MAX LENGTH
# ============================================================================
# From professor's guide: Leave room for response generation
MAX_CONTEXT_LENGTH = 2048  # Total context window for Llama 3.2-1B
MAX_NEW_TOKENS = 512       # Maximum tokens for response
# Effective history limit: 2048 - 512 = 1536 tokens
MAX_CONTEXT_TOKENS = MAX_CONTEXT_LENGTH - MAX_NEW_TOKENS

# History toggle
USE_CONVERSATION_HISTORY = True  # Set to False to disable memory

# ============================================================================
# RESTARTABLE CHAT STATE CLASS
# ============================================================================

class RestartableChatState:
    """
    Manages chat state with pickle persistence for restart capability.
    """

    def __init__(self, state_file: str = STATE_FILE):
        self.state_file = state_file

        # State variables to persist
        self.full_chat_history: List[Dict] = []
        self.working_chat_history: List[Dict] = []
        self.turn_number = 0
        self.session_start_time: Optional[datetime] = None
        self.total_restarts = 0
        self.last_save_time: Optional[datetime] = None

        # Configuration to persist
        self.max_context_tokens = MAX_CONTEXT_TOKENS
        self.max_new_tokens = MAX_NEW_TOKENS
        self.use_conversation_history = USE_CONVERSATION_HISTORY

        # Set up signal handlers for graceful shutdown
        signal.signal(signal.SIGINT, self._signal_handler)
        signal.signal(signal.SIGTERM, self._signal_handler)

        # Try to load existing state
        self._load_state()

    def _signal_handler(self, signum, frame):
        """Handle termination signals gracefully."""
        print(f"\n\n Received signal {signum}. Saving state before exit...")
        self._save_state()
        print(" State saved successfully. Safe to exit.")
        print(f" State file: {self.state_file}")
        sys.exit(0)

    def _save_state(self):
        """Save current state to pickle file."""
        state = {
            'full_chat_history': self.full_chat_history,
            'working_chat_history': self.working_chat_history,
            'turn_number': self.turn_number,
            'session_start_time': self.session_start_time,
            'total_restarts': self.total_restarts,
            'last_save_time': datetime.now(),
            'max_context_tokens': self.max_context_tokens,
            'max_new_tokens': self.max_new_tokens,
            'use_conversation_history': self.use_conversation_history
        }

        try:
            # Atomic write: temp file + rename
            temp_file = f"{self.state_file}.tmp"
            with open(temp_file, 'wb') as f:
                pickle.dump(state, f, protocol=pickle.HIGHEST_PROTOCOL)

            os.replace(temp_file, self.state_file)
            self.last_save_time = state['last_save_time']

        except Exception as e:
            print(f" Error saving state: {e}")
            if os.path.exists(temp_file):
                os.remove(temp_file)

    def _load_state(self):
        """Load state from pickle file if it exists."""
        if not os.path.exists(self.state_file):
            print(" No previous state found. Starting fresh session.")
            self.session_start_time = datetime.now()

            # Initialize with system message
            system_message = {"role": "system", "content": SYSTEM_PROMPT}
            self.full_chat_history.append(system_message)
            self.working_chat_history.append(system_message)
            return

        try:
            with open(self.state_file, 'rb') as f:
                state = pickle.load(f)

            # Restore all state
            self.full_chat_history = state.get('full_chat_history', [])
            self.working_chat_history = state.get('working_chat_history', [])
            self.turn_number = state.get('turn_number', 0)
            self.session_start_time = state.get('session_start_time', datetime.now())
            self.total_restarts = state.get('total_restarts', 0) + 1
            self.last_save_time = state.get('last_save_time')

            # Restore configuration
            self.max_context_tokens = state.get('max_context_tokens', MAX_CONTEXT_TOKENS)
            self.max_new_tokens = state.get('max_new_tokens', MAX_NEW_TOKENS)
            self.use_conversation_history = state.get('use_conversation_history', USE_CONVERSATION_HISTORY)

            print("\n" + "="*70)
            print(" STATE RESTORED SUCCESSFULLY!")
            print("="*70)
            print(f"  Restart #{self.total_restarts}")
            print(f"    {self.turn_number} turns completed")
            print(f"    {len(self.full_chat_history)} messages in full history")
            print(f"    {len(self.working_chat_history)} messages in working memory")
            print(f"    Session started: {self.session_start_time.strftime('%Y-%m-%d %H:%M:%S')}")
            if self.last_save_time:
                print(f"    Last saved: {self.last_save_time.strftime('%Y-%m-%d %H:%M:%S')}")
            print("="*70 + "\n")

        except Exception as e:
            print(f" Error loading state: {e}")
            print(" Starting with fresh state.")
            self.session_start_time = datetime.now()

            # Initialize with system message
            system_message = {"role": "system", "content": SYSTEM_PROMPT}
            self.full_chat_history.append(system_message)
            self.working_chat_history.append(system_message)

    def save(self):
        """Public method to save state."""
        self._save_state()

    def clear_state(self):
        """Clear all state and delete the state file."""
        self.full_chat_history = []
        self.working_chat_history = []
        self.turn_number = 0
        self.session_start_time = datetime.now()
        self.total_restarts = 0
        self.last_save_time = None

        # Re-initialize with system message
        system_message = {"role": "system", "content": SYSTEM_PROMPT}
        self.full_chat_history.append(system_message)
        self.working_chat_history.append(system_message)

        if os.path.exists(self.state_file):
            os.remove(self.state_file)
            print("  State cleared and file deleted.")

    def get_stats_dict(self):
        """Get statistics about the current state."""
        return {
            'turn_number': self.turn_number,
            'total_messages': len(self.full_chat_history),
            'working_messages': len(self.working_chat_history),
            'total_restarts': self.total_restarts,
            'session_start': self.session_start_time,
            'last_save': self.last_save_time
        }


# ============================================================================
# LOAD MODEL
# ============================================================================

print("Loading model (this takes 1-2 minutes)...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    low_cpu_mem_usage=True
)

model.eval()
print(f" Model loaded! Using device: {model.device}")
print(f" Memory usage: ~2.5 GB (FP16)\n")

# ============================================================================
# INITIALIZE RESTARTABLE STATE
# ============================================================================

chat_state = RestartableChatState(STATE_FILE)

# ============================================================================
# DISPLAY CONFIGURATION
# ============================================================================
print("="*70)
print("CONFIGURATION:")
print(f"  • Conversation History: {'ENABLED' if chat_state.use_conversation_history else 'DISABLED'}")
print(f"  • Context Strategy: TOKEN_BUDGET")
print(f"  • Max Context Length: {MAX_CONTEXT_LENGTH} tokens (total window)")
print(f"  • Max Response Tokens: {chat_state.max_new_tokens} tokens")
print(f"  • Effective History Limit: {chat_state.max_context_tokens} tokens")
print(f"  • State Persistence: ENABLED (pickle)")
print(f"  • State File: {STATE_FILE}")
print("="*70 + "\n")

# ============================================================================
# BEST PRACTICE 6: TOKEN COUNTING (FAST AND ACCURATE)
# ============================================================================

def approximate_token_count(text):
    """
    Fast token count approximation.
    From professor's guide: approx_tokens = len(text.split()) * 1.3
    """
    return int(len(text.split()) * 1.3)


def accurate_token_count(messages):
    """
    Accurate token count using actual tokenization.
    From professor's guide: actual_tokens = len(tokenizer.encode(text))
    """
    try:
        formatted = tokenizer.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=False
        )
        tokens = tokenizer.encode(formatted)
        return len(tokens)
    except:
        # Fallback to approximation
        total = 0
        for msg in messages:
            total += approximate_token_count(msg["content"])
        return total


# ============================================================================
# BEST PRACTICE 1: MONITOR TOKEN USAGE
# ============================================================================

def get_context_stats(history):
    """
    Get statistics about current context usage.
    From professor's guide: stats = chat_manager.get_context_stats()
    """
    num_tokens = accurate_token_count(history)
    num_messages = len([msg for msg in history if msg["role"] != "system"])

    return {
        'num_tokens': num_tokens,
        'max_tokens': chat_state.max_context_tokens,
        'num_messages': num_messages,
        'utilization': f"{(num_tokens/chat_state.max_context_tokens)*100:.1f}%",
        'tokens_remaining': chat_state.max_context_tokens - num_tokens
    }


# ============================================================================
# TOKEN BUDGET CONTEXT MANAGEMENT
# ============================================================================

def token_budget_management(history, max_tokens):
    """
    Token-based context management with all best practices.

    Implements:
    - BEST PRACTICE 3: Preserve system prompts
    - BEST PRACTICE 4: Edge case handling
    - BEST PRACTICE 4: User warnings when truncating
    """
    # BEST PRACTICE 3: Always preserve system message
    system_msgs = [msg for msg in history if msg["role"] == "system"]
    conversation_msgs = [msg for msg in history if msg["role"] != "system"]

    # BEST PRACTICE 4: Handle edge case - prevent infinite loops
    if len(conversation_msgs) <= 1:
        return history  # Can't truncate further

    # Start with all messages
    current_history = system_msgs + conversation_msgs
    current_tokens = accurate_token_count(current_history)

    # Remove oldest messages until under budget
    removed_count = 0
    while current_tokens > max_tokens and len(conversation_msgs) > 2:
        # Remove oldest message pair (user + assistant)
        conversation_msgs = conversation_msgs[2:]
        current_history = system_msgs + conversation_msgs
        current_tokens = accurate_token_count(current_history)
        removed_count += 2

    # BEST PRACTICE 4: Warn user when truncating
    if removed_count > 0:
        print(f"  [Context management: Removed {removed_count} old messages]")
        print(f"   Current context: {current_tokens}/{max_tokens} tokens")

    return current_history


# ============================================================================
# MAIN CHAT LOOP
# ============================================================================

print("Chat started! Type 'quit' or 'exit' to end the conversation.")
print("Special commands:")
print("  • 'stats'   - Show detailed context statistics")
print("  • 'history' - Show conversation history")
print("  • 'clear'   - Clear all state and start fresh")
print("  • 'save'    - Manually save current state")
print("  • Ctrl+C    - Interrupt (state auto-saved)")
print()
print("="*70 + "\n")

while True:
    # ========================================================================
    # Get user input
    # ========================================================================
    try:
        user_input = input("You: ").strip()
    except (KeyboardInterrupt, EOFError):
        print("\n\n Interrupted. Saving state...")
        chat_state.save()
        print(" State saved. Safe to exit.")
        break

    # Handle special commands
    if user_input.lower() in ['quit', 'exit', 'q']:
        print("\n" + "="*70)
        print("CONVERSATION SUMMARY:")
        stats = chat_state.get_stats_dict()
        print(f"  • Total turns: {stats['turn_number']}")
        print(f"  • Total messages: {stats['total_messages']}")
        print(f"  • Total restarts: {stats['total_restarts']}")
        if chat_state.use_conversation_history:
            ctx_stats = get_context_stats(chat_state.working_chat_history)
            print(f"  • Final tokens: {ctx_stats['num_tokens']}/{ctx_stats['max_tokens']}")
            print(f"  • Context utilization: {ctx_stats['utilization']}")
        print("="*70)

        # Save state before exit
        print("\n Saving state before exit...")
        chat_state.save()
        print(" State saved successfully!")
        print(f" State file: {STATE_FILE}")
        print("\nGoodbye! Run again to continue where you left off.")
        break

    # Show statistics
    if user_input.lower() == 'stats':
        ctx_stats = get_context_stats(chat_state.working_chat_history)
        state_stats = chat_state.get_stats_dict()

        print("\n" + "="*70)
        print("CONTEXT STATISTICS:")
        print(f"  • Tokens used: {ctx_stats['num_tokens']}/{ctx_stats['max_tokens']}")
        print(f"  • Utilization: {ctx_stats['utilization']}")
        print(f"  • Messages in context: {ctx_stats['num_messages']}")
        print(f"  • Tokens remaining: {ctx_stats['tokens_remaining']}")
        print("\nSESSION STATISTICS:")
        print(f"  • Turn number: {state_stats['turn_number']}")
        print(f"  • Total messages: {state_stats['total_messages']}")
        print(f"  • Working messages: {state_stats['working_messages']}")
        print(f"  • Total restarts: {state_stats['total_restarts']}")
        print(f"  • Session start: {state_stats['session_start'].strftime('%Y-%m-%d %H:%M:%S')}")
        if state_stats['last_save']:
            print(f"  • Last saved: {state_stats['last_save'].strftime('%Y-%m-%d %H:%M:%S')}")
        print("="*70 + "\n")
        continue

    # Show history
    if user_input.lower() == 'history':
        print("\n" + "="*70)
        print("CONVERSATION HISTORY:")
        print("="*70)
        for i, msg in enumerate(chat_state.full_chat_history, 1):
            role = msg['role'].upper()
            content = msg['content'][:100]
            if len(msg['content']) > 100:
                content += "..."
            print(f"{i}. [{role}] {content}")
        print("="*70 + "\n")
        continue

    # Clear state
    if user_input.lower() == 'clear':
        confirm = input("  Clear all state and start fresh? (yes/no): ")
        if confirm.lower() == 'yes':
            chat_state.clear_state()
            print(" State cleared. Starting fresh conversation.\n")
        continue

    # Manual save
    if user_input.lower() == 'save':
        chat_state.save()
        print(" State saved manually.\n")
        continue

    if not user_input:
        continue

    chat_state.turn_number += 1

    # ========================================================================
    # Add user message to histories
    # ========================================================================
    user_message = {"role": "user", "content": user_input}
    chat_state.full_chat_history.append(user_message)

    # ========================================================================
    # Prepare working history based on settings
    # ========================================================================
    if chat_state.use_conversation_history:
        # Add to working history
        chat_state.working_chat_history.append(user_message)

        # Apply token budget context management
        managed_history = token_budget_management(
            chat_state.working_chat_history,
            chat_state.max_context_tokens
        )

        # Update working history with managed version
        chat_state.working_chat_history = managed_history

    else:
        # NO HISTORY MODE: Only system + current message
        system_message = {"role": "system", "content": SYSTEM_PROMPT}
        managed_history = [system_message, user_message]
        print(f" [No history mode: Bot sees only current message]")

    # ========================================================================
    # BEST PRACTICE 1: MONITOR TOKEN USAGE
    # ========================================================================
    stats = get_context_stats(managed_history)
    print(f" [Context: {stats['num_tokens']}/{stats['max_tokens']} tokens " +
          f"({stats['utilization']} used) | {stats['num_messages']} messages]")

    # ========================================================================
    # Tokenize
    # ========================================================================
    input_ids = tokenizer.apply_chat_template(
        managed_history,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    attention_mask = torch.ones_like(input_ids)

    # ========================================================================
    # Generate response
    # ========================================================================
    print("Assistant: ", end="", flush=True)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_new_tokens=chat_state.max_new_tokens,  # BEST PRACTICE 2
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.eos_token_id
        )

    # ========================================================================
    # Decode response
    # ========================================================================
    new_tokens = outputs[0][input_ids.shape[1]:]
    assistant_response = tokenizer.decode(
        new_tokens,
        skip_special_tokens=True
    )

    print(assistant_response)

    # ========================================================================
    # Add assistant response to histories
    # ========================================================================
    assistant_message = {"role": "assistant", "content": assistant_response}
    chat_state.full_chat_history.append(assistant_message)

    if chat_state.use_conversation_history:
        chat_state.working_chat_history.append(assistant_message)

    # ========================================================================
    # AUTO-SAVE STATE AFTER EACH TURN
    # ========================================================================
    chat_state.save()

    print()

# ============================================================================
# IMPLEMENTATION SUMMARY
# ============================================================================
"""
ENHANCED FEATURES:

✓ PICKLE STATE PERSISTENCE
  - Automatic save after each turn
  - Graceful handling of Ctrl+C and kill signals
  - Atomic file writes prevent corruption
  - State restored on restart

✓ RESTART CAPABILITY
  - Tracks restart count
  - Preserves full conversation history
  - Maintains token budget state
  - Continues from exact point of interruption

CONTEXT MANAGEMENT STRATEGY: Token Budget

BEST PRACTICES IMPLEMENTED:

✓ 1. Monitor Token Usage
✓ 2. Set Appropriate Max Length
✓ 3. Preserve System Prompts
✓ 4. Handle Edge Cases
✓ 5. Optimize for Use Case
✓ 6. Token Counting

TEST CASES COVERED:

✓ Test Case 1: Context Overflow
✓ Test Case 2: System Prompt Preservation
✓ Test Case 3: Token Counting Accuracy
✓ Test Case 4: Multi-turn Coherence
✓ Test Case 5: Restart After Interruption (NEW!)

USAGE:

1. Normal conversation - state auto-saved after each turn
2. Ctrl+C or kill - state saved before exit
3. Restart script - conversation continues where it left off
4. Type 'clear' to start fresh
5. Type 'save' to manually save state
"""

Loading model (this takes 1-2 minutes)...
 Model loaded! Using device: cpu
 Memory usage: ~2.5 GB (FP16)


 STATE RESTORED SUCCESSFULLY!
  Restart #1
    4 turns completed
    9 messages in full history
    9 messages in working memory
    Session started: 2026-01-21 23:16:01
    Last saved: 2026-01-21 23:31:19

CONFIGURATION:
  • Conversation History: ENABLED
  • Context Strategy: TOKEN_BUDGET
  • Max Context Length: 2048 tokens (total window)
  • Max Response Tokens: 512 tokens
  • Effective History Limit: 1536 tokens
  • State Persistence: ENABLED (pickle)
  • State File: llama_chat_state.pkl

Chat started! Type 'quit' or 'exit' to end the conversation.
Special commands:
  • 'stats'   - Show detailed context statistics
  • 'history' - Show conversation history
  • 'clear'   - Clear all state and start fresh
  • 'save'    - Manually save current state
  • Ctrl+C    - Interrupt (state auto-saved)


You: What was I struggling with as a student
 [Context: 558/1536 tokens (36.3% used) | 

SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
