# QELM Training on Google Colab

This notebook trains the embedding-based QELM system on GPU.

**What it does:**
1. Clones the code from GitHub
2. Installs dependencies
3. Mounts Google Drive for checkpoints
4. Trains Stage 1 (supervised pretraining)
5. Trains the two-tower recommender
6. Tests the full system

**Prerequisites:**
- Push your code to GitHub
- Set OpenAI API key (for question generation)
- Use GPU runtime (Runtime → Change runtime type → GPU)

## 1. Setup

In [None]:
# Check GPU availability
!nvidia-smi

### Clone Code

**For private repos**, you have 3 options:
1. Make the repo public temporarily
2. Use a Personal Access Token (see cell below for instructions)
3. Upload code to Google Drive and copy from there

**For public repos**, just run the clone command directly.

In [None]:
# Clone repository
# Option 1: Public repo (easiest)
!git clone https://github.com/makarovaalexa-brch/qelm-crs.git

# Option 2: Private repo with personal access token
# Create token at: https://github.com/settings/tokens (select 'repo' scope)
# Then use: !git clone https://YOUR_TOKEN@github.com/makarovaalexa-brch/qelm-crs.git

# Option 3: Upload files from Google Drive instead
# Uncomment if you've already uploaded the code to Drive:
# !cp -r /content/drive/MyDrive/qelm-crs /content/

# Change to directory
%cd qelm-crs

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio
!pip install -q sentence-transformers
!pip install -q openai
!pip install -q pandas numpy scikit-learn tqdm
!pip install -q python-dotenv

In [None]:
# Mount Google Drive for checkpoints
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory
!mkdir -p /content/drive/MyDrive/qelm_checkpoints

In [None]:
# Set OpenAI API key
import os
from getpass import getpass

# Enter your OpenAI API key when prompted
api_key = getpass('Enter OpenAI API Key: ')
os.environ['OPENAI_API_KEY'] = api_key

## 2. Prepare Data

In [None]:
# Create sample Reddit data (or scrape real data)
%cd /content/qelm-crs

# Option 1: Use 5 sample posts (fast, for testing pipeline)
!python src/qelm/data/reddit_scraper.py --sample --output-dir data/reddit

# Option 2: Scrape real Reddit data (100 posts per subreddit, ~500 total)
# Uncomment for actual training:
# !python src/qelm/data/reddit_scraper.py --max-posts 100 --output-dir data/reddit

In [None]:
# Verify data
import json

with open('data/reddit/sample_questions.json', 'r') as f:
    data = json.load(f)
    
print(f"Loaded {len(data)} sample questions")
print("\nExample:")
print(json.dumps(data[0], indent=2))

## 3. Train Stage 1 (Supervised Pretraining)

This teaches the RL actor to predict embeddings in the right semantic space.

In [None]:
# Import modules
import sys
sys.path.append('/content/qelm-crs/src')

from qelm.models.embedding_qelm import SentenceBERTEmbeddingSpace, EmbeddingActorCritic
from qelm.training.stage1_supervised import Stage1Trainer
from sentence_transformers import SentenceTransformer

print("Imports successful!")

In [None]:
# Initialize components
print("Initializing SentenceBERT embedding space...")
embedding_space = SentenceBERTEmbeddingSpace(movielens_data_path=None)

print("\nInitializing RL actor...")
rl_agent = EmbeddingActorCritic(
    state_dim=384,  # SentenceBERT
    embedding_dim=384  # SentenceBERT
)

print("\nInitializing encoder...")
encoder = SentenceTransformer('all-MiniLM-L6-v2')

print("\n✓ Initialization complete")

In [None]:
# Create trainer
trainer = Stage1Trainer(
    rl_agent=rl_agent,
    embedding_space=embedding_space,
    reddit_data_path='data/reddit',
    encoder=encoder
)

In [None]:
# Train Stage 1
trainer.train(
    epochs=20,  # More epochs on GPU
    batch_size=64,  # Larger batch on GPU
    learning_rate=0.001
)

In [None]:
# Evaluate Stage 1
trainer.evaluate(num_samples=10)

In [None]:
# Save checkpoint to Google Drive
import torch
from pathlib import Path

checkpoint_dir = Path('/content/drive/MyDrive/qelm_checkpoints')
checkpoint_path = checkpoint_dir / 'stage1_final.pt'

torch.save({
    'actor_state_dict': rl_agent.actor.state_dict(),
    'train_losses': trainer.train_losses,
}, checkpoint_path)

