In [None]:
import torch

import random
import pandas as pd

from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch, gc
from rouge_score import rouge_scorer

from src.data_utils import clean_string, split_x_target_by_words
import warnings
warnings.filterwarnings('ignore')

In [3]:
random.seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x114e8acf0>

## Get the data

### Data Preparation



In [4]:
df = pd.read_csv("./data/tweets_cleaned.csv")

In [5]:
val_test_size = 0.20
test_size = 0.50

train_texts, val_test_texts = train_test_split(list(df["cleaned_text"]), test_size=val_test_size, random_state=42)
print(f"Train texts: {len(train_texts)}, Val_Test texts: {len(val_test_texts)}")
val_texts, test_texts = train_test_split(val_test_texts, test_size=test_size, random_state=42)
print(f"Val texts: {len(val_texts)}, Test texts: {len(test_texts)}")

Train texts: 1280000, Val_Test texts: 320000
Val texts: 160000, Test texts: 160000


### Pre-trained

In [9]:
def build_eval_pairs(texts, max_examples: int = 2000):
    """Build validation x/targets from splited df parts"""
    pairs = []
    for t in texts:
        t = clean_string(t)
        if not t:
            continue
        x, target = split_x_target_by_words(t)
        if x and target:
            pairs.append((x, target))
        if len(pairs) >= max_examples:
            break
    X = [p for p,_ in pairs]
    targets = [r for _,r in pairs]
    return X, targets

val_prefixes, val_refs = build_eval_pairs(test_texts, max_examples=1500)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available()
          else "mps" if torch.backends.mps.is_available()
          else "cpu")

model_name = "distilbert/distilgpt2"

gpt_tk = AutoTokenizer.from_pretrained(model_name, padding_side="left")

if gpt_tk.pad_token_id is None:
    gpt_tk.pad_token = gpt_tk.eos_token

gpt_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=(torch.float16 if device.type in {"cuda","mps"} else None),
    low_cpu_mem_usage=True
    ).to(device).eval()

gpt_model.config.pad_token_id = gpt_tk.pad_token_id

In [None]:
@torch.inference_mode()
def complete_with_distilgpt2_batched(X, *, max_new_tokens=64, do_sample=True,
                                     top_k=50, top_p=0.95, temperature=0.8,
                                     repetition_penalty=1.1):
    enc = gpt_tk(X, return_tensors="pt", padding=True, truncation=True)
    input_ids = enc["input_ids"].to(device)
    attention_mask = enc["attention_mask"].to(device)
    gen_cfg = GenerationConfig(
        max_new_tokens=max_new_tokens, do_sample=do_sample,
        top_k=top_k, top_p=top_p, temperature=temperature,
        repetition_penalty=repetition_penalty,
        eos_token_id=gpt_tk.eos_token_id, pad_token_id=gpt_tk.pad_token_id,
        use_cache=True  # default, keep for speed
    )
    out = gpt_model.generate(input_ids=input_ids,
                             attention_mask=attention_mask,
                             generation_config=gen_cfg)
    
    # slice tail per example using true X lengths
    lens = attention_mask.sum(dim=1)
    gens = []
    for i in range(out.size(0)):
        start = int(lens[i].item())
        gen_ids = out[i, start:]
        gens.append(gpt_tk.decode(gen_ids, skip_special_tokens=True))
    return gens

In [12]:
scorer = rouge_scorer.RougeScorer(["rouge1","rouge2"], use_stemmer=False)

def eval_rouge_model_batched(X, targets, gen_params, bs=32, max_new_tokens=64):
    r1 = r2 = n = 0
    for s in range(0, len(X), bs):
        batch_p = X[s:s+bs]
        batch_r = targets[s:s+bs]
        preds = complete_with_distilgpt2_batched(batch_p, max_new_tokens=max_new_tokens, **gen_params)
        for p, g in zip(preds, batch_r):
            s_ = scorer.score(g, p)
            r1 += s_["rouge1"].fmeasure
            r2 += s_["rouge2"].fmeasure
            n  += 1
        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()
    return {"rouge1": r1/max(n,1), "rouge2": r2/max(n,1)}

In [13]:
decode_grid = [
               dict(do_sample=True, top_k=50, top_p=0.95, temperature=0.8),
               dict(do_sample=True, top_k=0, top_p=0.90, temperature=0.7), # nucleus only 
               dict(do_sample=True, top_k=100, top_p=0.95, temperature=0.9, repetition_penalty=1.1),
               ]

In [14]:
val_results = []
for cfg in decode_grid:
    scores = eval_rouge_model_batched(val_prefixes, val_refs, cfg, bs=16, max_new_tokens=64)
    val_results.append((cfg, scores))
    print("CFG:", cfg, "-> ROUGE1={:.3f} ROUGE2={:.3f}".format(scores["rouge1"], scores["rouge2"]))
    if device.type == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

CFG: {'do_sample': True, 'top_k': 50, 'top_p': 0.95, 'temperature': 0.8} -> ROUGE1=0.028 ROUGE2=0.001
CFG: {'do_sample': True, 'top_k': 0, 'top_p': 0.9, 'temperature': 0.7} -> ROUGE1=0.029 ROUGE2=0.001
CFG: {'do_sample': True, 'top_k': 100, 'top_p': 0.95, 'temperature': 0.9, 'repetition_penalty': 1.1} -> ROUGE1=0.026 ROUGE2=0.001


In [15]:
best_cfg, best_scores = sorted(val_results, key=lambda x: x[1]["rouge2"], reverse=True)[0]
print("Best DistilGPT2 VAL:", best_cfg, best_scores)

Best DistilGPT2 VAL: {'do_sample': True, 'top_k': 0, 'top_p': 0.9, 'temperature': 0.7} {'rouge1': 0.029377097270387, 'rouge2': 0.001326424644015844}


In [17]:
def show_examples(prefixes, references, generator_fn, n=5, seed=0):
    rnd = random.Random(seed)
    idxs = rnd.sample(range(min(len(prefixes), len(references))), k=min(n, len(prefixes)))
    for i in idxs:
        pred = generator_fn([prefixes[i]])[0]
        print(f"\n--- Example {i} ---")
        print("INPUT (X):    ", prefixes[i])
        print("PREDICTED: ", pred)
        print("TARGET: ", references[i])

print("\nSAMPLE VAL COMPLETIONS (DistilGPT2):")
show_examples(val_prefixes, val_refs,
              lambda p: complete_with_distilgpt2_batched(p, max_new_tokens=64, **best_cfg), n=5)


SAMPLE VAL COMPLETIONS (DistilGPT2):

--- Example 788 ---
INPUT (X):     gaylib1986 it was just areply on you facebook status that you
PREDICTED:   can now delete your posts without the permission of any third party.
I would have had to go back and forth with them, because I know they could do anything at all in this way so if anyone ever gets a violation from someone else (like me) please contact us directly or by email - we'll take whatever
TARGET:  were a little upset

--- Example 861 ---
INPUT (X):     setuid oh i learned my lesson from seinfeld was
PREDICTED:   an excellent example of how to get better at it.
In the future I'll be working on a post about Seinfeld's interaction with John Gillis (another man who has been teaching me for years) that will help us understand what is really happening in our lives and show you why we are here, along with
TARGET:  tempted but didnt

--- Example 82 ---
INPUT (X):     ashmrx
PREDICTED:  vwCqYXU9RQdVkG7Zn3t6LxJ1bP4Wg+cBj2z8yMpOGaSlv5eNl33ab