In [1]:

import numpy as np
import pandas as pd
import torch
import os
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
import torch.nn.functional as F
from tqdm import tqdm, trange

import gc

gc.collect()

torch.cuda.empty_cache()

In [5]:
class MedTextDataset(Dataset):
    
    def __init__(self, path_file, truncate=False, gpt2_type="gpt2", max_length=768):

        self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
        self.comments = []
        
        with open(path_file) as file:
            comments = file.readlines()
            
            self.comments = [self.comment_to_tensor(c, max_length) for c in comments]
            
                
        if truncate:
            self.comments = self.comments[:20000]
        self.comments_count = len(self.comments)
        
        
    def comment_to_tensor(self, comment, max_length):
        return torch.tensor(self.tokenizer.encode(comment[:max_length]))
        
    def __len__(self):
        return self.comments_count

    def __getitem__(self, item):
        return self.comments[item]

In [6]:
DATA_TRAIN_PRUEBA = '../tagged_files/train_tagged_prueba.txt'
train_p_dataset = MedTextDataset(DATA_TRAIN_PRUEBA)

In [7]:
len(train_p_dataset)

9

In [8]:
def pack_tensor(new_tensor, packed_tensor, max_seq_len):
    if packed_tensor is None:
        return new_tensor, True, None
    if new_tensor.size()[1] + packed_tensor.size()[1] > max_seq_len:
        return packed_tensor, False, new_tensor
    else:
        packed_tensor = torch.cat([new_tensor, packed_tensor[:, 1:]], dim=1)
        return packed_tensor, True, None

In [9]:
def train(
    dataset,
    model,
    batch_size=16,
    epochs=4,
    lr=2e-5,
    max_seq_len=400,
    warmup_steps=5000,
    gpt2_type="gpt2",
    device="cuda",
    output_dir=".",
    output_prefix="medtex",
    test_mode=False,
    save_model_on_epoch=False,
):

    acc_steps = 100

    torch.cuda.empty_cache()

    model = model.to(device)
    model.train()

    optimizer = AdamW(model.parameters(), lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=warmup_steps, num_training_steps=-1
    )

    train_dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

    accumulating_batch_count = 0
    input_tensor = None

    for epoch in range(epochs):

        print(f"Training epoch {epoch}")
        for idx, entry in tqdm(enumerate(train_dataloader)):
            (input_tensor, carry_on, remainder) = pack_tensor(entry, input_tensor, 768)

            if carry_on and idx != len(train_dataloader) - 1:
                continue

            input_tensor = input_tensor.to(device)
            outputs = model(input_tensor, labels=input_tensor)
            loss = outputs[0]
            loss.backward()

            if (accumulating_batch_count % batch_size) == 0:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                model.zero_grad()

            accumulating_batch_count += 1
            input_tensor = None
        if save_model_on_epoch:
            torch.save(
                model.state_dict(),
                os.path.join(output_dir, f"{output_prefix}-{epoch}.pt"),
            )
    return model

In [10]:
gpt2_type = "gpt2"

In [14]:

model = train(
    train_p_dataset,
    GPT2LMHeadModel.from_pretrained(gpt2_type),
    batch_size=64,
    epochs=3,
    lr=2e-5,
    max_seq_len=140,
    warmup_steps=200,
    gpt2_type=gpt2_type,
    device="cpu",
    output_dir="trained_models",
    output_prefix="medtext",
    save_model_on_epoch=False
)

0it [00:00, ?it/s]Training epoch 0
9it [00:28,  3.22s/it]
0it [00:00, ?it/s]Training epoch 1
9it [00:31,  3.53s/it]
0it [00:00, ?it/s]Training epoch 2
9it [00:30,  3.39s/it]


In [15]:
def generate(
    model,
    tokenizer,
    prompt,
    entry_count=10,
    entry_length=100,
    top_p=0.8,
    temperature=1.,
):

    model.eval()

    generated_num = 0
    generated_list = []

    filter_value = -float("Inf")

    with torch.no_grad():

        for entry_idx in trange(entry_count):

            entry_finished = False

            generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)

            # Using top-p (nucleus sampling): https://github.com/huggingface/transformers/blob/master/examples/run_generation.py

            for i in range(entry_length):
                outputs = model(generated, labels=generated)
                loss, logits = outputs[:2]
                logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)

                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(
                    F.softmax(sorted_logits, dim=-1), dim=-1
                )

                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
                    ..., :-1
                ].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices[sorted_indices_to_remove]
                logits[:, indices_to_remove] = filter_value

                next_token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
                generated = torch.cat((generated, next_token), dim=1)

                if next_token in tokenizer.encode("<|EOS|>"):
                    entry_finished = True

                if entry_finished:

                    generated_num = generated_num + 1

                    output_list = list(generated.squeeze().numpy())
                    output_text = tokenizer.decode(output_list)

                    generated_list.append(output_text)
                    break
            
            if not entry_finished:
                output_list = list(generated.squeeze().numpy())
                output_text = f"{tokenizer.decode(output_list)}<|EOS|>" 
                generated_list.append(output_text)
                
    return generated_list

In [16]:
generated_tweets = generate(model.to("cpu"), GPT2Tokenizer.from_pretrained(gpt2_type), "<|BOS|>", entry_count=10)

100%|██████████| 10/10 [00:50<00:00,  5.01s/it]


In [17]:
generated_tweets

["<|BOS|> (but still quite complex) * To me, this whole project is rather simple: Let's figure out what the three top boxes are about. Now, we have to draw the bases of each in our basic 3D cube: Point-by-point.\n\n* Point-by-point construction of triangular bases * Point-by-point construction of black box bases * Point-by-point construction of hard box bases * Point-by-point construction of base shapes (We could write<|EOS|>",
 '<|BOS|>=-|',
 '<|BOS|>0.1|',
 "<|BOS|>/r0 -z '<",
 '<|BOS|>|',
 '<|BOS|>|',
 '<|BOS|>--|',
 '<|BOS|> the product value $O:\n\nI can\'t help but think that maybe the functional side (such as interface passing or parsing) could have done this and probably replaced something like:\n\n-- ; Promise wrapping the $O -> event. pipe ([]( " request ", function (){ return Promise. create ( this, $ O ); });\n\nI just don\'t see how that could\'ve looked any better, and I\'m sure some people are still confused about this "feature" and<|EOS|>',
 '<|BOS|> =|',
 '<|BOS|>BOS']