# 11: Generative Chatbot

**Duration:** 4-5 hours | **Difficulty:** Advanced

## Learning Objectives
- Generative model training strategies
- Advanced sampling and decoding methods
- Safety and bias considerations
- End-to-end generative chatbot implementation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json

import sys
sys.path.append('../')
from utils.text_utils import SimpleTokenizer
from utils.model_helpers import get_device, count_parameters

device = get_device("auto")
print(f"Using device: {device}")

## Advanced Generation Strategies

**Key methods**: Temperature, top-k, nucleus sampling, beam search

In [None]:
class AdvancedGenerator:
    """Advanced text generation methods."""
    
    @staticmethod
    def temperature_sampling(logits, temperature=1.0):
        """Temperature-based sampling for controlling randomness."""
        scaled_logits = logits / temperature
        probs = F.softmax(scaled_logits, dim=-1)
        return torch.multinomial(probs, 1)
    
    @staticmethod
    def top_k_sampling(logits, k=50, temperature=1.0):
        """Sample from top-k most likely tokens."""
        scaled_logits = logits / temperature
        top_k_logits, top_k_indices = torch.topk(scaled_logits, k)
        
        filtered_logits = torch.full_like(scaled_logits, -float('inf'))
        filtered_logits.scatter_(-1, top_k_indices, top_k_logits)
        
        probs = F.softmax(filtered_logits, dim=-1)
        return torch.multinomial(probs, 1)
    
    @staticmethod
    def nucleus_sampling(logits, p=0.9, temperature=1.0):
        """Nucleus (top-p) sampling for dynamic vocabulary."""
        scaled_logits = logits / temperature
        probs = F.softmax(scaled_logits, dim=-1)
        
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
        
        # Find nucleus
        nucleus_mask = cumulative_probs <= p
        nucleus_mask[..., 1:] = nucleus_mask[..., :-1].clone()
        nucleus_mask[..., 0] = True
        
        filtered_probs = torch.where(nucleus_mask, sorted_probs, torch.zeros_like(sorted_probs))
        filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
        
        sampled_idx = torch.multinomial(filtered_probs, 1)
        return torch.gather(sorted_indices, -1, sampled_idx)

# Load data
with open('../data/conversations/simple_qa_pairs.json', 'r') as f:
    conversations = [(item['question'], item['answer']) for item in json.load(f)]

tokenizer = SimpleTokenizer(vocab_size=2000)
all_text = [text for conv in conversations for text in conv]
tokenizer.fit(all_text)
vocab_size = len(tokenizer.vocab)

print(f"Loaded {len(conversations)} conversations")
print(f"Vocabulary size: {vocab_size}")

# Test generation strategies
sample_logits = torch.randn(1, vocab_size)
generator = AdvancedGenerator()

print("\nGeneration Strategy Comparison:")
print(f"Temperature (0.5): {generator.temperature_sampling(sample_logits, 0.5).item()}")
print(f"Temperature (1.5): {generator.temperature_sampling(sample_logits, 1.5).item()}")
print(f"Top-k (k=10): {generator.top_k_sampling(sample_logits, k=10).item()}")
print(f"Nucleus (p=0.9): {generator.nucleus_sampling(sample_logits, p=0.9).item()}")

## Generative Chatbot Implementation

Complete chatbot with conversation management and safety features.

