In [1]:
import joblib

In [2]:
train_embeddings = joblib.load('train_embeddings_logits_only.joblib')

In [3]:
train_embeddings[0].keys()

dict_keys(['logits_and_indices', 'input_ids'])

### Training from embeddings

In [4]:
import torch
import transformers
import torch.nn.functional as F

from tqdm import tqdm
from transformers import GPT2LMHeadModel, AutoTokenizer

In [5]:
import torch
import transformers
import torch.nn.functional as F

from tqdm import tqdm
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJConfig

config = transformers.GPTJConfig(n_embd=256,
                                 n_layer=28,
                                 n_head=16)

gpt_small = GPTJForCausalLM(config)
gpt_small.cuda()
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")

In [6]:
def create_text_from_summary_and_dialogue(summary, dialogue):
    text = f"""
A partial summary of the conversation is:
{summary}

With the dialogue being:
{dialogue}
    """.strip()
    
    return text.replace('\r\n', '\n')

In [7]:
summary = """
Alberto is a customer. Alberto ordered pizza and lasagne at Dominos one hour ago.
Alberto wants to know where his pizza is. Alberto is calling Dominos to know where his orders are.
John works at Dominos. John says that his pizza is almost ready. It will be at Alberto's home in 10 minutes. 
The lasagne is already on its way with the delivery person. Unfortunately the driver had an accident and is now in the hospital.
If alberto says "rabbit" then John replies "horse". If Alberto says "Tree" John says "red".
""".strip().replace('\n', ' ')

dialogue = """
Alberto: What happened to my pizza?
John: 
""".strip()

text = create_text_from_summary_and_dialogue(summary, dialogue)

prompt = tokenizer(text, return_tensors='pt').to('cuda')
prompt = {key: value for key, value in prompt.items()}

In [8]:
def chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def batchify(data, n):
    len_dict = {}
    for item in data:
        length = item.shape[1]
        try:
            len_dict[length].append(item)
        except:
            len_dict[length] = [item]

    batch_chunks = []
    for k in len_dict.keys():
        vectors = len_dict[k]
        batch_chunks += chunks(vectors, n)

    batches = []
    for chunk in batch_chunks:
        inputs = torch.stack([item[0] for item in chunk])
        batches.append((inputs))

    return batches

In [9]:
import json

val = json.load(open('../data/val.json'))

_limit = 1024
dev_data = []
total_skipped = 0
for item in val:
    text = create_text_from_summary_and_dialogue(item["summary"], item["dialogue"])
    tokens = tokenizer.encode(text, return_tensors='pt')
    if tokens.shape[1] > _limit:
        tokens = tokens[:, :_limit]
    dev_data.append(tokens)
    
print(f'Skipped {total_skipped} out of {len(val)}')

dev_batches = batchify(dev_data, 1)

def test(test_model, batches):
    test_model.eval()
    total_loss = 0.
    for i, batch in enumerate(batches):
        test_model.eval()
        inputs = batch
        loss = test_model(inputs.cuda(), labels=inputs.cuda())[0]
        total_loss += loss.item()

    return total_loss / len(batches)

Skipped 0 out of 818


In [10]:
_ = gpt_small.cuda()

In [11]:
print('Dev loss:', test(gpt_small, dev_batches))

Dev loss: 10.890254822803302


In [12]:
def get_probability_vector(log_prob_dict, temp):
    _vocab_size = 50400
    
    logits = torch.tensor(log_prob_dict['logits'])
    num_tokens = logits.shape[1]
    indices = torch.tensor(log_prob_dict['indices'])
    vectors = []
    
    for index_set, logs in zip(indices[0], logits[0]):
        v = torch.sparse_coo_tensor([index_set.tolist()], logs, (_vocab_size, )).to_dense().float()
        v[v == 0] = torch.tensor(float('-inf'))
        vectors.append(v)

    vectors = torch.stack(vectors, dim=0)
    return F.softmax(vectors / temp, dim=-1)

In [13]:
get_probability_vector(train_embeddings[0]['logits_and_indices'], temp=10).shape

torch.Size([56, 50400])

In [14]:
import random
from torch.optim.lr_scheduler import StepLR

lr = 1e-4
optimizer = torch.optim.Adam(gpt_small.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=2, gamma=0.5)
epochs = 50

steps = 0

