In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

class QNetwork(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.layers(x)

def extract_features(code):
    return np.array([
        len(code),
        code.count("for"),
        code.count("if"),
        code.count("return")
    ], dtype=np.float32)

def generate_dataset():
    codes = [
        "def add(x, y): return x + y",
        "def loop(n): for i in range(n): print(i)",
        "def cond(x): return x if x > 0 else -x",
        "def bad(): return 1 / 0"
    ]
    rewards = [1.0, 0.5, 0.8, -1.0]
    return [(extract_features(c), r) for c, r in zip(codes, rewards)]

def train_q():
    data = generate_dataset()
    net = QNetwork(input_dim=4)
    opt = optim.Adam(net.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    for epoch in range(100):
        random.shuffle(data)
        total_loss = 0
        for feat, reward in data:
            x = torch.tensor(feat).float().unsqueeze(0)
            y = torch.tensor([[reward]]).float()
            pred = net(x)
            loss = loss_fn(pred, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            total_loss += loss.item()
        if epoch % 10 == 0:
            print(f"Epoch {epoch} | Loss: {total_loss:.4f}")
    return net