In [1]:
import joblib

In [2]:
train_embeddings = joblib.load('train_embeddings_logits_only.joblib')

In [3]:
train_embeddings[0].keys()

dict_keys(['logits_and_indices', 'input_ids'])

### Training from embeddings

In [4]:
import torch
import transformers
import torch.nn.functional as F

from tqdm import tqdm
from transformers import GPT2LMHeadModel, AutoTokenizer

In [5]:
import torch
import transformers
import torch.nn.functional as F

from tqdm import tqdm
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJConfig

config = transformers.GPTJConfig(n_embd=128,
                                 n_layer=6,
                                 n_head=4)

gpt_small = GPTJForCausalLM(config)
gpt_small.cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

In [6]:
def create_text_from_summary_and_dialogue(summary, dialogue):
    text = f"""
A partial summary of the conversation is:
{summary}

With the dialogue being:
{dialogue}
    """.strip()
    
    return text.replace('\r\n', '\n')

In [7]:
summary = """
Alberto is a customer. Alberto ordered pizza and lasagne at Dominos one hour ago.
Alberto wants to know where his pizza is. Alberto is calling Dominos to know where his orders are.
John works at Dominos. John says that his pizza is almost ready. It will be at Alberto's home in 10 minutes. 
The lasagne is already on its way with the delivery person. Unfortunately the driver had an accident and is now in the hospital.
If alberto says "rabbit" then John replies "horse". If Alberto says "Tree" John says "red".
""".strip().replace('\n', ' ')

dialogue = """
Alberto: What happened to my pizza?
John: 
""".strip()

text = create_text_from_summary_and_dialogue(summary, dialogue)

prompt = tokenizer(text, return_tensors='pt').to('cuda')
prompt = {key: value for key, value in prompt.items()}

In [8]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [9]:
import json

val = json.load(open('../data/val.json'))

_limit = 1024
dev_data = []
total_skipped = 0
for item in val:
    text = create_text_from_summary_and_dialogue(item["summary"], item["dialogue"])
    tokens = tokenizer.encode(text, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
    dev_data.append(tokens)
    
print(f'Skipped {total_skipped} out of {len(val)}')

dev_batches = batchify(dev_data, 1)

def test(test_model, batches):
    test_model.eval()
    total_loss = 0.
    for i, batch in enumerate(batches):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs.cuda(), labels=inputs.cuda())[0]
        total_loss += loss.item()

    return total_loss / len(batches)

Skipped 0 out of 818


In [10]:
_ = gpt_small.cuda()

In [11]:
print('Dev loss:', test(gpt_small, dev_batches))

Dev loss: 10.836287398210073


In [12]:
def get_probability_vector(log_prob_dict, temp):
    _vocab_size = 50400
    
    logits = torch.tensor(log_prob_dict['logits'])
    num_tokens = logits.shape[1]
    indices = torch.tensor(log_prob_dict['indices'])
    vectors = []
    
    for index_set, logs in zip(indices[0], logits[0]):
        v = torch.sparse_coo_tensor([index_set.tolist()], logs, (_vocab_size, )).to_dense().float()
        v[v == 0] = torch.tensor(float('-inf'))
        vectors.append(v)

    vectors = torch.stack(vectors, dim=0)
    return F.softmax(vectors / temp, dim=-1)

In [13]:
get_probability_vector(train_embeddings[0]['logits_and_indices'], temp=10).shape

torch.Size([56, 50400])

In [14]:
import random
from torch.optim.lr_scheduler import StepLR

lr = 1e-4
optimizer = torch.optim.Adam(gpt_small.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=2, gamma=0.5)
epochs = 50

steps = 0