for epoch_num in range(epochs):
    gpt_small.train()
    temp = 30
    random.shuffle(train_embeddings)
    
    for item in tqdm(train_embeddings):
        input_ids = torch.tensor([item['input_ids']]).cuda()
        label_p = get_probability_vector(item['logits_and_indices'], temp=temp).cuda()
        output = gpt_small.forward(input_ids, labels=input_ids)
        out_logits = output.logits
        loss = output.loss
        out_p = F.softmax(out_logits / temp, dim=-1)
        
        loss -=  temp * temp * torch.mean(torch.mul(torch.log(out_p).flatten(),
                                          label_p.flatten()))

        loss.backward()

        optimizer.step()
        optimizer.zero_grad()
        
        steps += 1
        
        if steps % 2000 == 0:
            print("steps", steps)
            print('Dev loss:', test(gpt_small, dev_batches))
            
    scheduler.step()

 14%|█████▏                                | 1999/14732 [04:43<29:45,  7.13it/s]

steps 2000


 14%|████▊                              | 2001/14732 [05:12<21:33:53,  6.10s/it]

Dev loss: 4.69627631555851


 27%|██████████▎                           | 3999/14732 [09:56<26:22,  6.78it/s]

steps 4000


 27%|█████████▌                         | 4001/14732 [10:25<18:18:17,  6.14s/it]

Dev loss: 4.391282944924091


 41%|███████████████▍                      | 5999/14732 [15:09<20:36,  7.06it/s]

steps 6000


 41%|██████████████▎                    | 6001/14732 [15:37<14:54:24,  6.15s/it]

Dev loss: 4.2294482942898295


 54%|████████████████████▋                 | 7999/14732 [20:21<14:58,  7.49it/s]

steps 8000


 54%|███████████████████                | 8001/14732 [20:50<11:29:34,  6.15s/it]

Dev loss: 4.142407263986639


 68%|█████████████████████████▊            | 9999/14732 [25:33<10:35,  7.45it/s]

steps 10000


 68%|███████████████████████▊           | 10001/14732 [26:02<8:01:05,  6.10s/it]

Dev loss: 4.0790380615476876


 81%|██████████████████████████████▏      | 11999/14732 [30:46<06:04,  7.51it/s]

steps 12000


 81%|████████████████████████████▌      | 12001/14732 [31:15<4:36:15,  6.07s/it]

Dev loss: 4.0198832607502455


 95%|███████████████████████████████████▏ | 13999/14732 [35:59<01:47,  6.81it/s]

steps 14000


 95%|█████████████████████████████████▎ | 14001/14732 [36:28<1:14:50,  6.14s/it]

Dev loss: 3.9691521854738734


100%|█████████████████████████████████████| 14732/14732 [38:12<00:00,  6.43it/s]
  9%|███▎                                  | 1267/14732 [03:00<32:26,  6.92it/s]

steps 16000


  9%|███                                | 1269/14732 [03:29<22:56:23,  6.13s/it]

Dev loss: 3.9203810117646944


 22%|████████▍                             | 3267/14732 [08:13<27:40,  6.90it/s]

steps 18000


 22%|███████▊                           | 3269/14732 [08:42<19:35:50,  6.15s/it]

Dev loss: 3.872706271616929


 36%|█████████████▌                        | 5267/14732 [13:27<22:48,  6.92it/s]

steps 20000


 36%|████████████▌                      | 5269/14732 [13:56<16:12:22,  6.17s/it]

Dev loss: 3.8371031150258257


 49%|██████████████████▋                   | 7267/14732 [18:41<18:02,  6.89it/s]

steps 22000


 49%|█████████████████▎                 | 7269/14732 [19:10<12:48:05,  6.18s/it]

Dev loss: 3.817702235831608


 63%|███████████████████████▉              | 9267/14732 [23:54<13:13,  6.88it/s]

steps 24000


 63%|██████████████████████▋             | 9269/14732 [24:23<9:20:28,  6.16s/it]

Dev loss: 3.7863733732029976


 76%|████████████████████████████▎        | 11267/14732 [29:08<08:19,  6.94it/s]

steps 26000


 76%|██████████████████████████▊        | 11269/14732 [29:37<5:51:24,  6.09s/it]

