In [35]:
import random
import os
current_directory = os.getcwd()
in_path = os.path.join(current_directory, "data/train.txt")
out_path = "part2.txt"
random.seed(1)


In [47]:
def switch_word(btext: bytes) -> bytes:
    words = btext.split()
    if len(words) < 2:
        return btext
    i, j = random.sample(range(len(words)), 2)
    words[i], words[j] = words[j], words[i]
    return b" ".join(words)


def switch_char(btext: bytes) -> bytes:
    chars = list(btext)
    if not chars:
        return btext
    poss = [i for i, c in enumerate(chars) if c != 32]
    if not poss:
        poss = list(range(len(chars)))
    pos = random.choice(poss)
    original = chars[pos]
    choices = [c for c in range(97, 123)]
    new_c = original
    while new_c == original:
        new_c = random.choice(choices)
    chars[pos] = new_c
    return bytes(chars)


def delete_word(btext: bytes) -> bytes:
    words = btext.split()
    if not words:
        return btext
    pos = random.randrange(len(words))
    del words[pos]
    return b" ".join(words)


def switch_punctuation(btext: bytes) -> bytes:
    text = btext.decode("latin-1", errors="ignore")
    punct = list(".,;:!?-'\"")
    positions = [i for i, ch in enumerate(text) if ch in punct and i > 0]
    if not positions:
        pos = random.randrange(len(text) + 1) if len(text) else 0
        ch = random.choice(punct)
        text = text[:pos] + ch + text[pos:]
        return text.encode("latin-1", errors="ignore")
    pos = random.choice(positions)
    new_p = text[pos]
    while new_p == text[pos]:
        new_p = random.choice(punct)
    text = text[:pos] + new_p + text[pos + 1:]
    return text.encode("latin-1", errors="ignore")

def cut(btext: bytes) -> bytes:
    text = btext.decode("latin-1", errors="ignore")
    punct = list(".,;:!?\" ")
    positions = [i for i, ch in enumerate(text) if ch in punct and i > 0]
    if not positions:
        pos = random.randrange(1, len(text)) if len(text) > 1 else 0
        text = text[pos:] + text[:pos]
        return text.encode("latin-1", errors="ignore")
    pos = random.choice(positions)
    left = text[:pos].rstrip()
    right = text[pos:].lstrip()
    text = right + " " + left if left and right else right + left
    return text.encode("latin-1", errors="ignore")

def uncapitalize(btext: bytes) -> bytes:
    buf = bytearray(btext)
    caps = [i for i, c in enumerate(buf) if 65 <= c <= 90]
    if caps:
        pos = random.choice(caps)
        buf[pos] = buf[pos] + 32
        return bytes(buf)
    candidates = [i for i in range(1, len(buf)) if buf[i - 1] == 32 and 97 <= buf[i] <= 122]
    if not candidates:
        return btext
    pos = random.choice(candidates)
    buf[pos] = buf[pos] - 32
    return bytes(buf)

