In [1]:
import torch
import os
from torch._C import device
from torch.utils.data import Dataset, DataLoader
from transformers import (
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AdamW,
    get_linear_schedule_with_warmup,
)
import torch.nn.functional as F
from tqdm import tqdm, trange

In [40]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=100,
    top_p=0.8,
    temperature=1.0,
):

    model.eval()

    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False

            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
            # Using top-p (nucleus sampling): https://github.com/huggingface/transformers/blob/master/examples/run_generation.py

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    F.softmax(sorted_logits, dim=-1), dim=-1
                )

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|EOS|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)

                    generated_list.append(output_text)
                    break

            if not entry_finished:
                output_list = list(generated.squeeze().numpy())
                output_text = f"{tokenizer.decode(output_list)}<|EOS|>"
                generated_list.append(output_text)

    return generated_list

In [6]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.load_state_dict(torch.load('../trained_models/medtext-final.pt', map_location=torch.device('cpu')))

<All keys matched successfully>

In [7]:
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [44]:
generated_comments = generate(
    model.to('cpu'),
    tokenizer,
    "<|BOS|>",
    entry_count=100
)

with open("../results/generated_comments.txt", "w") as file:
    file.write("\n".join(generated_comments))

100%|██████████| 100/100 [15:47<00:00,  9.48s/it]


In [36]:
input_ids = tokenizer.encode('<|BOS|>', return_tensors='pt').to('cuda')

sentences = [model.generate(
    input_ids=input_ids,
    do_sample=True, 
    max_length=150,
    top_k=20,
    top_p=0.92) for i in range(3)]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [37]:
decoded = [[tokenizer.decode(char) for char in sent] for sent in sentences]

In [38]:
decoded[1]

["<|BOS|>Anthropogenic susceptibility of the skin with a p53-encoded gene. The p53-encoded variant of the P53A gene (PRK) is known to be a potent inhibitor of the skin's natural immune response. This is demonstrated by studies involving an 11-year-old girl and a patient with an immunosuppressed skin. In this study, we examined the effects of PRK in the presence or absence of P53A in the skin of both the 11-year-old and the patient with an immunosuppressed skin. In all cases, P53A significantly inhibited the immune response in both the patients and those with a p53-encoded gene, while a P"]