Dev loss: 3.747645101657998


 90%|█████████████████████████████████▎   | 13267/14732 [34:21<03:31,  6.92it/s]

steps 28000


 90%|███████████████████████████████▌   | 13269/14732 [34:50<2:30:02,  6.15s/it]

Dev loss: 3.714235064161436


100%|█████████████████████████████████████| 14732/14732 [38:18<00:00,  6.41it/s]
  4%|█▍                                     | 535/14732 [01:16<35:24,  6.68it/s]

steps 30000


  4%|█▎                                  | 537/14732 [01:44<24:06:53,  6.12s/it]

Dev loss: 3.663079282619551


 17%|██████▌                               | 2535/14732 [06:29<29:02,  7.00it/s]

steps 32000


 17%|██████                             | 2537/14732 [06:58<20:44:52,  6.12s/it]

Dev loss: 3.6485285493738786


 31%|███████████▋                          | 4535/14732 [11:42<23:27,  7.24it/s]

steps 34000


 31%|██████████▊                        | 4537/14732 [12:11<17:26:44,  6.16s/it]

Dev loss: 3.639877280133861


 44%|████████████████▊                     | 6535/14732 [16:55<18:57,  7.20it/s]

steps 36000


 44%|███████████████▌                   | 6537/14732 [17:24<13:52:53,  6.10s/it]

Dev loss: 3.6205853534211156


 58%|██████████████████████                | 8535/14732 [22:08<14:30,  7.12it/s]

steps 38000


 58%|████████████████████▎              | 8537/14732 [22:37<10:35:31,  6.16s/it]

Dev loss: 3.603582900718838


 72%|██████████████████████████▍          | 10535/14732 [27:22<10:14,  6.82it/s]

steps 40000


 72%|█████████████████████████          | 10537/14732 [27:51<7:06:42,  6.10s/it]

Dev loss: 3.598429406038998


 85%|███████████████████████████████▍     | 12535/14732 [32:36<05:11,  7.05it/s]

steps 42000


 85%|█████████████████████████████▊     | 12537/14732 [33:04<3:42:39,  6.09s/it]

Dev loss: 3.5819547250684725


 99%|████████████████████████████████████▌| 14535/14732 [37:47<00:29,  6.78it/s]

steps 44000


 99%|████████████████████████████████████▌| 14537/14732 [38:16<19:59,  6.15s/it]

Dev loss: 3.563764197027771


100%|█████████████████████████████████████| 14732/14732 [38:44<00:00,  6.34it/s]
 12%|████▋                                 | 1803/14732 [04:16<31:42,  6.80it/s]

steps 46000


 12%|████▎                              | 1805/14732 [04:45<22:03:36,  6.14s/it]

Dev loss: 3.579329483520722


 26%|█████████▊                            | 3803/14732 [09:30<25:55,  7.03it/s]

steps 48000


 26%|█████████                          | 3805/14732 [09:58<18:37:39,  6.14s/it]

Dev loss: 3.577633103414969


 39%|██████████████▉                       | 5803/14732 [14:43<21:36,  6.88it/s]

steps 50000


 39%|█████████████▊                     | 5805/14732 [15:12<15:07:47,  6.10s/it]

Dev loss: 3.565684823360303


 53%|████████████████████▏                 | 7803/14732 [19:56<16:31,  6.99it/s]

steps 52000


 53%|██████████████████▌                | 7805/14732 [20:24<11:43:59,  6.10s/it]

Dev loss: 3.556143230651585


 67%|█████████████████████████▎            | 9803/14732 [25:09<10:55,  7.52it/s]

steps 54000


 67%|███████████████████████▉            | 9805/14732 [25:38<8:26:54,  6.17s/it]

Dev loss: 3.5513140897296167


 80%|█████████████████████████████▋       | 11803/14732 [30:24<07:05,  6.89it/s]

steps 56000


 80%|████████████████████████████       | 11805/14732 [30:54<5:04:21,  6.24s/it]

Dev loss: 3.521962439372662


 94%|██████████████████████████████████▋  | 13803/14732 [35:35<02:13,  6.94it/s]

steps 58000


 94%|████████████████████████████████▊  | 13805/14732 [36:04<1:35:12,  6.16s/it]

Dev loss: 3.5210172780277094


100%|█████████████████████████████████████| 14732/14732 [38:16<00:00,  6.41it/s]
  7%|██▊                                   | 1071/14732 [02:31<32:40,  6.97it/s]

