__Abstract__: Realisation of some ideas from the paper 'WebGPT: Browser-assisted question-answering with human feedback.' Here we are trying to leverage human feedback to enhance the summarization model. We finetune a modified model to score the answer and then integrate rejection sampling.

Dataset: https://github.com/openai/summarize-from-feedback

Model: sshleifer/distilbart-xsum-12-1 (https://huggingface.co/sshleifer/distilbart-xsum-12-1), pretrained on XSum and CNN_daylymain

In [1]:
# !pip install transformers
# !pip install pytorch

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-1")

model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-1")

In [3]:
model

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
   

In [4]:
import torch
from torch import nn

class GraderModel(nn.Module):
    def __init__(self, model_max_len=512):
        super(GraderModel, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-1")

        self.model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-1")
        self.stack_layer = nn.Sequential(
          nn.Flatten(),
          nn.Linear(model_max_len * 50264, 1)
        )

    def forward(self, x, device):
        encoded_input = self.tokenizer.encode(x, padding='max_length', truncation=True, max_length=512, return_tensors="pt").to(device)
        answer = self.model(encoded_input)

        return self.stack_layer(answer.logits)

grader_model = GraderModel()


In [5]:
from torch.utils.data import DataLoader, Dataset
import json

class ComparisonDataset(Dataset):
    def __init__(self, files):
        self.files = files
        self.data = []

        for file in files:
            with open(file) as json_file:
                lines = json_file.readlines()

                for line in lines:
                    sample = json.loads(line)
                    self.data.append([sample['info']['post'], sample['summaries'][0]['text'], sample['summaries'][1]['text'], sample['choice']])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(device)

# model = model.to(device)
grader_model = grader_model.to(device)
grader_model.model = grader_model.model.to(device)

dataset = ComparisonDataset(['./data/batch4.json', './data/batch5.json', './data/batch9.json'])
# dataset = ComparisonDataset(['/content/batch4.json'])
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

cuda


In [6]:
from tqdm import tqdm
import numpy as np

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, grader_model.parameters()))

# num_epoch = 2 -> 0 cause using checkpoint
num_epoch = 0

loss_history = []

for epoch in tqdm(range(num_epoch), position=0, leave=True):
    loss_history.append([])

    for reference, desc1, desc2, target in tqdm(dataloader, position=0, leave=True):
        prompt1 = reference[0] + '|' + desc1[0]
        prompt2 = reference[0] + '|' + desc2[0]

        target = torch.tensor(target).to(device)
        optimizer.zero_grad()

        score1 = grader_model(prompt1, device).squeeze()
        score2 = grader_model(prompt2, device).squeeze()

        comp = torch.stack([score1, score2]).unsqueeze(0)

        loss = criterion(comp, target)
        loss_history[-1].append(loss.item())

        loss.backward()
        optimizer.step()

    print('Epoch: {}\t|\tLoss: {}'.format(epoch, np.mean(loss_history[-1])))


0it [00:00, ?it/s]


In [7]:
import numpy as np

def simple_sampling(model, text, ans_n=3):
    with torch.no_grad():
        prompt = text
        prompt = tokenizer.encode(prompt, padding='max_length', truncation=True, max_length=512, return_tensors="pt")

        answers = model.generate(prompt, 
                                  do_sample=True,   
                                  min_length=50, 
                                  max_length=768,
                                  top_k=30,                                 
                                  top_p=0.7,   
                                  temperature=0.9,
                                  repetition_penalty=2.0,
                                  num_return_sequences=ans_n)
        
        ans = []

        for answer in answers:
            ans.append(tokenizer.decode(answer))

        return ans
        
def rej_sampling(model, reward_model, text, ans_n=3):
    with torch.no_grad():
        prompt = text
        prompt = tokenizer.encode(prompt, padding='max_length', truncation=True, max_length=512, return_tensors="pt")

        answers = model.generate(prompt, 
                                  do_sample=True,   
                                  min_length=50, 
                                  max_length=768,
                                  top_k=30,                                 
                                  top_p=0.7,   
                                  temperature=0.9,
                                  repetition_penalty=2.0,
                                  num_return_sequences=15)
        
        rewards = []

        for answer in answers:
            prompt = text + '|' + tokenizer.decode(answer)

            reward = reward_model(prompt, device)
            rewards.append(reward.item())

        idx = np.argsort(rewards)
        ans = []
        
        for i in range(ans_n):
            ans.append(tokenizer.decode(answers[idx[-i]])) 

        return ans

In [8]:
# torch.save(grader_model, f='kek')
grader_model = torch.load('kek')

In [9]:
texts = [
    '''The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.''',
    '''lagos, nigeria (cnn) a day after winning nigeria’s presidency, muhammadu buhari told cnn’s christiane amanpour that he plans to aggressively fight corruption that has long plagued nigeria and go after the root of the nation’s unrest. buhari said he’ll “rapidly give attention” to curbing violence in the northeast part of nigeria, where the terrorist group boko haram operates. by cooperating with neighboring nations chad, cameroon and niger, he said his administration is confident it will be able to thwart criminals and others contributing to nigeria’s instability. for the first time in nigeria’s history, the opposition defeated the ruling party in democratic elections. buhari defeated incumbent goodluck jonathan by about 2 million votes, according to nigeria’s independent national electoral commission. the win comes after a long history of military rule, coups and botched attempts at democracy in africa’s most populous nation.''',
    '''What kind of exercise do lazy people do? Diddly-squats''',
    '''What is Forrest Gump's password? 1Forrest1.''',
    '''What do you call bears with no ears? B.''',
    '''What's a foot long and slippery? A slipper!''',
    '''What are a shark's two most favorite words? Man overboard!''',
]

for text in texts:
    simple_abs = simple_sampling(model, text)
    rej_abs = rej_sampling(model, grader_model, text)

    print('Text: {}\nSimple abs:'.format(text))

    for abs in simple_abs:
        print(abs)

    print('Rejection abs:')
    
    for abs in rej_abs:
        print(abs)

  next_indices = next_tokens // vocab_size


Text: The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct.
Simple abs:
</s> The Eiffel Tower in Paris has become the tallest in the world, more than 60 years after it was built on the roof of the French capital's (19.5m) at the end of this week. "9th Century".</s>
</s> The Eiffel Tower in Paris has become th

Text: What's a foot long and slippery? A slipper!
Simple abs:
</s> A foot-long has been found in the UK's biggest ever, but it can't take a step closer to the end of the year. and is now known as 'Slipper' or slippers - that they're going to be?</s>
</s> A foot-long walker has been described as "the world's most slippery" - but it is not yet to be able to find out what happens when they're a slipper. Â£40,000 (6m) long</s>
</s> A foot-long walker has been described as "the world's most slippery" in the US, but it is not too easy to find out how to get a slipper. and that they can't be used to make your feet.</s>
Rejection abs:
</s> A foot-long walker has been described as " the world's most slippery" - but how do you know what happens when they're going to be in a slipper. and is now known as 'Slipper' for'? of...</s><pad>
</s> A foot, a foot-long walker has been described as "the world's most slippery" - but what happens when it comes to an end up of this year's on the World War Two? 