# Fine-Tuning and Evaluation of DistilGPT2 Language Model

In [None]:
!nvidia-smi -L

In [None]:
%%capture
!pip install transformers datasets

In [None]:
from datasets import Dataset
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, set_seed
import torch
import re
import sys
import os

In [None]:
text_file = './combine.txt'
with open(text_file) as f:
    text = f.read()
my_dict = {"text": [text]}
dataset = Dataset.from_dict(my_dict)
dataset

Dataset({
    features: ['text'],
    num_rows: 1
})

In [None]:
checkpoint = "distilgpt2"
context_length = 128

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
config = AutoConfig.from_pretrained(checkpoint, n_ctx=context_length, pad_token_id=tokenizer.eos_token_id)
model = AutoModelForCausalLM.from_pretrained(checkpoint, config=config)

In [None]:
def tokenize(element):
    outputs = tokenizer(
        element["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = []
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}


dataset = dataset.map(
    tokenize, batched=True, remove_columns=dataset.column_names
)
dataset

  0%|          | 0/1 [00:00<?, ?ba/s]

Dataset({
    features: ['input_ids'],
    num_rows: 841
})

In [None]:
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

In [None]:
training_args = TrainingArguments(
    output_dir = f"{checkpoint}-mlm", per_device_train_batch_size=32,
    learning_rate=1e-4, lr_scheduler_type="cosine",
    logging_strategy="epoch", evaluation_strategy="no", save_strategy="no",
    num_train_epochs=50, log_level="error", report_to="none")

trainer = Trainer(
    model=model, tokenizer=tokenizer, data_collator=data_collator,
    args=training_args, train_dataset=dataset)

trainer.train()

trainer.save_model()



Step,Training Loss
27,3.5837
54,3.272
81,3.1079
108,2.9872
135,2.8678
162,2.7621
189,2.6671
216,2.5662
243,2.4696
270,2.3743


## Evaluating Fine-Tuned Text-Generation on Sample Sentence

In [None]:
checkpoint = "./distilgpt2-mlm"
context_length = 128

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
config = AutoConfig.from_pretrained(checkpoint, n_ctx=context_length, pad_token_id=tokenizer.eos_token_id)
model = AutoModelForCausalLM.from_pretrained(checkpoint, config=config).to('cuda')

In [None]:
def decode_output(output, start_tok=0):
  for i in range(output.shape[0]):
    print(tokenizer.decode(output[i, start_tok:]))

In [None]:
txt = "Ukrainian forces carried out counter offensives against Russian positions on Wednesday, seeking to inflict what"
input_ids = tokenizer(txt, return_tensors='pt')['input_ids'].to('cuda')
input_ids

In [None]:
set_seed(42)
output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10,
                        num_beams=10, return_dict_in_generate=True, output_scores=True, do_sample=False)
print(output['sequences_scores'])
decode_output(output['sequences'], len(input_ids[0]))

In [None]:
set_seed(42)
output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10, num_beams=10, do_sample=False)
decode_output(output, len(input_ids[0]))

In [None]:
set_seed(42)
output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10, do_sample=True, top_k=100)
decode_output(output, len(input_ids[0]))

In [None]:
set_seed(42)
output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10, do_sample=True, top_p=0.95)
decode_output(output, len(input_ids[0]))

In [None]:
set_seed(42)
output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10, do_sample=True, temperature=0.5, top_k=0)
decode_output(output, len(input_ids[0]))

### Evaluating Inference Time (important for text-editor)

In [None]:
set_seed(42)
%timeit output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10, num_beams=10, do_sample=False)

In [None]:
set_seed(42)
%timeit output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10, do_sample=True, top_k=100)

In [None]:
set_seed(42)
%timeit output = model.generate(input_ids, max_new_tokens=1, num_return_sequences=10, do_sample=True, top_p=0.95)

In [None]:
set_seed(42)
%timeit output = model.generate(input_ids, max_new_tokens=3, num_return_sequences=10, do_sample=True, temperature=0.5, top_k=0)

## Evaluating Prediction Accuracy on Out-of-Sample Articles

In [None]:
MODEL_CKPT = "./distilgpt2-mlm"
BANNED_TOKENS = [
    [12], [438], [532], [784], [851], [960], [1377], [11420],       # dashes
    [0], [1], [4], [6], [11], [13], [14], [25], [26],               # ! " % ' , / . : ;
    [338], [357], [366], [526], [553], [705], [720], [737], [828],  # 's ( " ." ," ' $ ). ),
    [1539], [1600], [1911], [2474], [2637], [7874], [14004]]        # ., ", ". !" .' .- ,''
CTX_LEN = 128

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_CKPT)
config = AutoConfig.from_pretrained(MODEL_CKPT, n_ctx=CTX_LEN, pad_token_id=tokenizer.eos_token_id)
model = AutoModelForCausalLM.from_pretrained(MODEL_CKPT, config=config).to(device)

In [None]:
TEXT_FILE = 'new1.txt'
# TEXT_FILE = 'new2.txt'
# TEXT_FILE = 'new3.txt'
fin = open(TEXT_FILE, 'r')
text_in = fin.read()
fin.close()

In [None]:
class Colors:
    Endc = "\033[0m"
    LightRed = "\033[91m"
    LightGreen = "\033[92m"
    LightYellow = "\033[93m"
    LightBlue = "\033[94m"
    LightMagenta = "\033[95m"
    LightCyan = "\033[96m"
    LightGray = "\033[37m"
    White = "\033[97m"

LOOKBACK = 28
NEW_TOKENS = 1 + 2
NUM_PREDS = 10
NUM_BEAMS = NUM_PREDS
EXTRA_PREDS = 0

OUTPUT_MISSES = False

