In [None]:
!pip install transformers git+https://github.com/zer0sh0t/zer0t0rch
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [2]:
import re
import torch
import numpy as np
import transformers
from torch.utils.data import Dataset
from zer0t0rch import Zer0t0rchWrapper
from zer0t0rch.utils import clear_cache
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding

clear_cache()

In [None]:
model_name = 'gpt2'
num_epochs = 1
batch_size = 16
max_seq_len = 32

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model = AutoModelForCausalLM.from_pretrained(model_name)
collate_fn = DataCollatorWithPadding(tokenizer=tokenizer)

In [5]:
class GetDataset(Dataset):
    def __init__(self, file_path, tokenizer, max_seq_len):
        self.text = open(file_path, 'r').read()
        self.words = re.split(' ', self.text)
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

    def __len__(self):
        return 25000 #len(self.words) - self.max_seq_len

    def __getitem__(self, index):
        content = self.words[index: index + self.max_seq_len]
        tok_con = self.tokenizer(' '.join(content), max_length=self.max_seq_len, truncation=True)
        return tok_con

In [8]:
data = GetDataset('input.txt', tokenizer, max_seq_len)
print(len(data), data.text[:1000])

25000 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for reve

In [10]:
class ZTGPT2(Zer0t0rchWrapper):
    def __init__(self, model):
        super().__init__(model)

    def batch_mounting_logic(self, batch):
        batch = {k: v.to(self.device) for k, v in batch.items()}
        return batch

    def loop_forward_logic(self, batch):
        preds = self.model(**batch, labels=batch['input_ids']) # labels are shifted right inside the model
        outputs = (preds.loss,)
        return outputs

    def metric_calc_logic(self, metric_inps):
        _, loss = metric_inps
        metric_vals = {}
        for metric_name, metric_fn in self.metric_fns.items():
            metric_vals[metric_name] = metric_fn(loss.item())
        return metric_vals

In [12]:
loss_fn = lambda loss: loss
get_perplexity = lambda loss: np.exp(loss)
metric_fns = {'ppl': get_perplexity}

In [13]:
zt_model = ZTGPT2(model)
zt_model.prepare_data(data, batch_size=batch_size, collate_fn=collate_fn, val_pct=0.2)
zt_model.compile(loss_fn, metric_fns)

In [14]:
zt_model.fit(num_epochs)

train: epoch=0, loss=1.7622, ppl=10.8289: 100%|██████████| 1250/1250 [06:51<00:00,  3.04it/s]
 val : epoch=0, loss=0.5538, ppl=1.7438: 100%|██████████| 313/313 [00:29<00:00, 10.44it/s]


In [15]:
def generate(text, max_length):
    inputs = tokenizer(text, return_tensors='pt')
    inputs = inputs['input_ids'].to(device)
    outputs = zt_model.model.generate(inputs, max_length=max_length)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [18]:
output = generate('''hello there''', 100)
print(output)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


hello there's difference; but the fall of either
Makes the survivor heir of all.

AUFIDIUS:
I know it;
And my pretext to strike at him admits
A good construction. I raised him, and raised him:
The people cry him, and the nobility dined.


Citizens:
Noble Marcius, the great sword of the people!


SICINIUS:
The people cry him,
