In [1]:
import joblib

In [2]:
train_embeddings = joblib.load('train_embeddings_logits_only.joblib')

In [3]:
train_embeddings[0].keys()

dict_keys(['logits_and_indices', 'input_ids'])

### Training from embeddings

In [4]:
import torch
import transformers
import torch.nn.functional as F

from tqdm import tqdm
from transformers import GPT2LMHeadModel, AutoTokenizer

In [6]:
gpt_small = GPT2LMHeadModel.from_pretrained("gpt2")
checkpoint = torch.load('save_small' + str(1))
gpt_small.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [7]:
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")

In [8]:
def create_text_from_summary_and_dialogue(summary, dialogue):
    text = f"""
A partial summary of the conversation is:
{summary}

With the dialogue being:
{dialogue}
    """.strip()
    
    return text.replace('\r\n', '\n')

In [9]:
summary = """
Alberto is a customer. Alberto ordered pizza and lasagne at Dominos one hour ago.
Alberto wants to know where his pizza is. Alberto is calling Dominos to know where his orders are.
John works at Dominos. John says that his pizza is almost ready. It will be at Alberto's home in 10 minutes. 
The lasagne is already on its way with the delivery person. Unfortunately the driver had an accident and is now in the hospital.
If alberto says "rabbit" then John replies "horse". If Alberto says "Tree" John says "red".
""".strip().replace('\n', ' ')

dialogue = """
Alberto: What happened to my pizza?
John: 
""".strip()

text = create_text_from_summary_and_dialogue(summary, dialogue)

prompt = tokenizer(text, return_tensors='pt').to('cuda')
prompt = {key: value for key, value in prompt.items()}

In [10]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [11]:
import json

val = json.load(open('../data/val.json'))

_limit = 1024
dev_data = []
total_skipped = 0
for item in val:
    text = create_text_from_summary_and_dialogue(item["summary"], item["dialogue"])
    tokens = tokenizer.encode(text, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
    dev_data.append(tokens)
    
print(f'Skipped {total_skipped} out of {len(val)}')

dev_batches = batchify(dev_data, 1)

def test(test_model, batches):
    test_model.eval()
    total_loss = 0.
    for i, batch in enumerate(batches):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs.cuda(), labels=inputs.cuda())[0]
        total_loss += loss.item()

    return total_loss / len(batches)

Skipped 0 out of 818


In [12]:
_ = gpt_small.cuda()

In [13]:
print('Dev loss:', test(gpt_small, dev_batches))

Dev loss: 2.2554801596114573


In [14]:
def get_probability_vector(log_prob_dict, temp):
    _vocab_size = 50257
    
    logits = torch.tensor(log_prob_dict['logits'])
    num_tokens = logits.shape[1]
    indices = torch.tensor(log_prob_dict['indices'])
    vectors = []
    
    for index_set, logs in zip(indices[0], logits[0]):
        v = torch.sparse_coo_tensor([index_set.tolist()], logs, (_vocab_size, )).to_dense().float()
        v[v == 0] = torch.tensor(float('-inf'))
        vectors.append(v)

    vectors = torch.stack(vectors, dim=0)
    return F.softmax(vectors / temp, dim=-1)

In [15]:
get_probability_vector(train_embeddings[0]['logits_and_indices'], temp=10).shape

torch.Size([56, 50257])

In [16]:
import random
from torch.optim.lr_scheduler import StepLR

lr = 1e-5
optimizer = torch.optim.Adam(gpt_small.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=2, gamma=0.5)
epochs = 5

steps = 0
best_model = None
best_loss = 1e6
for epoch_num in range(epochs):
    gpt_small.train()
    temp = 30
    random.shuffle(train_embeddings)
    
    for item in tqdm(train_embeddings):
        input_ids = torch.tensor([item['input_ids']]).cuda()
        label_p = get_probability_vector(item['logits_and_indices'], temp=temp).cuda()
        out_logits = gpt_small.forward(input_ids).logits
        out_p = F.softmax(out_logits / temp, dim=-1)
        
        loss = gpt_small(input_ids, labels=input_ids)[0]
        
        loss -=  temp * temp * torch.mean(torch.mul(torch.log(out_p).flatten(),
                                          label_p.flatten()))

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        steps += 1
        
        if steps % 2000 == 0:
            print("steps", steps)
            print('Dev loss:', test(gpt_small, dev_batches))
            if loss < best_loss:
                best_loss = loss
                best_model = gpt_small
            
    scheduler.step()

 14%|█████▏                                | 1999/14732 [03:48<24:35,  8.63it/s]