In [None]:
class SimpleGenerativeModel(nn.Module):
    """Simple generative model for demonstration."""
    
    def __init__(self, vocab_size, d_model=256, n_layers=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = nn.Parameter(torch.randn(512, d_model))
        self.lstm = nn.LSTM(d_model, d_model, n_layers, batch_first=True)
        self.output = nn.Linear(d_model, vocab_size)
    
    def forward(self, x):
        seq_len = x.size(1)
        x = self.embedding(x) + self.pos_encoding[:seq_len]
        x, _ = self.lstm(x)
        return self.output(x)

class GenerativeChatbot:
    """Generative chatbot with safety features."""
    
    def __init__(self, model, tokenizer, max_turns=5):
        self.model = model
        self.tokenizer = tokenizer
        self.max_turns = max_turns
        self.history = []
        self.generator = AdvancedGenerator()
        
        # Basic safety filter
        self.unsafe_words = ['hate', 'violence', 'harmful']
    
    def generate_response(self, user_input, strategy='nucleus', max_length=30, **kwargs):
        """Generate response using specified strategy."""
        
        # Add context
        if self.history:
            context = " ".join([f"User: {h['user']} Bot: {h['bot']}" for h in self.history[-2:]])
            full_input = f"{context} User: {user_input} Bot:"
        else:
            full_input = f"User: {user_input} Bot:"
        
        # Tokenize
        tokens = self.tokenizer.encode(full_input, add_special_tokens=True, max_length=100)
        input_tensor = torch.tensor([tokens]).to(device)
        
        # Generate
        self.model.eval()
        generated = []
        
        with torch.no_grad():
            current = input_tensor
            
            for _ in range(max_length):
                logits = self.model(current)
                next_logits = logits[:, -1, :]
                
                # Apply generation strategy
                if strategy == 'temperature':
                    next_token = self.generator.temperature_sampling(
                        next_logits, kwargs.get('temperature', 0.8)
                    )
                elif strategy == 'top_k':
                    next_token = self.generator.top_k_sampling(
                        next_logits, kwargs.get('k', 50)
                    )
                elif strategy == 'nucleus':
                    next_token = self.generator.nucleus_sampling(
                        next_logits, kwargs.get('p', 0.9)
                    )
                else:
                    next_token = torch.argmax(next_logits, dim=-1, keepdim=True)
                
                if next_token.item() == 2:  # EOS token
                    break
                
                generated.append(next_token.item())
                current = torch.cat([current, next_token], dim=1)
        
        # Decode response
        if generated:
            response = self.tokenizer.decode(generated)
        else:
            response = "I'm not sure how to respond."
        
        return response.strip()
    
    def safety_check(self, response):
        """Basic safety filtering."""
        response_lower = response.lower()
        
        # Check unsafe words
        for word in self.unsafe_words:
            if word in response_lower:
                return False
        
        # Check length
        if len(response.split()) > 50 or len(response.strip()) < 2:
            return False
        
        return True
    
    def respond(self, user_input, strategy='nucleus', **kwargs):
        """Generate safe response to user input."""
        
        response = self.generate_response(user_input, strategy, **kwargs)
        
        if not self.safety_check(response):
            response = "I'd prefer to discuss something else. What can I help you with?"
        
        # Update history
        self.history.append({'user': user_input, 'bot': response})
        if len(self.history) > self.max_turns:
            self.history = self.history[-self.max_turns:]
        
        return response
    
    def reset(self):
        """Reset conversation history."""
        self.history = []

# Create model and chatbot
model = SimpleGenerativeModel(vocab_size).to(device)
chatbot = GenerativeChatbot(model, tokenizer)

print(f"\nGenerative Chatbot Created:")
print(f"Model parameters: {count_parameters(model)['total']:,}")

# Demo conversation
print("\n=== Chatbot Demo ===")
test_inputs = [
    "Hello, how are you?",
    "What is machine learning?",
    "Can you help me learn programming?"
]

for strategy in ['nucleus', 'temperature', 'top_k']:
    print(f"\nUsing {strategy} strategy:")
    chatbot.reset()
    
    for user_input in test_inputs[:2]:
        response = chatbot.respond(user_input, strategy=strategy, temperature=0.8, p=0.9, k=50)
        print(f"User: {user_input}")
        print(f"Bot: {response}")

## Safety and Ethics Considerations

**Critical aspects**: Content filtering, bias mitigation, privacy, human oversight

In [None]:
# Safety and ethics discussion
print("=== Safety and Ethics in Generative AI ===")
print("\nKey Considerations:")
print("• Content Filtering: Prevent harmful, biased, or inappropriate outputs")
print("• Bias Mitigation: Ensure fair and inclusive responses")
print("• Privacy Protection: Avoid memorizing personal information")
print("• Factual Accuracy: Implement fact-checking and uncertainty measures")
print("• Human Oversight: Maintain human-in-the-loop validation")

print("\nImplementation Strategies:")
print("• Multi-layer safety filters (keyword, ML-based, rule-based)")
print("• Diverse training data and bias auditing")
print("• Clear disclaimers about model limitations")
print("• Regular monitoring and model updates")
print("• User feedback and reporting mechanisms")

print("\nBest Practices:")
print("• Safety-first design from the beginning")
print("• Transparent communication about capabilities")
print("• Continuous improvement based on real-world usage")
print("• Collaboration with ethicists and domain experts")

# Compare generation methods
def compare_generation_quality():
    """Compare different generation strategies."""
    
    print("\n=== Generation Strategy Analysis ===")
    
    strategies_info = {
        'Greedy': {
            'description': 'Always picks most likely token',
            'pros': ['Deterministic', 'Fast'],
            'cons': ['Repetitive', 'Less creative']
        },
        'Temperature': {
            'description': 'Controls randomness in sampling',
            'pros': ['Tunable creativity', 'Simple'],
            'cons': ['Can be too random', 'Hard to tune']
        },
        'Top-k': {
            'description': 'Sample from k most likely tokens',
            'pros': ['Balanced quality', 'Prevents unlikely words'],
            'cons': ['Fixed vocabulary size', 'Context-independent']
        },
        'Nucleus (Top-p)': {
            'description': 'Dynamic vocabulary based on probability mass',
            'pros': ['Adaptive', 'High quality', 'Context-aware'],
            'cons': ['More complex', 'Computational overhead']
        }
    }
    
    for strategy, info in strategies_info.items():
        print(f"\n{strategy}:")
        print(f"  Description: {info['description']}")
        print(f"  Pros: {', '.join(info['pros'])}")
        print(f"  Cons: {', '.join(info['cons'])}")

compare_generation_quality()

print("\n=== Generative Chatbot Complete ===")
print("Key Concepts Learned:")
print("• Advanced sampling strategies (temperature, top-k, nucleus)")
print("• Safety filtering and content moderation")
print("• Conversation context management")
print("• Ethical considerations in AI deployment")
print("• Quality vs creativity trade-offs in generation")
print("\nNext: Model fine-tuning and deployment strategies!")