# Testing Llama 3.1 8b for n2c2 data leakage (manuscript appendix D)

In [None]:
# Load libraries
import random, torch, difflib, numpy as np
from unsloth import FastLanguageModel
from datasets import load_dataset
import tqdm

In [None]:
# Define the n2c2 2018 track 2 folder
data_dir = "/prj/doctoral_letters/notebooks/MIEQA/i2b22018/n2c2_2018_track2/"

In [None]:
DEVICE = 'cuda'
SAMPLE_SIZE = 20
SEED = 42
random.seed(SEED); torch.manual_seed(SEED)

# Load Llama-3·1

MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B"
model, tokenizer = FastLanguageModel.from_pretrained(
      model_name       = MODEL_NAME,
      max_seq_length   = 2048,
      load_in_4bit     = False,
      dtype            = None,
)
FastLanguageModel.for_inference(model)

In [None]:
# Load n2c2 2018 Track2 data
dataset = load_dataset(
    "bigbio/n2c2_2018_track2",
    name="n2c2_2018_track2_source",
    data_dir=data_dir
)["train"]

In [None]:
notes = [record["text"] for record in dataset]
sampled_notes = random.sample(notes, k=SAMPLE_SIZE)

In [None]:
# Helper functions
def slice_mid(tokens, keep=50):
    m = len(tokens) // 2
    return tokens[m - keep // 2 : m + keep // 2]

def prepare_prompt(text, slice_len=50):
    tokens = text.split()
    prompt = " ".join(slice_mid(tokens, slice_len)) + "\n"
    gold   = " ".join(tokens[len(tokens)//2 + slice_len//2 :])
    return prompt, gold

@torch.inference_mode()
def greedy_generate(prompt, max_new=256):
    ids = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    out = model.generate(**ids, max_new_tokens=max_new, do_sample=False)
    return tokenizer.decode(out[0][ids.input_ids.size(1):], skip_special_tokens=True)

def longest_exact_match(a, b, min_run=20):
    m = difflib.SequenceMatcher(None, a.split(), b.split())
    runs = [size for _,_,size in m.get_matching_blocks() if size >= min_run]
    return max(runs) if runs else 0

In [None]:
# Memorisation test  +  DEBUG PRINTS
DEBUG_N = 5

hits = 0
for idx, note in tqdm.tqdm(enumerate(sampled_notes)):
    prompt, gold = prepare_prompt(note)
    gen = greedy_generate(prompt)

    if idx < DEBUG_N:
        print("="*70)
        print(f"SAMPLE {idx+1}/{SAMPLE_SIZE}")
        print("- PROMPT (50-token slice) -")
        print(prompt.strip())
        print("- GENERATED CONTINUATION (first 20 chars) -")
        print(gen.strip()[:20])
        print("- EXPECTED GOLD CONTINUATION (first 20 chars) -")
        print(gold.strip()[:20])
        print()

    if longest_exact_match(gen, gold) >= 20:
        hits += 1

print("="*70)
print(f"Verbatim ≥20-token continuations: {hits}/{SAMPLE_SIZE}")