print(f"✓ Saved checkpoint to: {checkpoint_path}")

## 4. Train Two-Tower Recommender

In [None]:
# Import recommender
from qelm.models.two_tower_recommender import (
    TwoTowerRecommender,
    MovieCatalog,
    RecommenderTrainer
)

In [None]:
# Initialize recommender
movie_catalog = MovieCatalog(movielens_data_path=None)  # Use sample data

recommender = TwoTowerRecommender(
    state_dim=384,
    embedding_dim=128
)

rec_trainer = RecommenderTrainer(recommender, movie_catalog, encoder)

In [None]:
# Sample training data
sample_conversations = [
    "I love action movies with great cinematography like The Dark Knight",
    "I enjoy mind-bending sci-fi films like Inception and Interstellar",
    "I prefer dark crime dramas with great dialogue like Pulp Fiction",
    "I like sci-fi movies with philosophical themes like The Matrix",
    "I want intense space exploration films like Interstellar",
] * 10  # Repeat for more training data

sample_liked_movies = [
    [1],  # The Dark Knight
    [2, 5],  # Inception, Interstellar
    [3],  # Pulp Fiction
    [4],  # The Matrix
    [5],  # Interstellar
] * 10

print(f"Training on {len(sample_conversations)} examples")

In [None]:
# Train recommender
rec_trainer.train(
    conversations=sample_conversations,
    liked_movies=sample_liked_movies,
    epochs=20,
    batch_size=32
)

In [None]:
# Save recommender checkpoint
rec_checkpoint_path = checkpoint_dir / 'recommender_final.pt'

torch.save({
    'user_tower': recommender.user_tower.state_dict(),
    'item_tower': recommender.item_tower.state_dict(),
}, rec_checkpoint_path)

print(f"✓ Saved recommender to: {rec_checkpoint_path}")

## 5. Test Full System

In [None]:
# Test recommender
test_conversation = "I want something like Inception with mind-bending sci-fi elements"
test_state = encoder.encode(test_conversation, convert_to_numpy=True)

recommendations = rec_trainer.recommend(test_state, top_k=5)

print(f"\nQuery: {test_conversation}\n")
print("Recommendations:")
for i, (movie_id, title, score) in enumerate(recommendations, 1):
    print(f"{i}. {title}: {score:.3f}")

In [None]:
# Test full QELM system
from qelm.models.embedding_qelm import EmbeddingQLEM

# Note: This will use GPT for question generation (requires API key)
qelm = EmbeddingQLEM(movielens_data_path=None)

# Load trained RL weights
qelm.rl_agent.actor.load_state_dict(
    torch.load(checkpoint_path)['actor_state_dict']
)

print("✓ Loaded trained RL actor")

In [None]:
# Simulate a conversation
print("\n" + "="*60)
print("QELM CONVERSATION DEMO")
print("="*60)

# First question
question1 = qelm.select_next_question(explore=False, verbose=True)
print(f"\n🤖 QELM: {question1}")

# Simulate response
response1 = "I love Christopher Nolan films, especially Inception and Interstellar"
qelm.process_user_response(response1)
print(f"\n👤 User: {response1}")

# Second question
question2 = qelm.select_next_question(explore=False, verbose=True)
print(f"\n🤖 QELM: {question2}")

# Get recommendations based on conversation
conv_state = qelm.encode_conversation_state()
final_recs = rec_trainer.recommend(conv_state, top_k=5)

print(f"\n\n📽️ RECOMMENDATIONS:")
for i, (movie_id, title, score) in enumerate(final_recs, 1):
    print(f"{i}. {title}")

## 6. Download Checkpoints (Optional)

In [None]:
# Download checkpoints to local machine
from google.colab import files

# Zip checkpoints
!zip -r qelm_checkpoints.zip /content/drive/MyDrive/qelm_checkpoints/

# Download
files.download('qelm_checkpoints.zip')

## Summary

**What we trained:**
1. ✅ Stage 1: RL actor to predict semantic embeddings
2. ✅ Two-Tower Recommender: Dialogue state → Movie recommendations

**Checkpoints saved to:**
- `/content/drive/MyDrive/qelm_checkpoints/stage1_final.pt`
- `/content/drive/MyDrive/qelm_checkpoints/recommender_final.pt`

**Next steps:**
- Train on real MovieLens data
- Scrape more Reddit conversations
- Train Stage 3 (end-to-end RL with reward)