In [107]:
from models import Generator, Predictor
generator = Generator("/home/ubuntu/models/generator/")

In [142]:
import pandas as pd
df = pd.DataFrame(data = {
    "utterance": ["<|client|>I don't want to prepare for the final exams anymore, but my parents forced me to.",],
    "generator_input_ids": [[],],
    "is_listener": [False,],
})
code_scores = [(code, 0.5) for code in Predictor.CODES]
print(Predictor.CODES)
gens = generator.predict(df, code_scores)
gens

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


['AF', 'SUP', 'PR', 'QUC', 'RF', 'QUO', 'INT', 'GR']


['You are a very strong person',
 "That's understandable",
 'You can do it',
 'Do you think that they will be able to help?',
 'That is a lot of pressure',
 'Why do you think they are forcing you?',
 'Hey',
 'Oh no']

# Compute seq scores

In [201]:
import torch

device = torch.device("cuda:0")

input_ids = [y for x in df.iloc[:]["generator_input_ids"].tolist() for y in x] + [-1,]  # placeholder for a code
input_ids = torch.tensor(input_ids).view(1, -1)  # torch.LongTensor of shape (batch_size, sequence_length)
input_ids = input_ids.repeat(len(code_scores), 1)
input_ids[:, -1] = torch.LongTensor([generator.CODE_TOKEN_IDS[code] for (code, _) in code_scores])

In [229]:
decode = generator.tokenizer.decode

generator.tokenizer.eos_token_id = generator.tokenizer.pad_token_id
outputs = generator.model.generate(
    input_ids.to(device), 
    do_sample=True, 
    max_length=input_ids.shape[1] + Generator.MAX_NEW_LEN,
    top_p=0.95, 
    top_k=50,

    length_penalty=0.9,
    temperature=0.3,
    repetition_penalty=1.2,
    no_repeat_ngram_size=2,
    forced_eos_token_id=generator.tokenizer.eos_token_id,

    bad_word_ids=generator.bad_words_ids,
    num_return_sequences=3,
    return_dict_in_generate=True,
    output_scores=True,
)

# Get scores of generated tokens
ids = outputs.sequences[:, input_ids.shape[1]:]
aranged = torch.arange(ids.shape[1]).repeat(ids.shape[0], 1)
scores = torch.stack(outputs.scores)
token_scores = scores[aranged, 0, ids]

# Average scores of non-special tokens
n_tokens = (token_scores >= 0).sum(axis=1)
n_tokens[n_tokens == 0] = 1  # avoid nan
token_scores[token_scores < 0] = 0.0
seq_scores = token_scores.sum(axis=1) / n_tokens

decoded = [decode(x, skip_special_tokens=True) for x in ids]
# score_seq_tuples = [(score, d) for score, d in zip(seq_scores, decoded)]
# score_seq_tuples.sort(key=lambda x: -x[0])
# score_seq_tuples

per_code_score_seq = []
inc = 3
for i in range(0, len(seq_scores), inc):
    code = Predictor.CODES[i // inc]
    idx = seq_scores[i:i+inc].argmax()
    # for score, seq in zip(seq_scores, outputs.sequences):
    # print("{:4.1f} {}".format(seq_scores[i+idx], decoded[i+idx]))
    per_code_score_seq.append((code, seq_scores[i+idx], decoded[i+idx]))

per_code_score_seq.sort(key=lambda x: -x[1])
for t in per_code_score_seq:
    print("{:5} {:4.1f} {}".format(*t))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


PR    80.8 You can always talk with your family about it
RF    78.4 It's okay you can do it
QUC   73.6 Do you feel like your parents are being too hard on you?
AF    59.7 That's a good decision
SUP   51.5 That's understandable
INT   22.5 Hey
GR    22.5 Ohh
QUO    0.0 Why did they force you?


# Reverse scoring -- can't do it without actually training a model

In [158]:
input_ids.shape, ids.shape

(torch.Size([8, 21]), torch.Size([24, 17]))

In [182]:
ro = generator.model.generate(
    ids.to(device),
    return_dict_in_generate=True,
    output_scores=True,
)
r_scores = torch.stack(ro.scores).mean(axis=0)
r_scores.shape

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


torch.Size([24, 50267])

In [196]:
r_seq_score = r_scores[torch.arange(r_scores.shape[0]).reshape(-1, 1), input_ids.repeat(3, 1)].mean(axis=0)
r_seq_score

tensor([10.9348, 22.4214, 15.3724, 18.8695, 15.9672, 19.5457, 15.2774, 18.4659,
        17.2179, 14.7451, 13.8014, 16.4618, 23.8505, 16.8800, 17.2270, 11.7806,
        13.1184, 15.5898, 19.5457, 27.0341, 11.9337], device='cuda:0')

In [197]:
r_rs = [(r, decode(s, skip_special_tokens=True)) for r, s in zip(r_seq_score, ids)]
r_rs.sort(key=lambda x: -x[0])
for r, s in r_rs:
    print("{:4.1f} {}".format(r, s))

27.0 Hey
23.9 It sounds like you are going through a lot right now
22.4 That's a good start
19.5 Oh I'm sorry
19.5 Hey
18.9 That's understandable
18.5 you shouldnt worry about that
17.2 You are in a tough situation
17.2 You should go back and study
16.9 That's a lot of pressure on you
16.5 Are you going to be able to get a job?
16.0 Oh I understand
15.6 Why did they force you?
15.4 It's a good idea
15.3 You can do it!
14.7 So you are afraid of your future?
13.8 Have you tried talking with your teachers about this?
13.1 Why do you feel like they are forcing you?
11.9 Hey
11.8 How long have you been in this relationship?
10.9 You're doing great