for epoch_num in range(epochs):
    gpt_small.train()
    temp = 30
    random.shuffle(train_embeddings)
    
    for item in tqdm(train_embeddings):
        input_ids = torch.tensor([item['input_ids']]).cuda()
        label_p = get_probability_vector(item['logits_and_indices'], temp=temp).cuda()
        out_logits = gpt_small.forward(input_ids).logits
        out_p = F.softmax(out_logits / temp, dim=-1)
        
        loss = gpt_small(input_ids, labels=input_ids)[0]
        
        loss -=  temp * temp * torch.mean(torch.mul(torch.log(out_p).flatten(),
                                          label_p.flatten()))

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        steps += 1
        
        if steps % 2000 == 0:
            print("steps", steps)
            print('Dev loss:', test(gpt_small, dev_batches))
            
    scheduler.step()

 14%|█████▏                                | 1998/14732 [02:34<16:11, 13.11it/s]

steps 2000


 14%|████▉                               | 2002/14732 [02:42<3:00:45,  1.17it/s]

Dev loss: 4.938633865715531


 27%|██████████▎                           | 3998/14732 [05:17<14:07, 12.67it/s]

steps 4000


 27%|█████████▊                          | 4002/14732 [05:25<2:30:54,  1.19it/s]

Dev loss: 4.553118632591732


 41%|███████████████▍                      | 5998/14732 [07:59<11:04, 13.14it/s]

steps 6000


 41%|██████████████▋                     | 6002/14732 [08:07<2:03:29,  1.18it/s]

Dev loss: 4.353118068723049


 54%|████████████████████▋                 | 7998/14732 [10:41<08:15, 13.60it/s]

steps 8000


 54%|███████████████████▌                | 8002/14732 [10:49<1:35:10,  1.18it/s]

Dev loss: 4.227549059758268


 68%|█████████████████████████▊            | 9998/14732 [13:23<06:05, 12.96it/s]

steps 10000


 68%|███████████████████████▊           | 10002/14732 [13:31<1:07:01,  1.18it/s]

Dev loss: 4.135350516489491


 81%|██████████████████████████████▏      | 11998/14732 [16:05<03:44, 12.19it/s]

steps 12000


 81%|██████████████████████████████▏      | 12002/14732 [16:13<40:22,  1.13it/s]

Dev loss: 4.0844615257748185


 95%|███████████████████████████████████▏ | 13999/14732 [18:57<00:58, 12.44it/s]

steps 14000


 95%|███████████████████████████████████▏ | 14001/14732 [19:05<15:50,  1.30s/it]

Dev loss: 4.034545751539011


100%|█████████████████████████████████████| 14732/14732 [20:07<00:00, 12.20it/s]
  9%|███▎                                  | 1266/14732 [01:47<17:39, 12.71it/s]

steps 16000


  9%|███                                 | 1270/14732 [01:56<3:24:35,  1.10it/s]

Dev loss: 4.00626885686935


 22%|████████▍                             | 3266/14732 [04:36<15:41, 12.18it/s]

steps 18000


 22%|███████▉                            | 3270/14732 [04:45<2:56:47,  1.08it/s]

Dev loss: 3.969237920998944


 36%|█████████████▌                        | 5267/14732 [07:23<12:04, 13.06it/s]

steps 20000


 36%|████████████▉                       | 5269/14732 [07:30<3:04:11,  1.17s/it]

Dev loss: 3.9378377671346687


 49%|██████████████████▋                   | 7267/14732 [10:04<09:28, 13.12it/s]

steps 22000


 49%|█████████████████▊                  | 7269/14732 [10:12<2:24:13,  1.16s/it]

Dev loss: 3.8997517847490193


 63%|███████████████████████▉              | 9267/14732 [12:47<06:35, 13.81it/s]

steps 24000


 63%|██████████████████████▋             | 9269/14732 [12:54<1:45:56,  1.16s/it]

Dev loss: 3.8704656216800943


 76%|████████████████████████████▎        | 11267/14732 [15:28<04:34, 12.64it/s]

steps 26000


 76%|██████████████████████████▊        | 11269/14732 [15:36<1:07:11,  1.16s/it]