steps 2000


 14%|████▉                               | 2001/14732 [03:57<6:26:48,  1.82s/it]

Dev loss: 2.2579836736098478


 27%|██████████▎                           | 3999/14732 [07:43<19:58,  8.96it/s]

steps 4000


 27%|█████████▊                          | 4001/14732 [07:52<5:16:15,  1.77s/it]

Dev loss: 2.257600620206819


 41%|███████████████▍                      | 5999/14732 [11:39<16:54,  8.61it/s]

steps 6000


 41%|██████████████▋                     | 6001/14732 [11:48<4:48:43,  1.98s/it]

Dev loss: 2.257007570459091


 54%|████████████████████▋                 | 7999/14732 [15:33<12:33,  8.93it/s]

steps 8000


 54%|███████████████████▌                | 8001/14732 [15:42<3:24:40,  1.82s/it]

Dev loss: 2.261356037228498


 68%|█████████████████████████▊            | 9999/14732 [19:27<08:50,  8.93it/s]

steps 10000


 68%|███████████████████████▊           | 10001/14732 [19:36<2:14:59,  1.71s/it]

Dev loss: 2.255610811681211


 81%|██████████████████████████████▏      | 11999/14732 [23:22<05:19,  8.54it/s]

steps 12000


 81%|████████████████████████████▌      | 12001/14732 [23:31<1:28:48,  1.95s/it]

Dev loss: 2.2528563258408916


 95%|███████████████████████████████████▏ | 13999/14732 [27:15<01:21,  8.94it/s]

steps 14000


 95%|███████████████████████████████████▏ | 14001/14732 [27:24<23:00,  1.89s/it]

Dev loss: 2.2522450959478437


100%|█████████████████████████████████████| 14732/14732 [28:47<00:00,  8.53it/s]
  9%|███▎                                  | 1267/14732 [02:25<26:09,  8.58it/s]

steps 16000


  9%|███                                 | 1269/14732 [02:34<7:00:41,  1.87s/it]

Dev loss: 2.2726549818055264


 22%|████████▍                             | 3267/14732 [06:20<22:15,  8.58it/s]

steps 18000


 22%|███████▉                            | 3269/14732 [06:29<6:11:18,  1.94s/it]

Dev loss: 2.2910467864248747


 36%|█████████████▌                        | 5267/14732 [10:15<17:17,  9.12it/s]

steps 20000


 36%|████████████▉                       | 5269/14732 [10:24<4:30:13,  1.71s/it]

Dev loss: 2.292764984016605


 49%|██████████████████▋                   | 7267/14732 [14:08<14:25,  8.63it/s]

steps 22000


 49%|█████████████████▊                  | 7269/14732 [14:17<3:58:06,  1.91s/it]

Dev loss: 2.2901895054978088


 63%|███████████████████████▉              | 9267/14732 [18:03<10:20,  8.81it/s]

steps 24000


 63%|██████████████████████▋             | 9269/14732 [18:12<2:47:52,  1.84s/it]

Dev loss: 2.2979643369653697


 76%|████████████████████████████▎        | 11266/14732 [21:57<06:17,  9.18it/s]

steps 26000


 76%|██████████████████████████▊        | 11269/14732 [22:06<1:26:41,  1.50s/it]

Dev loss: 2.297754973537473


 90%|█████████████████████████████████▎   | 13267/14732 [25:51<02:37,  9.29it/s]

steps 28000


 90%|█████████████████████████████████▎   | 13269/14732 [26:00<41:46,  1.71s/it]

Dev loss: 2.2932780080436204


100%|█████████████████████████████████████| 14732/14732 [28:45<00:00,  8.54it/s]
  4%|█▍                                     | 535/14732 [01:01<27:23,  8.64it/s]