steps 60000


  7%|██▌                                | 1073/14732 [03:00<23:23:13,  6.16s/it]

Dev loss: 3.5286517893772547


 21%|███████▉                              | 3071/14732 [07:45<27:38,  7.03it/s]

steps 62000


 21%|███████▎                           | 3073/14732 [08:14<19:54:40,  6.15s/it]

Dev loss: 3.5322389334513096


 34%|█████████████                         | 5071/14732 [12:58<23:04,  6.98it/s]

steps 64000


 34%|████████████                       | 5073/14732 [13:27<16:19:36,  6.09s/it]

Dev loss: 3.5466338624872615


 48%|██████████████████▏                   | 7071/14732 [18:13<18:30,  6.90it/s]

steps 66000


 48%|████████████████▊                  | 7073/14732 [18:42<13:10:14,  6.19s/it]

Dev loss: 3.5310269849224603


 62%|███████████████████████▍              | 9071/14732 [23:41<13:46,  6.85it/s]

steps 68000


 62%|█████████████████████▌             | 9073/14732 [24:12<10:06:56,  6.44s/it]

Dev loss: 3.5400714648382006


 75%|███████████████████████████▊         | 11071/14732 [28:58<09:06,  6.69it/s]

steps 70000


 75%|██████████████████████████▎        | 11073/14732 [29:29<6:32:18,  6.43s/it]

Dev loss: 3.525749698973518


 89%|████████████████████████████████▊    | 13071/14732 [34:15<04:03,  6.83it/s]

steps 72000


 89%|███████████████████████████████    | 13073/14732 [34:44<2:49:16,  6.12s/it]

Dev loss: 3.528985671571531


100%|█████████████████████████████████████| 14732/14732 [38:44<00:00,  6.34it/s]
  2%|▉                                      | 339/14732 [00:47<33:26,  7.17it/s]

steps 74000


  2%|▊                                   | 341/14732 [01:16<24:17:29,  6.08s/it]

Dev loss: 3.543033439547625


 16%|██████                                | 2339/14732 [06:00<30:10,  6.85it/s]

steps 76000


 16%|█████▌                             | 2341/14732 [06:30<22:22:48,  6.50s/it]

Dev loss: 3.5690258428053636


 29%|███████████▏                          | 4339/14732 [11:16<26:16,  6.59it/s]

steps 78000


 29%|██████████▎                        | 4341/14732 [11:46<18:19:08,  6.35s/it]

Dev loss: 3.5739358602351254


 43%|████████████████▎                     | 6339/14732 [16:33<20:16,  6.90it/s]

steps 80000


 43%|███████████████                    | 6341/14732 [17:01<14:12:18,  6.09s/it]

Dev loss: 3.5666958659085783


 57%|█████████████████████▌                | 8339/14732 [21:46<15:26,  6.90it/s]

steps 82000


 57%|███████████████████▊               | 8341/14732 [22:15<10:50:32,  6.11s/it]

Dev loss: 3.555265073962783


 70%|█████████████████████████▉           | 10339/14732 [27:14<11:09,  6.56it/s]

steps 84000


 70%|████████████████████████▌          | 10341/14732 [27:45<7:59:34,  6.55s/it]

Dev loss: 3.5492198190362645


 84%|██████████████████████████████▉      | 12339/14732 [32:49<06:09,  6.48it/s]

steps 86000


 84%|█████████████████████████████▎     | 12341/14732 [33:19<4:16:03,  6.43s/it]

Dev loss: 3.5577098438675656


 97%|████████████████████████████████████ | 14339/14732 [38:16<01:06,  5.88it/s]

steps 88000


 97%|████████████████████████████████████ | 14341/14732 [38:45<40:53,  6.27s/it]

Dev loss: 3.555036128150222


100%|█████████████████████████████████████| 14732/14732 [39:42<00:00,  6.18it/s]
 11%|████▏                                 | 1607/14732 [03:49<30:39,  7.14it/s]

steps 90000


 11%|███▊                               | 1609/14732 [04:18<22:18:28,  6.12s/it]

Dev loss: 3.5795010791722603


 24%|█████████▎                            | 3607/14732 [09:09<26:29,  7.00it/s]