pat_mkr = re.compile(r"(\^+)")
pat_mkr_end = re.compile(r"\^+$")
pat_sep = re.compile(r"[\s-]+")
pat_punct = re.compile(r"[()\".,;:!$%@–—-]")  # DON'T include "'"

fout = open('_' + os.path.basename(TEXT_FILE), "w")
# fout = sys.stdout

text_in = text_in.replace("\n", "^")    # insert markers for NL ...
text_in = pat_mkr.sub(r"\1 ", text_in)  # plus space for splitting 
words = pat_sep.split(text_in)
LB_BUF = [""] * (LOOKBACK - 1) + [words[0]]
print(LB_BUF[-1], end=" ", file=fout)
# print(LB_BUF[-1], end=" ", flush=True)

top1_hits, top5_hits, top10_hits = 0, 0, 0

for word in words[1:]:
    mkrs = pat_mkr_end.search(word)  # check for NL markers
    if mkrs:
        num_mkrs = len(mkrs.group())
        word = word[:-num_mkrs]
        sep = "\n" * num_mkrs
    else:
        sep = " "

    text_sep = " ".join(LB_BUF)
    text_sep = text_sep.replace("^", "")
    input_ids = tokenizer(text_sep, return_tensors="pt")["input_ids"].to(device)

    # Generate first/starter token
    output = model.generate(
        input_ids,
        max_new_tokens=1,
        min_length=len(input_ids),
        num_return_sequences=NUM_PREDS + EXTRA_PREDS,
        num_beams=NUM_BEAMS + EXTRA_PREDS,
        do_sample=False,
        bad_words_ids=BANNED_TOKENS
    )

    # Filter out unsuitable starter tokens
    next_first = []
    for i in range(output.shape[0]):
        next_first.append((i, tokenizer.decode(output[i, len(input_ids[0])]).strip()))
    # tmp = [(idx, word) for idx, word in next_first if pat_punct.match(word)]
    # if tmp: print(tmp[0][1], output[tmp[0][0], len(input_ids[0])])
    next_filt = [idx for idx, word in next_first if not pat_punct.match(word)]
    output = output[next_filt]  # remove filtered tokens

    # Generate subsequent tokens (for multi-token words)
    output = model.generate(
        output,
        max_new_tokens=NEW_TOKENS - 1,
        num_return_sequences=1,
        num_beams=1,
        do_sample=False,
    )

    # Split off hypenated words and truncate punctuations
    next_words = []
    for i in range(output.shape[0]):
        next_words.append(tokenizer.decode(output[i, len(input_ids[0]) :]).strip())
    next_words = [pat_sep.split(word)[0] for word in next_words if word]
    next_words = [word.rstrip("\"'.,;:!$%@–—-") for word in next_words]
    next_words = [word for word in next_words if word]
    next_words = next_words[:NUM_PREDS]  # limit to NUM_PREDS if EXTRA_PREDS > 0

    # Check for prediction hits
    word_nopunct = pat_punct.sub("", word)  # remove beginning/ending punctuations
    if word_nopunct in next_words:
        idx = next_words.index(word_nopunct)
        if idx == 0:
            attr = Colors.LightGreen   # Top-1
            top1_hits += 1
        elif idx < 5:
            attr = Colors.LightCyan    # Top-5
            top5_hits += 1
        else:
            attr = Colors.LightYellow  # Top-10
            top10_hits += 1
        s = f"{attr}{word}{Colors.Endc}{sep}"
        print(s, end="", file=fout)
        # print(s, end="", flush=True)
    else:
        if OUTPUT_MISSES:
            s = f"{Colors.LightRed}{'|'.join(next_words)}{Colors.Endc}{sep}"
            print(s, end="", file=fout)
            print(s, end="", flush=True)
        print(word + sep, end="", file=fout)
        # print(word + sep, end="", flush=True)
    LB_BUF = LB_BUF[1:] + [word]

total_hits = top1_hits + top5_hits + top10_hits
s = (
    f"\n{TEXT_FILE}>> LKBACK: {LOOKBACK}, PREDS: {NUM_PREDS}, EXTRA_PREDS: {EXTRA_PREDS}, NEW_TOKS: {NEW_TOKENS}"
    f" => #HITS: {top1_hits},{top5_hits},{top10_hits} / {len(words)} ({100*total_hits/len(words):.1f})%"
)
print(s, file=fout)
print(s)

fout.close()


new1.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 2, NEW_TOKS: 3 => #HITS: 439,343,135 / 1547 (59.3)%


## Evaluating Look-Back Sequence (Token) Length & Early-Stopping

In [None]:
### 'train_loss': 0.839 ###
# old1.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 1019,177,38 / 1374 (89.8)%
# new1.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 439,343,135 / 1547 (59.3)%
# new2.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 235,180,69 / 852 (56.8)%
# new3.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 321,260,106 / 1202 (57.2)%

# new1.txt>> LKBACK: 30, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 433,341,142 / 1547 (59.2)%
# new1.txt>> LKBACK: 26, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 429,359,130 / 1547 (59.3)%

# new1.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 2, NEW_TOKS: 3 => #HITS: 439,343,135 / 1547 (59.3)%

In [None]:
### 'train_loss': 0.072 ###
# new1.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 411,348,128 / 1547 (57.3)%
# new2.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 222,172,73 / 852 (54.8)%
# new3.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 314,240,111 / 1202 (55.3)%

In [None]:
### 'train_loss': 0.043 ###
# new1.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 406,334,138 / 1547 (56.8)%
# new2.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 210,177,68 / 852 (53.4)%
# new3.txt>> LKBACK: 28, PREDS: 10, EXTRA_PREDS: 0, NEW_TOKS: 3 => #HITS: 319,223,100 / 1202 (53.4)%