steps 30000


  4%|█▎                                   | 537/14732 [01:10<7:28:21,  1.90s/it]

Dev loss: 2.2973292445203786


 17%|██████▌                               | 2535/14732 [04:56<22:38,  8.98it/s]

steps 32000


 17%|██████▏                             | 2537/14732 [05:05<5:47:33,  1.71s/it]

Dev loss: 2.3563650125104814


 31%|███████████▋                          | 4534/14732 [08:51<18:46,  9.05it/s]

steps 34000


 31%|███████████                         | 4537/14732 [09:00<4:06:54,  1.45s/it]

Dev loss: 2.3639763159390474


 44%|████████████████▊                     | 6535/14732 [12:46<15:31,  8.80it/s]

steps 36000


 44%|███████████████▉                    | 6537/14732 [12:55<4:21:07,  1.91s/it]

Dev loss: 2.367737333902811


 58%|██████████████████████                | 8535/14732 [16:41<12:14,  8.44it/s]

steps 38000


 58%|████████████████████▊               | 8537/14732 [16:50<3:20:38,  1.94s/it]

Dev loss: 2.3664583180527816


 72%|██████████████████████████▍          | 10534/14732 [20:34<07:52,  8.89it/s]

steps 40000


 72%|█████████████████████████          | 10538/14732 [20:43<1:27:23,  1.25s/it]

Dev loss: 2.3597797955452378


 85%|███████████████████████████████▍     | 12535/14732 [24:30<04:20,  8.44it/s]

steps 42000


 85%|█████████████████████████████▊     | 12537/14732 [24:39<1:11:48,  1.96s/it]

Dev loss: 2.3662582026717134


 99%|████████████████████████████████████▌| 14535/14732 [28:23<00:22,  8.66it/s]

steps 44000


 99%|████████████████████████████████████▌| 14538/14732 [28:32<04:52,  1.51s/it]

Dev loss: 2.3658582785310256


100%|█████████████████████████████████████| 14732/14732 [28:54<00:00,  8.49it/s]
 12%|████▋                                 | 1803/14732 [03:26<22:38,  9.52it/s]

steps 46000


 12%|████▍                               | 1805/14732 [03:35<5:57:01,  1.66s/it]

Dev loss: 2.3316217537031196


 26%|█████████▊                            | 3803/14732 [07:20<21:08,  8.61it/s]

steps 48000


 26%|█████████▎                          | 3805/14732 [07:29<5:46:32,  1.90s/it]

Dev loss: 2.4235706993886486


 39%|██████████████▉                       | 5803/14732 [11:14<17:23,  8.55it/s]

steps 50000


 39%|██████████████▏                     | 5805/14732 [11:23<4:50:03,  1.95s/it]

Dev loss: 2.4373007386121306


 53%|████████████████████                  | 7802/14732 [15:09<12:53,  8.96it/s]

steps 52000


 53%|███████████████████                 | 7805/14732 [15:18<2:51:17,  1.48s/it]

Dev loss: 2.436643770533844


 67%|█████████████████████████▎            | 9803/14732 [19:02<09:15,  8.87it/s]

steps 54000


 67%|███████████████████████▉            | 9805/14732 [19:11<2:29:32,  1.82s/it]

Dev loss: 2.431935532897492


 80%|█████████████████████████████▋       | 11803/14732 [22:57<05:17,  9.23it/s]

steps 56000


 80%|████████████████████████████       | 11805/14732 [23:06<1:12:44,  1.49s/it]

Dev loss: 2.4359890920608724


 94%|██████████████████████████████████▋  | 13803/14732 [26:51<01:47,  8.67it/s]

steps 58000


 94%|██████████████████████████████████▋  | 13805/14732 [27:00<29:21,  1.90s/it]

Dev loss: 2.431170636983839


100%|█████████████████████████████████████| 14732/14732 [28:45<00:00,  8.54it/s]
  7%|██▊                                   | 1071/14732 [02:02<26:15,  8.67it/s]

steps 60000


  7%|██▌                                 | 1073/14732 [02:12<7:27:03,  1.96s/it]