steps 92000


 24%|████████▌                          | 3609/14732 [09:38<19:14:11,  6.23s/it]

Dev loss: 3.590947471504398


 38%|██████████████▍                       | 5607/14732 [14:23<22:01,  6.90it/s]

steps 94000


 38%|█████████████▎                     | 5609/14732 [14:52<15:33:19,  6.14s/it]

Dev loss: 3.6075165229496573


 52%|███████████████████▌                  | 7607/14732 [19:36<16:42,  7.11it/s]

steps 96000


 52%|██████████████████                 | 7609/14732 [20:05<12:02:47,  6.09s/it]

Dev loss: 3.5982551297815037


 65%|████████████████████████▊             | 9607/14732 [24:51<12:10,  7.01it/s]

steps 98000


 65%|███████████████████████▍            | 9609/14732 [25:20<8:46:09,  6.16s/it]

Dev loss: 3.605612377374271


 79%|█████████████████████████████▏       | 11607/14732 [30:05<06:57,  7.49it/s]

steps 100000


 79%|███████████████████████████▌       | 11609/14732 [30:34<5:19:06,  6.13s/it]

Dev loss: 3.6022543947970664


 92%|██████████████████████████████████▏  | 13607/14732 [35:19<02:38,  7.11it/s]

steps 102000


 92%|████████████████████████████████▎  | 13609/14732 [35:47<1:54:24,  6.11s/it]

Dev loss: 3.6092896942404487


100%|█████████████████████████████████████| 14732/14732 [38:28<00:00,  6.38it/s]
  6%|██▎                                    | 875/14732 [02:04<33:26,  6.91it/s]

steps 104000


  6%|██▏                                 | 877/14732 [02:33<23:39:29,  6.15s/it]

Dev loss: 3.6287426971864583


 20%|███████▍                              | 2875/14732 [07:17<28:50,  6.85it/s]

steps 106000


 20%|██████▊                            | 2877/14732 [07:46<20:13:02,  6.14s/it]

Dev loss: 3.6495630286433585


 33%|████████████▌                         | 4875/14732 [12:31<23:50,  6.89it/s]

steps 108000


 33%|███████████▌                       | 4877/14732 [13:00<16:40:02,  6.09s/it]

Dev loss: 3.65369100442434


 47%|█████████████████▋                    | 6875/14732 [17:45<19:18,  6.78it/s]

steps 110000


 47%|████████████████▎                  | 6877/14732 [18:14<13:19:51,  6.11s/it]

Dev loss: 3.6616971571462953


 60%|██████████████████████▉               | 8875/14732 [23:07<14:52,  6.56it/s]

steps 112000


 60%|█████████████████████              | 8877/14732 [23:38<10:38:12,  6.54s/it]

Dev loss: 3.6555186304603056


 74%|███████████████████████████▎         | 10875/14732 [28:26<09:00,  7.13it/s]

steps 114000


 74%|█████████████████████████▊         | 10877/14732 [28:54<6:34:01,  6.13s/it]

Dev loss: 3.652967402725173


 87%|████████████████████████████████▎    | 12875/14732 [33:38<04:31,  6.84it/s]

steps 116000


 87%|██████████████████████████████▌    | 12877/14732 [34:07<3:09:35,  6.13s/it]

Dev loss: 3.6461047454393287


100%|█████████████████████████████████████| 14732/14732 [38:31<00:00,  6.37it/s]
  1%|▍                                      | 143/14732 [00:20<35:26,  6.86it/s]

steps 118000


  1%|▎                                   | 145/14732 [00:49<24:51:23,  6.13s/it]

Dev loss: 3.6541491539962134


 15%|█████▌                                | 2143/14732 [05:33<31:18,  6.70it/s]

steps 120000


 15%|█████                              | 2145/14732 [06:02<21:27:22,  6.14s/it]

Dev loss: 3.675419067112333


 28%|██████████▋                           | 4143/14732 [10:46<24:54,  7.09it/s]

steps 122000


 28%|█████████▊                         | 4145/14732 [11:15<18:00:59,  6.13s/it]

Dev loss: 3.694505612104038


 42%|███████████████▊                      | 6143/14732 [15:59<19:35,  7.31it/s]

steps 124000


 42%|██████████████▌                    | 6145/14732 [16:28<14:27:56,  6.06s/it]