Dev loss: 3.8390351332778745


 90%|█████████████████████████████████▎   | 13267/14732 [18:11<01:53, 12.92it/s]

steps 28000


 90%|█████████████████████████████████▎   | 13269/14732 [18:18<29:22,  1.20s/it]

Dev loss: 3.8119348023806343


100%|█████████████████████████████████████| 14732/14732 [20:11<00:00, 12.16it/s]
  4%|█▍                                     | 534/14732 [00:41<19:25, 12.18it/s]

steps 30000


  4%|█▎                                   | 538/14732 [00:48<3:21:06,  1.18it/s]

Dev loss: 3.775694073558145


 17%|██████▌                               | 2534/14732 [03:23<15:35, 13.04it/s]

steps 32000


 17%|██████▏                             | 2538/14732 [03:30<2:50:19,  1.19it/s]

Dev loss: 3.7674856097016183


 31%|███████████▋                          | 4534/14732 [06:04<13:26, 12.64it/s]

steps 34000


 31%|███████████                         | 4538/14732 [06:12<2:25:39,  1.17it/s]

Dev loss: 3.7380736578064617


 44%|████████████████▊                     | 6534/14732 [08:46<10:59, 12.44it/s]

steps 36000


 44%|███████████████▉                    | 6538/14732 [08:54<1:54:40,  1.19it/s]

Dev loss: 3.7151037230176858


 58%|██████████████████████                | 8534/14732 [11:29<08:09, 12.65it/s]

steps 38000


 58%|████████████████████▊               | 8538/14732 [11:37<1:27:33,  1.18it/s]

Dev loss: 3.6944738814475193


 59%|██████████████████████▌               | 8745/14732 [11:53<08:08, 12.26it/s]


KeyboardInterrupt: 

In [15]:
_ = gpt_small.eval()

In [16]:
device = 'cuda'

In [18]:
dialogue = """
John: Hello, how can I help?
""".strip()


print(dialogue)

while True:
    user_input = input()
    dialogue += "\nAlberto: " + user_input + "\nJohn: "
    text = create_text_from_summary_and_dialogue(summary, dialogue)
    prompt = tokenizer(text, return_tensors='pt')
    prompt = {key: value.to(device) for key, value in prompt.items()}
    out = gpt_small.generate(**prompt, max_length=prompt['input_ids'].shape[1] + 25, do_sample=False)
    out = out[0][prompt['input_ids'].shape[1]:]
    answer = tokenizer.decode(out)
    answer = answer[:answer.find('\n')].strip()
    print(answer)
    dialogue += answer

John: Hello, how can I help?
hello what did I order


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I'm not going to the cinema with the last time, I'm not sure, but I'm not sure, but


KeyboardInterrupt: Interrupted by user

In [None]:
%%time

summary = """
Alberto is a customer. Alberto ordered pizza and lasagne at Dominos one hour ago.
Alberto wants to know where his pizza is. Alberto is calling Dominos to know where his orders are.
John works at Dominos. John says that his pizza is almost ready. It will be at Alberto's home in 10 minutes. 
The lasagne is already on its way with the delivery person. Unfortunately the driver had an accident and is now in the hospital.
If alberto says "rabbit" then John replies "horse". If Alberto says "Tree" John says "red".
""".strip().replace('\n', ' ')

dialogue = """
Alberto: What happened to my pizza?
John: It's in the delivery man's car.
Alberto: And where is the delivery man?
""".strip()

text = create_text_from_summary_and_dialogue(summary, dialogue)

prompt = tokenizer(text, return_tensors='pt').to('cuda')
prompt = {key: value for key, value in prompt.items()}
out = gpt_small.generate(**prompt, max_length=prompt['input_ids'].shape[1] + 10, do_sample=False)
print(tokenizer.decode(out[0]))

In [None]:
gpt_small.save_pretrained('./gptj_small')