In [1]:
import torch
from transformers import AutoTokenizer

import os
import warnings

warnings.filterwarnings("ignore")



files = []

for filename in os.listdir("data"):
    
    with open("data/" + filename, encoding="UTF-8") as f:
        files.append(f.read())

In [2]:
validation_text = " ".join(files[-3:])
train_text = " ".join(files[:-3])  #Last three files as validation

In [4]:
#Could be changed to other models
model_name = "gpt2-medium"

tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"


class LMDataset(torch.utils.data.Dataset):
    
    def __init__(self, text, tokenizer, block_size=98):
        
        text_tokenized = tokenizer(train_text)
        arr = text_tokenized["input_ids"]
        
        ret = []
        for i in range(0, len(arr), block_size):
            ret.append(arr[i: i+block_size])
        
        self.data = torch.Tensor(ret[:-1]).long().to(device) #Drop last element

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)
    
train_dataset = LMDataset(train_text, tokenizer)
val_dataset = LMDataset(validation_text, tokenizer)

In [6]:
from transformers import Trainer, TrainingArguments,AutoModelWithLMHead

model = AutoModelWithLMHead.from_pretrained(model_name)


training_args = TrainingArguments(
    num_train_epochs=3, # number of training epochs
    per_device_train_batch_size=16, # batch size for training
    per_device_eval_batch_size=64,  # batch size for evaluation
    eval_steps = 400, # Number of update steps between two evaluations.
    warmup_steps=500,# number of warmup steps for learning rate scheduler
    )


trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

Downloading:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

In [7]:
trainer.train()

Step,Training Loss
500,2.796303


TrainOutput(global_step=633, training_loss=2.7388785363750245)

In [8]:
trainer.save_model("model/")