"""
fine_tune_gpt2_lora.py

Fine-tune GPT-2 Small on Enron email–reply pairs with LoRA adapters,
using either full-thread or subject+last-email prompts and tone conditioning.
"""
need to test the pipeline from local to GitHub

In [None]:
pip install transformers datasets peft accelerate

Defaulting to user installation because normal site-packages is not writeable
Collecting peft
  Downloading peft-0.16.0-py3-none-any.whl.metadata (14 kB)
Collecting accelerate
  Downloading accelerate-1.9.0-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.13.0->peft)
  Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.13.0->peft)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.13.0->peft)
  Downloading nvidia_cublas_cu1

In [None]:
import argparse
import logging
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, TaskType

Data Preparation

In [None]:
# data_prep_step1.py

import pandas as pd
import json
import re
import warnings

def extract_tag_index(tag: str) -> int:
    """
    Map message tags to an integer index for ordering:
      "<|original|>" → 0
      "<|replyN|>"   → N
    Anything else → large number (sorts last).
    """
    if tag == "<|original|>":
        return 0
    m = re.match(r"<\|reply(\d+)\|>", tag)
    if m:
        return int(m.group(1))
    return 9999

def prepare_pairs(input_csv: str, output_jsonl: str):
    # 1) Load the CSV, skipping malformed lines
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        df = pd.read_csv(
            input_csv,
            engine='python',
            on_bad_lines='skip'
        )

    # 2) Ensure there's a 'subject' column (blank for now)
    if 'subject' not in df.columns:
        df['subject'] = ""

    examples = []

    # 3) Process each thread
    for thread_id, group in df.groupby("message_id", sort=False):
        grp = group.copy()
        # 4) Sort messages by tag index
        grp['idx'] = grp['tag'].apply(extract_tag_index)
        grp = grp.sort_values('idx')

        msgs = grp['clean_message'].tolist()
        # 5a) Grab the subject from the original message row
        subject_rows = grp.loc[grp['idx'] == 0, 'subject']
        subject = subject_rows.iloc[0] if not subject_rows.empty else ""

        # 5b) For each reply position i (i >= 1), build one example
        for i in range(1, len(msgs)):
            examples.append({
                "thread": " ".join(msgs[:i]),  # all messages before reply i
                "subject": subject,            # blank or to-be-filled later
                "email": msgs[i-1],            # immediate predecessor
                "reply": msgs[i],              # this reply
                "tone": "[formal]"             # placeholder tone
            })

    # 6) Write examples to JSONL
    with open(output_jsonl, 'w', encoding='utf-8') as fout:
        for ex in examples:
            fout.write(json.dumps(ex, ensure_ascii=False) + "\n")

if __name__ == "__main__":
    input_csv    = "enron_cleaned_v7.csv"
    output_jsonl = "enron_pairs.jsonl"

    print(f"Reading {input_csv}…")
    prepare_pairs(input_csv, output_jsonl)
    count = sum(1 for _ in open(output_jsonl, encoding='utf-8'))
    print(f"Wrote {count} examples to {output_jsonl}.")


Reading enron_cleaned_v7.csv…
Wrote 44847 examples to enron_pairs.jsonl.
