# Phase 2 Training Monitor (Stage 1.5: Mixed Pre-training)

This notebook tracks the progress of `train_phase2_stage1_mixed.py`.
It visualizes the **Cross-Entropy (CE) Loss** on the combined WikiText + Alpaca dataset.


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import os

# Config
log_file = "../logs/training_metrics_stage1_mixed.csv"
window = 50

# Load Data
if os.path.exists(log_file):
    try:
        df = pd.read_csv(log_file)
        print(f"Loaded {len(df)} steps from {log_file}")

        # Plot
        plt.figure(figsize=(14, 6))
        plt.plot(df["Loss"], label="Raw Loss", alpha=0.3, color="gray")
        plt.plot(
            df["Loss"].rolling(window).mean(),
            label=f"Smoothed (Window={window})",
            color="purple",
            linewidth=2,
        )

        plt.title("Stage 1.5: Mixed Pre-training (WikiText + Alpaca)")
        plt.xlabel("Batch Step")
        plt.ylabel("Cross Entropy Loss")
        plt.legend()
        plt.grid(True, alpha=0.3)

        # Stats
        if len(df) > 0:
            curr_loss = df["Loss"].iloc[-1]
            min_loss = df["Loss"].min()
            plt.text(
                len(df),
                curr_loss,
                f" Current: {curr_loss:.4f}",
                color="purple",
                fontweight="bold",
            )
            print(f"Current Loss: {curr_loss:.4f} | Min Loss: {min_loss:.4f}")

        plt.show()
    except Exception as e:
        print(f"Error reading log file: {e}")
else:
    print(f"Log file not found: {log_file}")

In [None]:
# Test Checkpoint Cell
import torch
import sys
import os

sys.path.append(os.path.abspath(".."))
from indra.models.quantum_model_v2 import IndraQuantumPhase2
from transformers import AutoTokenizer

# --- Config ---
checkpoint_path = "../checkpoints/phase2_stage1_mixed/checkpoint_stage1_mixed_epoch_3.pt"  # Change Epoch here
prompts = ["The future of AI is", "Once upon a time", "Quantum physics explains"]
device = "cuda" if torch.cuda.is_available() else "cpu"
# ----------------

if os.path.exists(checkpoint_path):
    print(f"Loading {checkpoint_path}...")
    model = IndraQuantumPhase2(32000, 128).to(device)
    ckpt = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

    for p in prompts:
        input_ids = tokenizer.encode(p, return_tensors="pt").to(device)
        print(f"\nPrompt: {p}")
        print("Output: ", end="")

        # Simple Greedy Gen
        with torch.no_grad():
            output_ids = input_ids.clone()
            for _ in range(50):
                logits, _, _ = model(output_ids)
                next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
                output_ids = torch.cat([output_ids, next_token], dim=1)
                print(tokenizer.decode(next_token[0]), end="", flush=True)
                if next_token.item() == tokenizer.eos_token_id:
                    break
        print()
else:
    print(f"Checkpoint not found: {checkpoint_path}")

Loading ../checkpoints/phase2_stage1_mixed/checkpoint_stage1_mixed_epoch_3.pt...


  ckpt = torch.load(checkpoint_path, map_location=device)



Prompt: The future of AI is
Output: exambeginningofBElöclauseNoRAgoalgoaltermJsoncitdenAlscientfacesievesladyTradeofephighly}Topsacredwinterladyalledasymtickladycolonialagingladycit%).teamdencompanillerytickladyaking}scientZeitconnectionErrorgoal

Prompt: Once upon a time
Output: isingLaurdegreesisingthrowsFrançacontributedpurposesCTroadDatstudiotiedhäufigquestionsKoreahornpenatom—thermalPUlokuzdivwinterheimerÖpassageworsecatalogynaremarkswinterhelping-,args’ColorgangKoreaappropriatePUollsomeoneregardschangeancementKeepaylor

Prompt: Quantum physics explains
Output: DataphysicsDataDataquallementánromInfloscOptDataSurDataeyeSurDataentitiesMauricebonenordistingYKofunctionsSurmanufact@StevenPanalledlareyDatahrlivecalculusdarUnionNASADataDataPanriansofteliminaterouteuerowSomething


: 