Dev loss: 2.3656516860633143


 21%|███████▉                              | 3071/14732 [05:57<22:26,  8.66it/s]

steps 62000


 21%|███████▌                            | 3073/14732 [06:06<6:12:31,  1.92s/it]

Dev loss: 2.5030678698660984


 34%|█████████████                         | 5071/14732 [09:50<18:30,  8.70it/s]

steps 64000


 34%|████████████▍                       | 5073/14732 [09:59<4:59:29,  1.86s/it]

Dev loss: 2.5160546503906436


 48%|██████████████████▏                   | 7071/14732 [13:44<14:51,  8.59it/s]

steps 66000


 48%|█████████████████▎                  | 7073/14732 [13:53<3:59:55,  1.88s/it]

Dev loss: 2.5128355273990586


 62%|███████████████████████▍              | 9071/14732 [17:38<10:37,  8.88it/s]

steps 68000


 62%|██████████████████████▏             | 9073/14732 [17:47<2:48:02,  1.78s/it]

Dev loss: 2.5151401408435663


 75%|███████████████████████████▊         | 11070/14732 [21:33<07:14,  8.43it/s]

steps 70000


 75%|██████████████████████████▎        | 11074/14732 [21:42<1:20:14,  1.32s/it]

Dev loss: 2.5142844481106783


 89%|████████████████████████████████▊    | 13071/14732 [25:26<03:09,  8.77it/s]

steps 72000


 89%|████████████████████████████████▊    | 13073/14732 [25:35<53:31,  1.94s/it]

Dev loss: 2.523418012487276


100%|█████████████████████████████████████| 14732/14732 [28:42<00:00,  8.55it/s]


In [18]:
print('Dev loss:', test(best_model, dev_batches))

Dev loss: 2.518438814758963


In [19]:
best_model.save_pretrained(f'gpt_small_checkpoint_temp{temp}_lr{lr}')

In [None]:
gpt_small = GPT2LMHeadModel.from_pretrained("gpt_small_checkpoint_temp30_lr1e-05")
_ = gpt_small.cuda()

In [20]:
_ = gpt_small.eval()

In [21]:
device = 'cuda'

In [23]:
dialogue = """
John: Hello, how can I help?
""".strip()


print(dialogue)

while True:
    user_input = input()
    dialogue += "\nAlberto: " + user_input + "\nJohn: "
    text = create_text_from_summary_and_dialogue(summary, dialogue)
    prompt = tokenizer(text, return_tensors='pt')
    prompt = {key: value.to(device) for key, value in prompt.items()}
    out = gpt_small.generate(**prompt, max_length=prompt['input_ids'].shape[1] + 25, do_sample=False)
    out = out[0][prompt['input_ids'].shape[1]:]
    answer = tokenizer.decode(out)
    answer = answer[:answer.find('\n')].strip()
    print(answer)
    dialogue += answer

John: Hello, how can I help?
tree


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I'm a customer of Dominos.


KeyboardInterrupt: Interrupted by user

In [None]:
%%time

summary = """
Alberto is a customer. Alberto ordered pizza and lasagne at Dominos one hour ago.
Alberto wants to know where his pizza is. Alberto is calling Dominos to know where his orders are.
John works at Dominos. John says that his pizza is almost ready. It will be at Alberto's home in 10 minutes. 
The lasagne is already on its way with the delivery person. Unfortunately the driver had an accident and is now in the hospital.
If alberto says "rabbit" then John replies "horse". If Alberto says "Tree" John says "red".
""".strip().replace('\n', ' ')

dialogue = """
Alberto: What happened to my pizza?
John: It's in the delivery man's car.
Alberto: And where is the delivery man?
""".strip()

text = create_text_from_summary_and_dialogue(summary, dialogue)

prompt = tokenizer(text, return_tensors='pt').to('cuda')
prompt = {key: value for key, value in prompt.items()}
out = gpt_small.generate(**prompt, max_length=prompt['input_ids'].shape[1] + 10, do_sample=False)
print(tokenizer.decode(out[0]))

In [None]:
gpt_small.save_pretrained('./gptj_small')