# Extract LaTa FF1 post-activation embeddings

This extracts the encoder FF1 post-activation representations (input to `wo`) for layers 1..12.


In [None]:
from pathlib import Path
import os
import sys


def in_colab() -> bool:
    return "COLAB_GPU" in os.environ or "COLAB_RELEASE_TAG" in os.environ


if in_colab():
    from google.colab import drive

    drive.mount("/content/drive")

    REPO_URL = os.environ.get("REPO_URL", "https://github.com/ianrowe12/localLatin.git")
    REPO_ROOT = Path(os.environ.get("REPO_ROOT", "/content/localLatin"))
    if not REPO_ROOT.exists():
        !git clone {REPO_URL}

    CANON_ROOT = Path(os.environ.get("CANON_ROOT", "/content/drive/MyDrive/localLatin_data/canon"))
    RUNS_ROOT = Path(os.environ.get("RUNS_ROOT", "/content/drive/MyDrive/localLatin_runs/ff1_lata_postact"))
else:
    REPO_ROOT = Path(os.environ.get("REPO_ROOT", "/Users/ianrowe/git/localLatin"))
    CANON_ROOT = Path(os.environ.get("CANON_ROOT", str(REPO_ROOT / "canon")))
    RUNS_ROOT = Path(os.environ.get("RUNS_ROOT", str(REPO_ROOT / "runs" / "ff1_lata_postact")))

sys.path.append(str(REPO_ROOT / "src"))

print(f"REPO_ROOT: {REPO_ROOT}")
print(f"CANON_ROOT: {CANON_ROOT}")
print(f"RUNS_ROOT: {RUNS_ROOT}")


In [None]:
from __future__ import annotations

from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from canon_retrieval import l2_normalize, load_texts, save_json

META_CSV = str(RUNS_ROOT / "meta.csv")
MODEL_NAME = "bowphs/LaTa"
MAX_LENGTH = 512
BATCH_SIZE = 12

run_id = datetime.now().strftime("run_%Y%m%d_%H%M%S")
RUN_DIR = str(RUNS_ROOT / run_id)
Path(RUN_DIR).mkdir(parents=True, exist_ok=True)

meta = pd.read_csv(META_CSV)
paths = meta["path"].tolist()

print(f"Run dir: {RUN_DIR}")
print(f"Files: {len(paths)}")


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
encoder = model.get_encoder() if hasattr(model, "get_encoder") else model.encoder
encoder.to(device)
encoder.eval()

config = {
    "model_name": MODEL_NAME,
    "max_length": MAX_LENGTH,
    "batch_size": BATCH_SIZE,
    "hook_point": "ff1_post_activation_input_to_wo",
    "layers": list(range(1, len(encoder.block) + 1)),
    "timestamp": run_id,
}

save_json(f"{RUN_DIR}/config.json", config)
meta.to_csv(f"{RUN_DIR}/meta.csv", index=False)


In [None]:
def get_ffn_wo(layer_module: torch.nn.Module) -> torch.nn.Module:
    if hasattr(layer_module, "DenseReluDense"):
        return layer_module.DenseReluDense.wo
    if hasattr(layer_module, "DenseGatedGeluDense"):
        return layer_module.DenseGatedGeluDense.wo
    if hasattr(layer_module, "wo"):
        return layer_module.wo
    raise AttributeError("Could not find FFN wo module")


@torch.no_grad()
def extract_layer_embeddings(layer_idx: int) -> np.ndarray:
    ffn_layer = encoder.block[layer_idx].layer[1]
    wo = get_ffn_wo(ffn_layer)
    captured: list[torch.Tensor] = []

    def hook(module, inputs):
        captured.append(inputs[0])

    handle = wo.register_forward_pre_hook(hook)

    embeddings = []
    for start in range(0, len(paths), BATCH_SIZE):
        batch_paths = paths[start : start + BATCH_SIZE]
        batch_texts = load_texts(batch_paths)
        enc = tokenizer(
            batch_texts,
            truncation=True,
            max_length=MAX_LENGTH,
            padding=True,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"].to(device)
        attention_mask = enc["attention_mask"].to(device)
        captured.clear()
        _ = encoder(input_ids=input_ids, attention_mask=attention_mask)
        if len(captured) != 1:
            raise RuntimeError(f"Expected 1 capture, got {len(captured)}")
        ff_act = captured[0]
        mask = attention_mask.unsqueeze(-1).float()
        pooled = (ff_act * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1.0)
        embeddings.append(pooled.detach().cpu().numpy().astype(np.float32))

    handle.remove()
    return np.concatenate(embeddings, axis=0)


for layer_idx in range(len(encoder.block)):
    layer_num = layer_idx + 1
    print(f"Extracting layer {layer_num}...")
    emb = extract_layer_embeddings(layer_idx)
    emb_norm = l2_normalize(emb)
    np.save(f"{RUN_DIR}/ff1_layer{layer_num}_embeddings.npy", emb)
    np.save(f"{RUN_DIR}/ff1_layer{layer_num}_embeddings_norm.npy", emb_norm)

print("Done.")


In [None]:
from pathlib import Path
import os
import subprocess

key_path = os.environ.get("DRIVE_SA_KEY_PATH", "/content/sa_drive_key.json")
if Path(key_path).exists():
    env = os.environ.copy()
    env["DRIVE_SA_KEY_PATH"] = key_path
    subprocess.run(
        ["python", "-m", "src.drive_sync", "--local_run_dir", RUN_DIR],
        cwd=str(REPO_ROOT),
        env=env,
        check=True,
    )
    print("Synced run to Drive.")
else:
    print(f"Drive key not found at {key_path}; skipping sync.")
