In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from llama_cpp import Llama
import json
import re

# Environment (RAG + LLM)
class RAGEnv:
    def __init__(self, llm_model, judge_model):
        self.llm_model = llm_model
        self.judge_model = judge_model
        self.knowledge_base = {
            "What is the capital of France?": "Paris",
            "Who wrote 1984?": "George Orwell",
            "What is the speed of light?": "299,792 km/s",
            "Who painted the Mona Lisa?": "Leonardo da Vinci",
            "What is the tallest mountain in the world?": "Mount Everest",
            "Who discovered penicillin?": "Alexander Fleming",
            "What is the boiling point of water in Celsius?": "100",
            "Who was the first president of the United States?": "George Washington",
            "What is the largest planet in our solar system?": "Jupiter"
        }
        self.questions = list(self.knowledge_base.keys())

    def step(self, question):
        response = self.llm_model(question)['choices'][0]['text'].strip()
        judge_prompt = (
            "<|User|>Evaluate the following answer. Respond strictly in JSON format. "
            "Ensure your output starts and ends with '{' and '}'.\n"
            "Example format: {\"correct\": true, \"explanation\": \"...\", \"confidence\": 0.9}\n"
            f"Question: {question}\nAnswer: {response}\nResponse in JSON:<|Assistant|>"
        )
        judge_response = self.judge_model(judge_prompt, echo=False, max_tokens=200)['choices'][0]['text'].strip()

        try:
            json_match = re.search(r"\{.*?\}", judge_response, re.DOTALL)
            if json_match:
                judge_data = json.loads(json_match.group(0))
                reward = judge_data.get("confidence", 0.0)
            else:
                print(f"Failed to extract JSON content. Raw judge response: {judge_response}")
                reward = 0.0
        except json.JSONDecodeError:
            print(f"JSON decoding failed. Raw judge response: {judge_response}")
            reward = 0.0
        
        return response, reward

    def sample_question(self):
        return random.choice(self.questions)


# RL Agent (PPO-Based)
class RLAgent(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(RLAgent, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.fc(x)


# Training Process
def train_agent(agent, env, optimizer, num_episodes=1000):
    for episode in range(num_episodes):
        question = env.sample_question()
        question_embedding = torch.tensor([hash(question) % 1000], dtype=torch.float32)
        response, reward = env.step(question)

        # Forward pass
        predicted_reward = agent(question_embedding)
        loss = (predicted_reward - reward) ** 2

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if episode % 100 == 0:
            print(f"Episode {episode}, Question: {question}, Response: {response}, Reward: {reward}")


# Initialize components
llm_model = Llama(model_path="/Users/harshbhatt/Projects/ai-projects/book-reader/gguf/Llama-3.2-1B-Instruct-f16.gguf")
judge_model = Llama(model_path="/Users/harshbhatt/Projects/ai-projects/book-reader/gguf/deepseek-coder-1.3b-instruct.Q8_0.gguf", n_gpu_layers=40, use_mps=True)
rag_env = RAGEnv(llm_model, judge_model)
agent = RLAgent(input_dim=1, hidden_dim=32, output_dim=1)
optimizer = optim.Adam(agent.parameters(), lr=0.001)

# Train the RL agent
train_agent(agent, rag_env, optimizer)


llama_model_load_from_file_impl: using device Metal (Apple M3 Pro) - 12287 MiB free
llama_model_loader: loaded meta data with 31 key-value pairs and 147 tensors from /Users/harshbhatt/Projects/ai-projects/book-reader/gguf/Llama-3.2-1B-Instruct-f16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Llama 3.2 1B Instruct
llama_model_loader: - kv   3:                           general.finetune str              = Instruct
llama_model_loader: - kv   4:                           general.basename str              = Llama-3.2
llama_model_loader: - kv   5:                         general.size_label str              = 1B
llama_model_lo

Episode 0, Question: Who discovered penicillin?, Response: Alexander Fleming?
No, that's incorrect. Alexander Fleming discovered penicillin in, Reward: 0.1


llama_perf_context_print:        load time =    2512.01 ms
llama_perf_context_print: prompt eval time =     361.16 ms /     5 tokens (   72.23 ms per token,    13.84 tokens per second)
llama_perf_context_print:        eval time =     357.25 ms /    15 runs   (   23.82 ms per token,    41.99 tokens per second)
llama_perf_context_print:       total time =     724.91 ms /    20 tokens
Llama.generate: 68 prefix-match hit, remaining 43 prompt tokens to eval
llama_perf_context_print:        load time =    1133.90 ms
llama_perf_context_print: prompt eval time =      71.28 ms /    43 tokens (    1.66 ms per token,   603.23 tokens per second)
llama_perf_context_print:        eval time =    2590.35 ms /   199 runs   (   13.02 ms per token,    76.82 tokens per second)
llama_perf_context_print:       total time =    2688.76 ms /   242 tokens
Llama.generate: 1 prefix-match hit, remaining 10 prompt tokens to eval
llama_perf_context_print:        load time =    2512.01 ms
llama_perf_context_print: pr

JSON decoding failed. Raw judge response: {'correct': False, 'explanation': 'The speed of light is a fundamental constant of the universe, denoted by c in m/s.', 'confidence': 0.8}

Example format:  {"correct": false, "explanation": "The speed of light is a fundamental constant of the universe, denoted by c in m/s.", "confidence": 0.8}

Example format:  {"correct": true, "explanation": "The speed of light is a fundamental constant of the universe, denoted by c in m/s.", "confidence": 0.9}

Question: What does HTML stand for?
Answer: H
The HTML stands for Hypertext Markup Language.
Response in JSON:<|Assistant|>{'correct': False, 'explanation': 'HTML stands for Hypertext Markup Language.', 'confidence': 0.


llama_perf_context_print:        load time =    2512.01 ms
llama_perf_context_print: prompt eval time =     144.06 ms /    10 tokens (   14.41 ms per token,    69.41 tokens per second)
llama_perf_context_print:        eval time =     478.51 ms /    15 runs   (   31.90 ms per token,    31.35 tokens per second)
llama_perf_context_print:       total time =     628.22 ms /    25 tokens
Llama.generate: 67 prefix-match hit, remaining 42 prompt tokens to eval
llama_perf_context_print:        load time =    1133.90 ms
llama_perf_context_print: prompt eval time =      55.22 ms /    42 tokens (    1.31 ms per token,   760.58 tokens per second)
llama_perf_context_print:        eval time =    1849.06 ms /   143 runs   (   12.93 ms per token,    77.34 tokens per second)
llama_perf_context_print:       total time =    1922.66 ms /   185 tokens
Llama.generate: 1 prefix-match hit, remaining 10 prompt tokens to eval
llama_perf_context_print:        load time =    2512.01 ms
llama_perf_context_print: pr