# RL for Human Feedback

In [1]:
import torch as t
import transformers

In [2]:
device = t.device('cuda:6' if t.cuda.is_available() else 'cpu')

## Train policy on simple reward function

In [17]:
def count_periods(s: str):
    return s.count('.')

In [18]:
def normalize_rewards(rewards):
    rewards = rewards - rewards.mean()
    rewards = rewards / (rewards.std() + 1e-5)
    return rewards

In [19]:
tokenizer = transformers.GPT2Tokenizer.from_pretrained('gpt2')
model = transformers.GPT2LMHeadModel.from_pretrained('gpt2', pad_token_id=tokenizer.eos_token_id).to(device)
ref_model = transformers.GPT2LMHeadModel.from_pretrained('gpt2').to(device)

In [20]:
generic_prompt = tokenizer.encode('This is', return_tensors='pt').to(device)
generic_prompt

tensor([[1212,  318]], device='cuda:6')

In [21]:
batch_size = 4
gen_len = 20
learning_rate = 3e-5
n_steps = 80

In [22]:
def get_logprobs(sample_ids, logits):
    logits += 1e-40
    logprobs = t.nn.functional.log_softmax(logits, dim=-1)
    logprobs = t.gather(logprobs, -1, sample_ids[:,:,None])[:,:,0]
    return logprobs

def get_loss(sample_ids, old_logprobs, ref_logits, rewards, prefix_len, clip_range=0.2):
    logits = model(sample_ids).logits # logits: batch_size, seq_len, vocab_size
    logprobs = get_logprobs(sample_ids[:,prefix_len:], logits[:,prefix_len-1:-1]).to(device)        
    return -(logprobs * rewards.unsqueeze(1)).mean()
    

In [23]:
def train_simple_policy_gradient():

    optimizer = t.optim.Adam(model.parameters(), learning_rate)
    
    for step in range(n_steps):
            
        sample_ids = model.generate(generic_prompt, max_length=generic_prompt.shape[-1]+gen_len, min_length=generic_prompt.shape[-1]+gen_len, do_sample=True, temperature=0.6, top_k=len(tokenizer), top_p=1.0, num_return_sequences=batch_size)
        sample_ids.to(device)

        old_logits = model(sample_ids).logits.detach()
        old_logits = old_logits.detach()
        old_logprobs = get_logprobs(sample_ids[:,generic_prompt.shape[-1]:], old_logits[:,generic_prompt.shape[-1]-1:-1]).detach()
        ref_logits = ref_model(sample_ids).logits.detach()
                
        list_of_sentences = tokenizer.batch_decode(sample_ids)
        print(list_of_sentences)

        rewards = t.tensor([count_periods(s) for s in list_of_sentences], dtype=t.float32).to(device)
        rewards = normalize_rewards(rewards)
        print(rewards)

        for epoch in range(1):
            loss = get_loss(sample_ids, old_logprobs, ref_logits, rewards, prefix_len=len(generic_prompt))
            loss.backward()
            t.nn.utils.clip_grad_norm_(model.parameters(), 1.0, norm_type=2.0, error_if_nonfinite=True)
            optimizer.step()
            optimizer.zero_grad()

In [24]:
train_simple_policy_gradient()

['This is a case where he can be very strong. He is a very good player and has won the league', 'This is an interesting thought experiment, and I think it would be a great idea to set up a lab for', 'This is a massive increase in the amount of fuel that people are using to transport their cars to work," said', 'This is a silly question: How can they possibly know that the person who said "Hey, I love that']
tensor([ 1.5000, -0.5000, -0.5000, -0.5000], device='cuda:6')
["This is the kind of thing that we want to do, and one of the things that's going to be", 'This is a very small number of people who are engaged in the campaign to save jobs and our economy, and', "This is what I'm about to say, because when I saw the video of the moment, I was like", 'This is why the PPP is the best indicator of how far government has come in the last decade of the']
tensor([0., 0., 0., 0.], device='cuda:6')
["This is the case with the P.A.A. in the 1970's. The P.A.", 'This is a very different type of s