In [1]:
import torch
import os
from torch._C import device
from torch.utils.data import Dataset, DataLoader
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AdamW,
    get_linear_schedule_with_warmup,
)
import torch.nn.functional as F
from tqdm import tqdm, trange

In [None]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=100,
    top_p=0.8,
    temperature=1.0,
):

    model.eval()

    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False

            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

            # Using top-p (nucleus sampling): https://github.com/huggingface/transformers/blob/master/examples/run_generation.py

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    F.softmax(sorted_logits, dim=-1), dim=-1
                )

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|EOS|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)

                    generated_list.append(output_text)
                    break

            if not entry_finished:
                output_list = list(generated.squeeze().numpy())
                output_text = f"{tokenizer.decode(output_list)}<|EOS|>"
                generated_list.append(output_text)

    return generated_list

In [7]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.load_state_dict(torch.load('../trained_models/medtext-final.pt', map_location=torch.device('cpu')))

<All keys matched successfully>

In [8]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [9]:
generated_comments = generate(
    model.to('cpu'),
    tokenizer,
    "<|BOS|>",
    entry_count=10
)

with open("results/generated_comments.txt", "w") as file:
    file.write("\n".join(generated_comments))

  0%|          | 0/10 [00:00<?, ?it/s]tensor([[  27,   91,   33, 2640,   91,   29]])
 10%|█         | 1/10 [00:18<02:48, 18.73s/it]tensor([[  27,   91,   33, 2640,   91,   29]])
 20%|██        | 2/10 [00:41<02:50, 21.31s/it]tensor([[  27,   91,   33, 2640,   91,   29]])
 30%|███       | 3/10 [01:09<02:49, 24.26s/it]tensor([[  27,   91,   33, 2640,   91,   29]])
 40%|████      | 4/10 [01:46<02:55, 29.24s/it]tensor([[  27,   91,   33, 2640,   91,   29]])
 50%|█████     | 5/10 [02:20<02:33, 30.80s/it]tensor([[  27,   91,   33, 2640,   91,   29]])
 60%|██████    | 6/10 [02:52<02:05, 31.29s/it]tensor([[  27,   91,   33, 2640,   91,   29]])
tensor([[  27,   91,   33, 2640,   91,   29]])
 70%|███████   | 7/10 [03:19<01:25, 28.46s/it]


KeyboardInterrupt: 