# Chapter 41: Database Design for AI Systems

Run this notebook directly in Google Colab - no local Python needed!

**Full code**: [GitHub](https://github.com/eduardd76/AI_for_networking_and_security_engineers/tree/main/CODE/Volume-3-Production-Systems/Chapter-41-Database-Design)

## Setup

Install dependencies and configure database connection.

In [None]:
# Install dependencies
!pip install -q psycopg2-binary sqlalchemy alembic

import os
from getpass import getpass

print('✓ Dependencies installed')
print('\n✅ Setup complete! Ready to run examples.')
print('\n⚠️  Note: PostgreSQL examples require a running database.')
print('   For testing without PostgreSQL, examples use SQLite.')

## Example 1: SQLAlchemy Models for AI Conversations

Define database models for storing AI conversations and messages.

In [None]:
from sqlalchemy import create_engine, Column, String, Integer, Float, Text, DateTime, Boolean, ForeignKey, JSON
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
from datetime import datetime
import uuid

# Create base class for models
Base = declarative_base()

# Define models
class User(Base):
    """User model"""
    __tablename__ = 'users'
    
    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    email = Column(String(255), unique=True, nullable=False)
    name = Column(String(255), nullable=False)
    role = Column(String(50), default='user')
    is_active = Column(Boolean, default=True)
    created_at = Column(DateTime, default=datetime.utcnow)
    
    # Relationships
    conversations = relationship('Conversation', back_populates='user')

class Conversation(Base):
    """Conversation model"""
    __tablename__ = 'conversations'
    
    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    user_id = Column(String(36), ForeignKey('users.id'), nullable=False)
    title = Column(String(500))
    model = Column(String(100), nullable=False)
    created_at = Column(DateTime, default=datetime.utcnow)
    updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
    is_archived = Column(Boolean, default=False)
    metadata = Column(JSON, default={})
    
    # Relationships
    user = relationship('User', back_populates='conversations')
    messages = relationship('Message', back_populates='conversation', cascade='all, delete-orphan')

class Message(Base):
    """Message model"""
    __tablename__ = 'messages'
    
    id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
    conversation_id = Column(String(36), ForeignKey('conversations.id'), nullable=False)
    role = Column(String(20), nullable=False)  # 'user', 'assistant', 'system'
    content = Column(Text, nullable=False)
    model = Column(String(100))
    tokens_input = Column(Integer)
    tokens_output = Column(Integer)
    cost_usd = Column(Float)
    created_at = Column(DateTime, default=datetime.utcnow)
    metadata = Column(JSON, default={})
    
    # Relationships
    conversation = relationship('Conversation', back_populates='messages')

class APIUsage(Base):
    """API usage tracking"""
    __tablename__ = 'api_usage'
    
    id = Column(Integer, primary_key=True, autoincrement=True)
    user_id = Column(String(36), ForeignKey('users.id'))
    model = Column(String(100), nullable=False)
    endpoint = Column(String(200), nullable=False)
    tokens_input = Column(Integer, nullable=False)
    tokens_output = Column(Integer, nullable=False)
    cost_usd = Column(Float, nullable=False)
    latency_ms = Column(Integer)
    status = Column(String(20), nullable=False)
    recorded_at = Column(DateTime, default=datetime.utcnow)

# Create SQLite database (for testing)
engine = create_engine('sqlite:///ai_system.db', echo=False)
Base.metadata.create_all(engine)

# Create session
Session = sessionmaker(bind=engine)
session = Session()

print("Database Models Created:")
print(f"  Tables: {', '.join([table.name for table in Base.metadata.sorted_tables])}")
print(f"  Database: sqlite:///ai_system.db")

# Create sample user
user = User(
    email="engineer@company.com",
    name="John Engineer",
    role="user"
)
session.add(user)
session.commit()

print(f"\nCreated user: {user.email} (ID: {user.id})")

## Example 2: Creating and Querying Conversations

Store AI conversations with messages and track token usage.

In [None]:
from sqlalchemy import func
from datetime import datetime, timedelta

# Create a conversation
conversation = Conversation(
    user_id=user.id,
    title="BGP Configuration Analysis",
    model="claude-sonnet-4-5",
    metadata={"source": "cli", "device": "router01"}
)
session.add(conversation)
session.commit()

print(f"Created conversation: {conversation.title}")
print(f"  ID: {conversation.id}")
print(f"  Model: {conversation.model}")
print(f"  Created: {conversation.created_at}")

# Add messages to conversation
user_message = Message(
    conversation_id=conversation.id,
    role='user',
    content="Analyze this BGP config and identify potential issues",
    metadata={"device": "router01"}
)
session.add(user_message)

assistant_message = Message(
    conversation_id=conversation.id,
    role='assistant',
    content="I found 3 issues: 1) Missing route-map on neighbor...",
    model="claude-sonnet-4-5",
    tokens_input=1250,
    tokens_output=890,
    cost_usd=0.015,
    metadata={"confidence": 0.95}
)
session.add(assistant_message)
session.commit()

print(f"\nAdded {len(conversation.messages)} messages to conversation")

# Record API usage
api_usage = APIUsage(
    user_id=user.id,
    model="claude-sonnet-4-5",
    endpoint="/v1/messages",
    tokens_input=1250,
    tokens_output=890,
    cost_usd=0.015,
    latency_ms=2345,
    status="success"
)
session.add(api_usage)
session.commit()

print("\nRecorded API usage")

# Query recent conversations
print("\n" + "=" * 60)
print("Recent Conversations:")
recent = session.query(Conversation).filter_by(
    user_id=user.id,
    is_archived=False
).order_by(Conversation.created_at.desc()).limit(5).all()

for conv in recent:
    print(f"\n  {conv.title}")
    print(f"    ID: {conv.id}")
    print(f"    Model: {conv.model}")
    print(f"    Messages: {len(conv.messages)}")
    print(f"    Created: {conv.created_at}")

# Calculate total costs
print("\n" + "=" * 60)
print("Cost Analysis:")

total_cost = session.query(func.sum(Message.cost_usd)).filter(
    Message.conversation_id == Conversation.id,
    Conversation.user_id == user.id
).scalar() or 0

total_tokens_in = session.query(func.sum(Message.tokens_input)).filter(
    Message.conversation_id == Conversation.id,
    Conversation.user_id == user.id
).scalar() or 0

total_tokens_out = session.query(func.sum(Message.tokens_output)).filter(
    Message.conversation_id == Conversation.id,
    Conversation.user_id == user.id
).scalar() or 0

print(f"  Total cost: ${total_cost:.3f}")
print(f"  Input tokens: {total_tokens_in:,}")
print(f"  Output tokens: {total_tokens_out:,}")
print(f"  Total tokens: {total_tokens_in + total_tokens_out:,}")

## Example 3: Database Migrations with Alembic

Manage schema changes with Alembic migrations.

In [None]:
# Note: This example demonstrates migration concepts
# In production, run `alembic init migrations` and use CLI commands

print("Alembic Migration Workflow:\n")
print("1. Initialize Alembic:")
print("   $ alembic init migrations")
print("")
print("2. Configure alembic.ini:")
print("   sqlalchemy.url = postgresql://user:pass@localhost/ai_system")
print("")
print("3. Create migration:")
print("   $ alembic revision --autogenerate -m 'add_embedding_column'")
print("")
print("4. Apply migration:")
print("   $ alembic upgrade head")
print("")
print("5. Rollback migration:")
print("   $ alembic downgrade -1")

# Example migration file structure
migration_example = '''
"""add_embedding_column

Revision ID: abc123def456
"""
from alembic import op
import sqlalchemy as sa

def upgrade():
    # Add embedding column for vector search
    op.add_column('messages',
        sa.Column('embedding', sa.ARRAY(sa.Float), nullable=True)
    )
    
    # Add model_version column
    op.add_column('messages',
        sa.Column('model_version', sa.String(50), nullable=True)
    )

def downgrade():
    # Remove columns
    op.drop_column('messages', 'embedding')
    op.drop_column('messages', 'model_version')
'''

print("\n" + "=" * 60)
print("Example Migration File:\n")
print(migration_example)

## Example 4: Backup and Performance Optimization

Demonstrate backup strategies and query optimization.

In [None]:
from sqlalchemy import text, inspect
import time

print("Database Performance and Backup\n")
print("=" * 60)

# Create additional test data for performance testing
print("\n1. Creating test data...")
for i in range(10):
    conv = Conversation(
        user_id=user.id,
        title=f"Test Conversation {i+1}",
        model="claude-sonnet-4-5",
        metadata={"test": True}
    )
    session.add(conv)
    
    # Add messages
    for j in range(5):
        msg = Message(
            conversation_id=conv.id,
            role='user' if j % 2 == 0 else 'assistant',
            content=f"Message {j+1}",
            tokens_input=100,
            tokens_output=150,
            cost_usd=0.002
        )
        session.add(msg)

session.commit()
print("   Created 10 conversations with 50 messages")

# Query with joins (unoptimized)
print("\n2. Query performance comparison:")

start = time.time()
results = session.query(Message).join(Conversation).filter(
    Conversation.user_id == user.id
).all()
duration1 = (time.time() - start) * 1000
print(f"   Query without eager loading: {duration1:.2f}ms ({len(results)} messages)")

# Query with eager loading (optimized)
from sqlalchemy.orm import joinedload

start = time.time()
results = session.query(Message).options(
    joinedload(Message.conversation)
).join(Conversation).filter(
    Conversation.user_id == user.id
).all()
duration2 = (time.time() - start) * 1000
print(f"   Query with eager loading: {duration2:.2f}ms ({len(results)} messages)")
print(f"   Performance improvement: {((duration1 - duration2) / duration1 * 100):.1f}%")

# Database statistics
print("\n3. Database Statistics:")
inspector = inspect(engine)

for table_name in inspector.get_table_names():
    count = session.query(Base.metadata.tables[table_name]).count()
    print(f"   {table_name}: {count} rows")

# Backup strategy (conceptual)
print("\n" + "=" * 60)
print("Backup Strategy:\n")
print("PostgreSQL backup commands:")
print("  Full backup:")
print("    $ pg_dump -h localhost -U postgres -F c -f backup.dump ai_system")
print("")
print("  Restore:")
print("    $ pg_restore -h localhost -U postgres -d ai_system -c backup.dump")
print("")
print("  Schema only:")
print("    $ pg_dump -h localhost -U postgres --schema-only -f schema.sql ai_system")
print("")
print("Recommended schedule:")
print("  - Daily full backups (retain 7 days)")
print("  - Hourly WAL archiving (point-in-time recovery)")
print("  - Weekly uploads to S3/Azure Blob")

# Indexing recommendations
print("\n" + "=" * 60)
print("Index Recommendations:\n")
print("CREATE INDEX idx_messages_conversation_id ON messages(conversation_id);")
print("CREATE INDEX idx_messages_created_at ON messages(created_at DESC);")
print("CREATE INDEX idx_conversations_user_date ON conversations(user_id, created_at DESC);")
print("CREATE INDEX idx_api_usage_user_time ON api_usage(user_id, recorded_at DESC);")
print("CREATE INDEX idx_conversations_metadata ON conversations USING GIN (metadata);")

print("\n✅ Examples complete!")

## Example 5: Cost Tracking and Analytics

Query cost data and generate analytics reports.

In [None]:
from sqlalchemy import func, extract, case
from datetime import datetime, timedelta

print("Cost Analytics Report\n")
print("=" * 60)

# Total costs by model
print("\n1. Cost by Model:")
model_costs = session.query(
    Message.model,
    func.count(Message.id).label('message_count'),
    func.sum(Message.tokens_input).label('total_input'),
    func.sum(Message.tokens_output).label('total_output'),
    func.sum(Message.cost_usd).label('total_cost')
).filter(
    Message.model.isnot(None)
).group_by(Message.model).all()

for row in model_costs:
    print(f"\n   {row.model}:")
    print(f"     Messages: {row.message_count}")
    print(f"     Tokens: {row.total_input:,} in / {row.total_output:,} out")
    print(f"     Cost: ${row.total_cost:.3f}")

# Daily cost trends (last 7 days)
print("\n2. Daily Cost Trends:")
week_ago = datetime.now() - timedelta(days=7)

daily_costs = session.query(
    func.date(APIUsage.recorded_at).label('date'),
    func.count(APIUsage.id).label('requests'),
    func.sum(APIUsage.cost_usd).label('cost'),
    func.avg(APIUsage.latency_ms).label('avg_latency')
).filter(
    APIUsage.recorded_at >= week_ago
).group_by(func.date(APIUsage.recorded_at)).all()

if daily_costs:
    for row in daily_costs:
        print(f"   {row.date}: {row.requests} requests, ${row.cost:.3f}, {row.avg_latency:.0f}ms avg")
else:
    print("   No data for last 7 days (new database)")

# User statistics
print("\n3. User Activity:")
user_stats = session.query(
    User.email,
    func.count(Conversation.id).label('conversations'),
    func.count(Message.id).label('messages'),
    func.sum(Message.cost_usd).label('total_cost')
).join(Conversation).join(Message).group_by(User.email).all()

for row in user_stats:
    print(f"\n   {row.email}:")
    print(f"     Conversations: {row.conversations}")
    print(f"     Messages: {row.messages}")
    print(f"     Total cost: ${row.total_cost:.3f}")

# Cost projections
print("\n" + "=" * 60)
print("Monthly Projections:\n")

# Calculate average daily cost
total_days = 30  # Assume 30 days of data
total_cost_all = session.query(func.sum(Message.cost_usd)).scalar() or 0
daily_avg = total_cost_all / total_days if total_days > 0 else 0

monthly_projection = daily_avg * 30
annual_projection = monthly_projection * 12

print(f"   Daily average: ${daily_avg:.2f}")
print(f"   Monthly projection: ${monthly_projection:.2f}")
print(f"   Annual projection: ${annual_projection:.2f}")

# Cost breakdown by role
print("\n4. Cost by Message Role:")
role_costs = session.query(
    Message.role,
    func.count(Message.id).label('count'),
    func.sum(Message.cost_usd).label('cost')
).filter(
    Message.cost_usd.isnot(None)
).group_by(Message.role).all()

for row in role_costs:
    print(f"   {row.role}: {row.count} messages, ${row.cost:.3f}")

print("\n✅ Analytics complete!")

## Next Steps

- Full code: [Chapter 41 on GitHub](https://github.com/eduardd76/AI_for_networking_and_security_engineers/tree/main/CODE/Volume-3-Production-Systems/Chapter-41-Database-Design)
- Learn more: [vExpertAI.com](https://vexpertai.com)
- Author: Eduard Dulharu ([@eduardd76](https://github.com/eduardd76))

**Production Deployment:**
- Use PostgreSQL for production
- Implement connection pooling with PgBouncer
- Set up automated backups with pg_dump
- Configure WAL archiving for point-in-time recovery
- Add indexes for common query patterns
- Partition large tables by date
- Monitor query performance with EXPLAIN ANALYZE