# logic from sample.py
def replace_last_word_with_lm(
    btext: bytes,
    gpt,
    meta_path: str = "data/train_lm/meta.pkl",
    device: str | None = None,
    max_new_bytes: int = 20,
) -> bytes:
    import torch, pickle, random

    meta = pickle.load(open(meta_path, "rb"))
    BOS, EOS = meta.get("bos_id", 256), meta.get("eos_id", 257)
    block_size = meta.get("block_size", getattr(gpt.config, "block_size", 1024))
    device = next(gpt.parameters()).device if device is None else torch.device(device)
    if device != next(gpt.parameters()).device:
        gpt = gpt.to(device)

    if not btext:
        return btext

    # preserve trailing whitespace and punctuation
    i = len(btext)
    while i > 0 and btext[i - 1] == 32:
        i -= 1
    trailing_ws = btext[i:]
    core = btext[:i]

    punct_bytes = b".,;:!?-'\""
    trailing_punct = b""
    while core and core[-1] in punct_bytes:
        trailing_punct = bytes([core[-1]]) + trailing_punct
        core = core[:-1]

    words = core.split()
    if not words:
        return btext
    orig_last = words[-1]
    prefix = b" ".join(words[:-1])
    ctx = prefix + (b" " if prefix else b"")

    ctx_ids = [BOS] + list(ctx)
    if len(ctx_ids) > block_size - 1:
        ctx_ids = ctx_ids[-(block_size - 1):]
    ctx_tensor = torch.tensor(ctx_ids, dtype=torch.long, device=device).unsqueeze(0)

    def sample_word():
        ids = ctx_tensor.clone()
        generated = []
        with torch.no_grad():
            for _ in range(max_new_bytes):
                if ids.size(1) > block_size - 1:
                    ids = ids[:, -(block_size - 1):]
                logits, _ = gpt(ids)
                probs = torch.softmax(logits[:, -1, :], dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)
                token = int(next_id.item())
                if token in (32, EOS):
                    break
                if 0 <= token <= 255:
                    generated.append(token)
                ids = torch.cat([ids, next_id], dim=1)
        return bytes(generated)

    punct_set = set(punct_bytes)
    new_word = orig_last
    for _ in range(10):
        cand = sample_word()
        while cand and cand[-1] in punct_set:
            cand = cand[:-1]
        if cand and cand != orig_last:
            new_word = cand
            break

    # if still identical, change one byte of the original last word
    if new_word == orig_last:
        buf = bytearray(orig_last)
        if buf:
            idx = random.randrange(len(buf))
            buf[idx] = (buf[idx] + random.randint(1, 25)) % 256
            new_word = bytes(buf)

    if prefix:
        return prefix + b" " + new_word + trailing_punct + trailing_ws
    else:
        return new_word + trailing_punct + trailing_ws


In [51]:
import torch
from LM.model import GPT, GPTConfig
LIMIT = 10000
device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load("out_lm/ckpt.pt", map_location=device)
margs = ckpt["model_args"]
gpt = GPT(GPTConfig(**margs))
sd = ckpt["model"]
for k in list(sd):
    if k.startswith("_orig_mod."):
        sd[k[10:]] = sd.pop(k)
gpt.load_state_dict(sd)
gpt.to(device).eval()

with open(in_path, "rb") as f:
        raw = f.read()

choices = [1, 2, 3, 4, 5, 6, 7]
pairs = []
count = 0
for line in raw.splitlines():
     if not line.strip():
          continue
     parts = line.split(b"	")
     if len(parts) == 2:
          pairs.append(parts[0])
     choice = random.choice(choices)
     if choice == 1:
          pairs.append(switch_word(parts[0]))
     elif choice == 2:
          pairs.append(switch_char(parts[0]))
     elif choice == 3:
          pairs.append(delete_word(parts[0]))
     elif choice == 4:
          pairs.append(cut(parts[0]))
     elif choice == 5: 
          pairs.append(uncapitalize(parts[0]))
     elif choice == 6 and count < LIMIT:
          pairs.append(replace_last_word_with_lm(parts[0], gpt, device=device))
          count += 1
     else:
          pairs.append(switch_punctuation(parts[0]))
             

  ckpt = torch.load("out_lm/ckpt.pt", map_location=device)


number of parameters: 85.15M


In [52]:
with open(out_path, "wb") as f:
    for i in range(0, len(pairs), 2):
        if i + 1 >= len(pairs):
            break
        orig = pairs[i]
        altered = pairs[i + 1]
        f.write(orig + b"\t" + altered + b"\n")

In [53]:
rows = []
for i in range(0, len(pairs), 2):
    if random.random() < 0.5:
        label, A, B = 1, pairs[i], pairs[i+1]
    else:
        label, A, B = 0, pairs[i+1], pairs[i]
    rows.append((label, A, B))

random.shuffle(rows)

out_path = "val_part2.tsv"
with open(out_path, "wb") as f:
    for lbl, A, B in rows:
        f.write(f"{lbl}\t".encode("ascii") + A + b"\t" + B + b"\n")

In [None]:
# sanity check for eval on identical pairs.
rows = []
for i in range(0, len(pairs), 2):
    orig = pairs[i]
    label = random.randint(0, 1)
    if label == 1:
        A, B = orig, orig
    else:
        A, B = orig, orig
    rows.append((label, A, B))

out_path = "identical_pairs.tsv"
with open(out_path, "wb") as f:
    for lbl, A, B in rows:
        f.write(f"{lbl}\t".encode("ascii") + A + b"\t" + B + b"\n")
