In [1]:
import json
import torch
import torch.nn as nn
import random
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [2]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [3]:
train_list = json.load(open('../data/roles.json', encoding='utf8'))

In [4]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [5]:
_limit = 1024
data = []
total_trimmed = 0
for item in train_list:
    tokens = tokenizer.encode(item, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
        total_trimmed += 1
    data.append(tokens)
print(f'Trimmed {total_trimmed} out of {len(train_list)}')

Trimmed 0 out of 42066


In [6]:
train_batches = batchify(data, 1)

In [4]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [8]:
def train(train_model, batches, optimizer, criterion):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        model.train()
        inputs = batch
        optimizer.zero_grad()
        loss = train_model(inputs.cuda(), labels=inputs.cuda())[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(train_model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(batches)

def test(test_model, batches):
    total_loss = 0.
    for i, batch in tqdm(enumerate(batches), total=len(batches)):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs, labels=inputs)[0]
        total_loss += loss.item()

    return total_loss / len(batches)

In [9]:
from torch.optim.lr_scheduler import StepLR

random.shuffle(train_batches)
scheduler = StepLR(optimizer, step_size=2, gamma=0.8)
for epoch in range(5):
    random.shuffle(train_batches)
    loss = train(model, train_batches, optimizer, criterion)
    print('Epoch:', epoch, 'Loss:', loss)
    torch.save({'epoch': epoch,
                'model_state_dict': model.state_dict()},
                'rules_small' + str(epoch))
    scheduler.step()

100%|█████████████████████████████████████████████████████████████████████████████████████████████| 42066/42066 [32:03<00:00, 21.87it/s]


Epoch: 0 Loss: 0.7470427921034496


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 42066/42066 [31:16<00:00, 22.42it/s]


Epoch: 1 Loss: 0.6445429053407125


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 42066/42066 [31:16<00:00, 22.42it/s]


Epoch: 2 Loss: 0.6045077501759035


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 42066/42066 [31:15<00:00, 22.43it/s]


Epoch: 3 Loss: 0.5830690292268248


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 42066/42066 [31:16<00:00, 22.41it/s]


Epoch: 4 Loss: 0.5589288786221563


In [5]:
checkpoint = torch.load('rules_small' + str(4))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [6]:
def generate_answer(model, prompt, num):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    _length = 50
    tokens_length = tokens.shape[1]
    if tokens_length + _length > 1024:
        return ''
    generated_ids = model.generate(tokens.to("cuda"), max_length=_length, num_beams=num, num_return_sequences=num)
    generated_sentences = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    sentences = []
    for index, item in enumerate(generated_sentences):
        output = generated_sentences[index]
        offset = len(prompt)
        start = offset + 1
        end = output.find('\n', start)
        sentences.append(item[start: end])
    
    return sentences

In [7]:
def create_text(sentence):
    return f"""
The following utterance:
{sentence}

defines this type of role:
    """.strip()

In [8]:
prompt = create_text("The user says 'Oh'")
generate_answer(model, prompt, 10)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['You are a bot that agrees with everything the user says.',
 'You are a bot designed to agree with everything the user says, but you are not designed to agree with everything the user says.',
 'It is a bot designed to agree with everything the user says.',
 'You are a bot designed to agree with what the user says.',
 'You are a bot designed to agree with everything the user says.',
 'A bot designed to agree with everything the user says.',
 'You are a bot designed to agree with what the user says.',
 'You are a bot designed to agree to everything the user says.',
 'You are a bot that agrees with everything the user says.',
 'You are a bot that agrees with everything the user says.']

In [26]:
prompt = """
In the text below two people are discussing a story.

Story:
The user says: "My ship is called George".

Discussion:
Q: Who is speaking?
A: 
""".strip()

generate_answer_with_typical_decoding(model, prompt)

CPU times: user 25.3 ms, sys: 13.7 ms, total: 39 ms
Wall time: 38.3 ms


'George.'

In [53]:
prompt = """
In the text below two people are discussing a story.

Story:
The Wag speaks: "My ship is called George".

Discussion:
Q: Who is speaking?
A: 
""".strip()

generate_answer_with_typical_decoding(model, prompt)

'The Waggard.'

In [54]:
%%time
generate_answer_and_get_confidence(model, prompt)

torch.Size([1, 44])
41
torch.Size([1, 3])
CPU times: user 4.47 s, sys: 0 ns, total: 4.47 s
Wall time: 4.47 s


('The Wag', 1.748390108346939)

In [55]:
def generate_answer_greedy(model, prompt, max_length= 50):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    tokens_length = tokens.shape[1]
    if tokens_length + max_length > 1024:
        return ''
    
    while tokens.shape[-1] < tokens_length + max_length:
        new_tokens = model(tokens.cuda())
        pred_ids = torch.argmax(new_tokens.logits, dim=-1)
        last_token = pred_ids[:, -1].cpu().detach()
        tokens = torch.cat([tokens, torch.tensor([[last_token]])], dim=-1)
        last_output = tokenizer.decode(last_token, skip_special_tokens=True)
        if last_output == '\n':
            break
        
    generated_output = tokens[:, tokens_length:]
    output = tokenizer.decode(generated_output[0], skip_special_tokens=True)
    end = output.find('\n')
    return output[:end].replace('A: ', '').strip()

In [56]:
%%time
generate_answer_greedy(model, prompt)

CPU times: user 22.2 ms, sys: 0 ns, total: 22.2 ms
Wall time: 22 ms


'The Wag'

#### Compute prob of specific sequences

In [66]:
def get_sequence_probability_given_prompt(model, prompt, sequence):
    tokens = tokenizer.encode(prompt, return_tensors='pt')
    sequence_tokens = tokenizer.encode(sequence, return_tensors='pt')[0]
    token_index = 0
    total_prob = 1
    while token_index < len(sequence_tokens):
        new_tokens = model(tokens.cuda())
        last_distribution = new_tokens.logits[:, -1].cpu().detach()
        probs = torch.nn.functional.softmax(last_distribution)
        total_prob *= probs[0][sequence_tokens[token_index]]
        token_index += 1
        
        pred_ids = torch.argmax(new_tokens.logits, dim=-1)
        last_token = pred_ids[:, -1].cpu().detach()
        tokens = torch.cat([tokens, torch.tensor([[last_token]])], dim=-1)

            
    return total_prob

In [75]:
prompt = """
In the text below two people are discussing a story.

Story:
The user speaks: "My ship is called George".

Discussion:
Q: Who is the user speaking to?
A: 
""".strip()

In [77]:
get_sequence_probability_given_prompt(model, prompt, "user")

  probs = torch.nn.functional.softmax(last_distribution)


tensor(6.8406e-08)

In [78]:
generate_answer_greedy(model, prompt)

'George'