Dev loss: 3.696727312135813


 55%|█████████████████████                 | 8143/14732 [21:12<15:32,  7.07it/s]

steps 126000


 55%|███████████████████▎               | 8145/14732 [21:41<11:19:13,  6.19s/it]

Dev loss: 3.701027455714629


 69%|█████████████████████████▍           | 10143/14732 [26:26<10:46,  7.10it/s]

steps 128000


 69%|████████████████████████           | 10145/14732 [26:54<7:44:52,  6.08s/it]

Dev loss: 3.701933394580132


 82%|██████████████████████████████▍      | 12143/14732 [31:37<06:06,  7.07it/s]

steps 130000


 82%|████████████████████████████▊      | 12145/14732 [32:06<4:22:14,  6.08s/it]

Dev loss: 3.7095223975647924


 96%|███████████████████████████████████▌ | 14143/14732 [36:50<01:25,  6.88it/s]

steps 132000


 96%|█████████████████████████████████▌ | 14145/14732 [37:19<1:00:05,  6.14s/it]

Dev loss: 3.71284573527012


100%|█████████████████████████████████████| 14732/14732 [38:42<00:00,  6.34it/s]
 10%|███▋                                  | 1411/14732 [03:21<33:15,  6.68it/s]

steps 134000


 10%|███▎                               | 1413/14732 [03:49<22:42:05,  6.14s/it]

Dev loss: 3.73190962379311


 23%|████████▊                             | 3411/14732 [08:34<25:37,  7.36it/s]

steps 136000


 23%|████████                           | 3413/14732 [09:03<19:10:04,  6.10s/it]

Dev loss: 3.73450775980075


 37%|█████████████▉                        | 5411/14732 [13:47<21:29,  7.23it/s]

steps 138000


 37%|████████████▊                      | 5413/14732 [14:15<15:52:52,  6.14s/it]

Dev loss: 3.7475992301274044


 50%|███████████████████                   | 7411/14732 [18:59<16:57,  7.20it/s]

steps 140000


 50%|█████████████████▌                 | 7413/14732 [19:28<12:30:41,  6.15s/it]

Dev loss: 3.751663234822616


 52%|███████████████████▊                  | 7680/14732 [20:06<18:28,  6.36it/s]


KeyboardInterrupt: 

In [15]:
_ = gpt_small.eval()

In [16]:
device = 'cuda'

In [17]:
dialogue = """
John: Hello, how can I help?
""".strip()


print(dialogue)

while True:
    user_input = input()
    dialogue += "\nAlberto: " + user_input + "\nJohn: "
    text = create_text_from_summary_and_dialogue(summary, dialogue)
    prompt = tokenizer(text, return_tensors='pt')
    prompt = {key: value.to(device) for key, value in prompt.items()}
    out = gpt_small.generate(**prompt, max_length=prompt['input_ids'].shape[1] + 25, do_sample=False)
    out = out[0][prompt['input_ids'].shape[1]:]
    answer = tokenizer.decode(out)
    answer = answer[:answer.find('\n')].strip()
    print(answer)
    dialogue += answer

John: Hello, how can I help?
hello what did I order


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


I'm at home already, but the last time is?


KeyboardInterrupt: Interrupted by user

In [None]:
%%time

summary = """
Alberto is a customer. Alberto ordered pizza and lasagne at Dominos one hour ago.
Alberto wants to know where his pizza is. Alberto is calling Dominos to know where his orders are.
John works at Dominos. John says that his pizza is almost ready. It will be at Alberto's home in 10 minutes. 
The lasagne is already on its way with the delivery person. Unfortunately the driver had an accident and is now in the hospital.
If alberto says "rabbit" then John replies "horse". If Alberto says "Tree" John says "red".
""".strip().replace('\n', ' ')

dialogue = """
Alberto: What happened to my pizza?
John: It's in the delivery man's car.
Alberto: And where is the delivery man?
""".strip()

text = create_text_from_summary_and_dialogue(summary, dialogue)

prompt = tokenizer(text, return_tensors='pt').to('cuda')
prompt = {key: value for key, value in prompt.items()}
out = gpt_small.generate(**prompt, max_length=prompt['input_ids'].shape[1] + 10, do_sample=False)
print(tokenizer.decode(out[0]))

In [None]:
gpt_small.save_pretrained('./gptj_small')