In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical
from transformers import GPTNeoForCausalLM, GPTNeoModel, GPT2Tokenizer
import pandas as pd
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

generator_model_path = "./models/pretrain-models"
reward_model_path = "./models/reward-models"

# === LOAD TOKENIZER ===
tokenizer = GPT2Tokenizer.from_pretrained(generator_model_path)
tokenizer.pad_token = tokenizer.eos_token

# === VALUE HEAD WRAPPER ===
class GPTNeoWithValueHead(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        hidden_size = base_model.config.hidden_size
        self.value_head = nn.Sequential(
            nn.Dropout(0.1),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, input_ids):
        outputs = self.base_model(input_ids, output_hidden_states=True)
        hidden_states = outputs.hidden_states[-1]
        last_token_hidden = hidden_states[:, -1, :]
        value = self.value_head(last_token_hidden).squeeze(-1)
        return outputs, value

# === LOAD GENERATOR MODEL ===
base_model = GPTNeoForCausalLM.from_pretrained(generator_model_path).to(device)
generator = GPTNeoWithValueHead(base_model).to(device)
optimizer = optim.Adam(generator.parameters(), lr=1e-5)

# === LOAD REWARD MODEL ===
reward_model = GPTNeoModel.from_pretrained(reward_model_path).to(device)
reward_model.eval()

In [None]:
# === GENERATE SEQUENCE ===
def generate_sequence_with_grad(prompt, max_length=50):
    generator.train()
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    generated = input_ids.clone()
    log_probs = []

    for _ in range(max_length):
        outputs, _ = generator(generated)
        logits = outputs.logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
        dist = Categorical(probs)
        next_token = dist.sample()
        log_prob = dist.log_prob(next_token)
        log_probs.append(log_prob)
        generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)

        if next_token.item() == tokenizer.eos_token_id:
            break

    total_log_prob = torch.stack(log_probs).sum()
    _, value = generator(generated)
    return generated, total_log_prob, value.squeeze(), generated

# === GET REWARD ===
def get_reward(prompt, response_ids):
    text = prompt + tokenizer.decode(response_ids[0], skip_special_tokens=True)
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)

    with torch.no_grad():
        outputs = reward_model(**inputs)
        last_hidden = outputs.last_hidden_state[:, -1, :]
        score = torch.sigmoid(last_hidden.mean(dim=-1))
        return score.item()

# === A2C UPDATE ===
def a2c_update(log_probs, rewards, values):
    total_loss = 0.0
    for log_prob, reward, value in zip(log_probs, rewards, values):
        advantage = reward - value
        actor_loss = -log_prob * advantage.detach()
        critic_loss = advantage.pow(2)
        total_loss += actor_loss + critic_loss

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

In [None]:
# === TRAINING LOOP ===
def train_a2c(prompts, epochs=2, batch_size=8):
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        for i in tqdm(range(0, len(prompts), batch_size), desc=f"Epoch {epoch+1}"):
            batch_prompts = prompts[i:i+batch_size]
            log_probs, rewards, values = [], [], []

            for prompt in batch_prompts:
                generated, log_prob, value, generated_ids = generate_sequence_with_grad(prompt)
                reward = get_reward(prompt, generated_ids)
                log_probs.append(log_prob)
                rewards.append(torch.tensor(reward, device=device))
                values.append(value)

            a2c_update(log_probs, rewards, values)

        print(f"\n=== GRAD CHECK (Epoch {epoch+1}) ===")
        for name, param in generator.named_parameters():
            if param.grad is not None:
                print(f"{name} grad mean: {param.grad.mean().item():.6f}")

# === LOAD PROMPTS ===
df = pd.read_csv('./dataset/XSS_Dataset-1m.txt', names=["prompt"], on_bad_lines='skip', nrows=1000)
prompts = df['prompt'].astype(str).tolist()

# === START TRAINING ===
train_a2c(prompts, epochs=10, batch_size=8)

In [None]:
# === SAVE ===
save_path = "./models/finetune-models/gpt-neo-a2c-XSS"
generator.base_model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)