## Make the GPT-2 simulate republican cases through Reinforcement Learning

In this homework, we will first train a reward model that assign higher reward to documents that sounds more like republican cases. Then we use RL to guide GPT-2 to complete the democratric cases in a republican way.

The reward model is covered in a previous notebook. All TODOs you need to finish lie in the RL part.

In [None]:
%pip install transformers trl

In [None]:
import warnings; warnings.simplefilter('ignore')
import pandas as pd
import numpy as np

import torch
from tqdm import tqdm
import pandas as pd

tqdm.pandas()

from transformers import pipeline, AutoTokenizer, DistilBertTokenizerFast, DistilBertForSequenceClassification

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler


### train a classification model as our reward model

In [None]:
df = pd.read_pickle('/kaggle/input/sc-cases-cleanked/sc_cases_cleaned.pkl', compression='gzip')
df = df.assign(author_id=(df['authorship']).astype('category').cat.codes)

# gpu or cpu?
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print (device)

model_name = 'distilbert-base-uncased' # huggingface model_ID or path to folder 
model = DistilBertForSequenceClassification.from_pretrained(model_name)

tokenizer = DistilBertTokenizerFast.from_pretrained(model_name)
inputs = tokenizer(df['opinion_text'].tolist(), return_tensors="pt", padding=True, truncation=True)
labels = torch.tensor(df['x_republican'].tolist()).long()

optimizer = torch.optim.Adam([
    {'params': model.distilbert.parameters(), 'lr': 1e-5},  
    {'params': model.classifier.parameters(), 'lr': 1e-3}
])

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(df['opinion_text'].tolist(), df['x_republican'].tolist(), test_size=.2)

# generate batches
X_train, X_test, y_train, y_test = np.array(X_train[:608]), np.array(X_test[:152]), np.array(y_train[:608]), np.array(y_test[:152])
print (X_train.shape, X_test.shape, y_train.shape, y_test.shape)

X_train, X_test, y_train, y_test = X_train.reshape(-1, 8), X_test.reshape(-1, 8), y_train.reshape(-1, 8), y_test.reshape(-1, 8)
print (X_train.shape, X_test.shape, y_train.shape, y_test.shape)

X_train, X_test = X_train.tolist(), X_test.tolist()

# train
from tqdm import tqdm

model.to(device)
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for text, labels in tqdm(zip(X_train, y_train), total=len(X_train)):
        # prepare model input through our tokenizer
        model_inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
        # place everything on the right device
        model_inputs = {k:v.to(device) for k,v in model_inputs.items()}
        # labels have to be torch long tensors
        labels = torch.tensor(labels).long().to(device)
        # now, we can perform the forward pass
        output = model(**model_inputs, labels=labels)
        loss, logits = output[:2]
        # and the backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

torch.save(model, 'republican_classifier.pt')
republican_classifier = model

In [None]:
from transformers import DistilBertTokenizer, TextClassificationPipeline

distil_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', max_length=512, truncation=True)
pipeline = TextClassificationPipeline(model=republican_classifier, tokenizer=distil_tokenizer, return_all_scores=True, device=republican_classifier.device)

In [None]:
texts = df.loc[df['x_republican'] == 0, 'opinion_text'].tolist()
texts[0]

In [None]:
pipeline(X_test[0][0], truncation=True)

### Reinforcement Learning

In [None]:
# prepare the dataset for reinforcement learning

from torch.utils.data import Dataset

class PPODataset(Dataset):
  def __init__(self, tokenizer, texts, input_min_text_length, input_max_text_length):
    self.tokenizer = tokenizer
    self.texts = texts
    self.random_sample = LengthSampler(input_min_text_length, input_max_text_length)

  def __len__(self):
    return len(self.texts)

  def __getitem__(self, index):
    text = self.texts[index]
    sample = {}
    sample["input_ids"] = torch.tensor(tokenizer.encode(text)[: self.random_sample()])
    sample["query"] = tokenizer.decode(sample["input_ids"])
    return sample

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

dataset = PPODataset(tokenizer, texts, 10, 15)


In [None]:
# prepare arguments and configurations for RL

sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

output_min_length = 50
output_max_length = 100
output_length_sampler = LengthSampler(output_min_length, output_max_length)


generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

def collator(data):
  return dict((key, [d[key] for d in data]) for key in data[0])

config = PPOConfig(
    model_name="gpt2",
    learning_rate=1.41e-5,
    batch_size=32,
)

In [None]:
# TODO: prepare model (use gpt2), reference model, and PPO trainer for RL
# load gpt2 to be optimized
# Prepare GPT-2 model for optimization
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.optimized_model_name)
# Load GPT-2 model to be used as reference
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.reference_model_name)

tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

# Initialize PPO trainer
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

len_sampler = LengthSampler(4, 16)

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader), total=len(ppo_trainer.dataloader)):
    if epoch >= 3:
        break
    query_tensors = batch["input_ids"]

    # Get responses from GPT-2
    response_tensors = []
    for query in query_tensors:
        glen = len_sampler()
        generation_kwargs["max_new_tokens"] = glen
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-glen:])
    batch["response"] = [tokenizer.decode(response.squeeze()) for response in response_tensors]

    # Compute sentiment scores
    texts = [query + response for query, response in zip(batch["query"], batch["response"])]
    outputs = pipeline(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in outputs]

    # Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

In [None]:
# visualize the outcomes generated by the RL tuned GPT-2

#### get a batch from the dataset
game_data = dict()
game_data["query"] = [dataset[i]["query"] for i in range(15)]
query_tensors = [dataset[i]["input_ids"] for i in range(15)]

response_tensors_ref, response_tensors = [], []
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
#### get response from gpt2 and gpt2_ref
for i in range(15):
    gen_len = output_length_sampler()
    output = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(15)]
game_data["response (after)"] = [tokenizer.decode(response_tensors[i]) for i in range(15)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in pipeline(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in pipeline(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results