In [None]:
from transformers import AutoTokenizer, pipeline, GPT2Tokenizer, GPT2LMHeadModel
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from tqdm import tqdm
import torch
import os
import json
import gc
from gpt_wrapper.chat import Chat
from reward_model import GPTRewardModel
from datasets import Dataset

In [None]:
import gpt_wrapper

gpt_wrapper.api_base = "http://mnlp-backend-938795011.eu-central-1.elb.amazonaws.com"
gpt_wrapper.api_key = "3e510581-28fa-4342-9758-4fa131bc2f42"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
tokenizer.add_special_tokens({"pad_token": "<pad>"})
tokenizer.add_tokens(["<bot>: "])

In [None]:
model = GPT2LMHeadModel.from_pretrained("gpt2-large")
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load("weights/model_state_2_large_v2.pt", map_location=torch.device(device)))

In [None]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(model).to(device)


In [None]:
reward_model = GPTRewardModel('gpt2').to(device)
reward_model.load_state_dict(torch.load("reward_model.pth", map_location=torch.device(device)))

In [None]:
data = json.load(open("data/data_merged.json", "r"))
dataset = {'query' : []}
for i in range(len(data['queries'])):
    dataset['query'].append(data['queries'][i])
    
dataset = Dataset.from_dict(dataset)

In [None]:
config = PPOConfig(
    seed = 42,
    mini_batch_size=1,
    #log_with= 'wandb',
    ratio_threshold=50,
    learning_rate=1e-8,
    steps = 100,
    batch_size=1
)

In [None]:
ppo_trainer = PPOTrainer(
    model=model,
    dataset=dataset,
    tokenizer=tokenizer,
    config = config,
    
)

In [None]:
reward_tokenizer = AutoTokenizer.from_pretrained('gpt2')
reward_tokenizer.pad_token = reward_tokenizer.eos_token


In [None]:
reward_model.eval()

In [None]:
for batch in dataset:
    query = batch['query']
    # generate respnse
    gpt_tokenized_text = tokenizer(query, return_tensors='pt', padding=True, truncation=True, max_length=1024).to(device)
    input_ids = gpt_tokenized_text['input_ids'].to(device)
    attention_mask = gpt_tokenized_text['attention_mask'].to(device)
    output = model.generate(input_ids, attention_mask=attention_mask, max_length=150)
    answer = tokenizer.decode(output[0], skip_special_tokens=True)
    batch['response'] = answer
    input_reward = query + "\n" + answer
    input_ids = reward_tokenizer(input_reward, return_tensors='pt', padding=True, truncation=True, max_length=1024).to(device)['input_ids']
    attention_mask = reward_tokenizer(input_reward, return_tensors='pt', padding=True, truncation=True, max_length=1024).to(device)['attention_mask']
    reward = reward_model(input_ids, attention_mask=attention_mask)
    reward = [torch.tensor(reward).to(device)]
    input_ids = [input_ids]
    output = [output]
    stats = ppo_trainer.step(input_ids, output, reward)
    ppo_trainer.log_stats(stats, batch, reward)