In [None]:
pip install sentencepiece

In [None]:
import os, re, json
from tqdm import tqdm
import torch
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# ---------------- CONFIG ----------------
PARA_MODEL  = "Vamsi/T5_Paraphrase_Paws"     # paraphrase-tuned [web:250]
STYLE_MODEL = "google/flan-t5-base"          # better at instruction-following [web:268]

BATCH_SIZE = 32
MAX_SAMPLES = 5000          # keep 5k for testing first
OUT_PATH = "data/imdb_triplets.jsonl"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE.startswith("cuda") else torch.float32

MAX_IN_TOKENS = 256         # keep inputs shorter than full IMDb reviews
MAX_NEW_TOKENS = 128        # cap generation length

SEED = 42
# ----------------------------------------

def clean_html(text: str) -> str:
    # remove common IMDb HTML breaks
    text = text.replace("<br />", " ").replace("<br/>", " ").replace("<br>", " ")
    text = re.sub(r"\s+", " ", text).strip()
    return text

def shorten(text: str, max_chars: int = 900) -> str:
    # cheap, stable shortening for IMDb long reviews
    return text[:max_chars]

def is_bad(gen: str) -> bool:
    g = gen.strip().lower()
    if g in {"true", "false"}:
        return True
    if len(g) < 20:
        return True
    return False

@torch.inference_mode()
def generate_batch(model, tok, prompts):
    enc = tok(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=MAX_IN_TOKENS,
    )
    enc = {k: v.to(DEVICE) for k, v in enc.items()}

    out_ids = model.generate(
        **enc,
        max_new_tokens=MAX_NEW_TOKENS,   # prefer this over max_length [web:268][web:226]
        num_beams=4,
        do_sample=False,
        early_stopping=True,
    )
    return [tok.decode(x, skip_special_tokens=True).strip() for x in out_ids]

def maybe_resume_count(path):
    if not os.path.exists(path):
        return 0
    with open(path, "r", encoding="utf-8") as f:
        return sum(1 for _ in f)

def main():
    os.makedirs(os.path.dirname(OUT_PATH), exist_ok=True)
    print("DEVICE:", DEVICE)

    print("Loading models...")
    para_tok = AutoTokenizer.from_pretrained(PARA_MODEL)
    para_mod = AutoModelForSeq2SeqLM.from_pretrained(PARA_MODEL, torch_dtype=DTYPE).to(DEVICE).eval()

    style_tok = AutoTokenizer.from_pretrained(STYLE_MODEL)
    style_mod = AutoModelForSeq2SeqLM.from_pretrained(STYLE_MODEL, torch_dtype=DTYPE).to(DEVICE).eval()

    print("Loading IMDb...")
    ds = load_dataset("imdb", split="train")

    # balanced subset
    pos = ds.filter(lambda x: x["label"] == 1).shuffle(seed=SEED)
    neg = ds.filter(lambda x: x["label"] == 0).shuffle(seed=SEED)
    pos = pos.select(range(MAX_SAMPLES // 2))
    neg = neg.select(range(MAX_SAMPLES // 2))

    data = concatenate_datasets([pos, neg]).shuffle(seed=SEED)  # [web:273]
    data = data.remove_columns([c for c in data.column_names if c not in ["text", "label"]])

    # resume support
    already = maybe_resume_count(OUT_PATH)
    if already > 0:
        print(f"Resuming: {already} lines already in {OUT_PATH}")
    start_idx = already
    if start_idx >= len(data):
        print("Nothing to do; file already complete.")
        return

    # iterate in slices (fast + deterministic)
    with open(OUT_PATH, "a", encoding="utf-8") as f:
        for start in tqdm(range(start_idx, len(data), BATCH_SIZE)):
            batch = data[start : start + BATCH_SIZE]
            anchors_raw = batch["text"]
            labels = batch["label"]

            anchors = [shorten(clean_html(t)) for t in anchors_raw]

            # prompts
            para_prompts = [f"paraphrase: {a} </s>" for a in anchors]  # matches model card style [web:252]
            style_prompts = [f"Rewrite as a short casual tweet, keep sentiment the same: {a}" for a in anchors]

            paras = generate_batch(para_mod, para_tok, para_prompts)
            styles = generate_batch(style_mod, style_tok, style_prompts)

            for a, p, s, y in zip(anchors, paras, styles, labels):
                # fallback if generation is broken
                if is_bad(p): p = a
                if is_bad(s): s = a

                rec = {
                    "anchor": a,
                    "positive_para": p,
                    "positive_style": s,
                    "label": int(y),
                }
                f.write(json.dumps(rec, ensure_ascii=False) + "\n")

    print("Done:", OUT_PATH)

if __name__ == "__main__":
    main()


DEVICE: cuda
Loading models...
Loading IMDb...


 22%|██▏       | 35/157 [05:54<20:48, 10.23s/it]

Note: you may need to restart the kernel to use updated packages.
