In [5]:
import os
from pathlib import Path
import pandas as pd

# ====== CONFIG ======
in_dir = Path("/home/ec2-user/SageMaker/InterPLM/oas_data")
pattern = "*.csv.gz"   # adjust if needed
out_path = Path("all_paired_antibodies_compact.csv")

# Columns we want to extract
wanted_cols = {
    "VH_full": "sequence_alignment_aa_heavy",
    "VL_full": "sequence_alignment_aa_light",
    "CDRH1": "cdr1_aa_heavy",
    "CDRH2": "cdr2_aa_heavy",
    "CDRH3": "cdr3_aa_heavy",
    "CDRL1": "cdr1_aa_light",
    "CDRL2": "cdr2_aa_light",
    "CDRL3": "cdr3_aa_light",
}

# ====== HELPERS ======
def load_oas_table(fp: Path) -> pd.DataFrame:
    """
    Load an OAS file that has a JSON metadata line first, then a table.
    Tries TSV first; falls back to auto-sniffing if needed.
    """
    # Most OAS dumps are TSV; skip the first JSON row
    try:
        df = pd.read_csv(fp, compression="gzip", sep="\t", header=0, skiprows=1, dtype=str)
        if df.shape[1] == 1:
            # Fall back to Python engine with sep=None to sniff delimiters
            df = pd.read_csv(fp, compression="gzip", sep=None, engine="python", header=0, skiprows=1, dtype=str)
        return df
    except Exception as e:
        # Some files may still be comma-separated; try a plain CSV read as last resort
        df = pd.read_csv(fp, compression="gzip", header=0, skiprows=1, dtype=str)
        return df

def extract_compact(df: pd.DataFrame, source_name: str) -> pd.DataFrame:
    """
    Build compact dataframe with the columns in `wanted_cols`.
    Adds 'source_file' for traceability and drops rows with missing VH/VL.
    """
    missing = [c for c in wanted_cols.values() if c not in df.columns]
    if missing:
        raise KeyError(f"Missing required columns: {missing}")

    out = pd.DataFrame({out_col: df[in_col] for out_col, in_col in wanted_cols.items()})
    out.insert(0, "source_file", source_name)

    # Drop rows without either chain (you can relax this if needed)
    out = out.dropna(subset=["VH_full", "VL_full"])
    return out

# ====== MAIN ======
all_rows = []
broken = []
total_sequences = 0
files_seen = 0
files_ok = 0

for fp in sorted(in_dir.glob(pattern)):
    files_seen += 1
    try:
        df = load_oas_table(fp)
        compact = extract_compact(df, fp.name)
        all_rows.append(compact)
        n = len(compact)
        total_sequences += n
        files_ok += 1
        print(f"[OK] {fp.name}: {n} sequences")
    except Exception as e:
        broken.append((fp.name, str(e)))
        print(f"[FAIL] {fp.name}: {e}")

# Concatenate and save if we got anything
if all_rows:
    final_df = pd.concat(all_rows, ignore_index=True)
    final_df.to_csv(out_path, index=False)
    print(f"\nSaved {len(final_df)} sequences from {files_ok}/{files_seen} files to {out_path}")
else:
    print(f"\nNo sequences extracted. {files_ok}/{files_seen} files succeeded.")

# Report failures
if broken:
    print("\nBroken files:")
    for name, msg in broken:
        print(f" - {name} :: {msg}")

# Summary counts
print("\n=== Summary ===")
print(f"Files seen:        {files_seen}")
print(f"Files succeeded:   {files_ok}")
print(f"Files failed:      {len(broken)}")
print(f"Total sequences:   {total_sequences}")


[OK] 1279049_1_Paired_All.csv.gz: 8954 sequences
[OK] 1279050_1_Paired_All.csv.gz: 15196 sequences
[OK] 1279051_1_Paired_All.csv.gz: 11508 sequences
[OK] 1279052_1_Paired_All.csv.gz: 1112 sequences
[OK] 1279053_1_Paired_All.csv.gz: 10175 sequences
[OK] 1279054_1_Paired_All.csv.gz: 9723 sequences
[OK] 1279055_1_Paired_All.csv.gz: 5498 sequences
[OK] 1279057_1_Paired_All.csv.gz: 9463 sequences
[OK] 1279058_1_Paired_All.csv.gz: 14965 sequences
[OK] 1279059_1_Paired_All.csv.gz: 13651 sequences
[OK] 1279060_1_Paired_All.csv.gz: 1275 sequences
[OK] 1279061_1_Paired_All.csv.gz: 10493 sequences
[OK] 1279062_1_Paired_All.csv.gz: 10486 sequences
[OK] 1279063_1_Paired_All.csv.gz: 4661 sequences
[OK] 1279065_1_Paired_All.csv.gz: 8369 sequences
[OK] 1279066_1_Paired_All.csv.gz: 12381 sequences
[OK] 1279067_1_Paired_All.csv.gz: 8437 sequences
[OK] 1279068_1_Paired_All.csv.gz: 843 sequences
[OK] 1279069_1_Paired_All.csv.gz: 12228 sequences
[OK] 1279070_1_Paired_All.csv.gz: 11482 sequences
[OK] 127907

In [7]:
# Count missing VH or VL
n_missing_vh = final_df["VH_full"].isna().sum() + (out_df["VH_full"] == "").sum()
n_missing_vl = final_df["VL_full"].isna().sum() + (out_df["VL_full"] == "").sum()

print(f"Missing VH: {n_missing_vh}")
print(f"Missing VL: {n_missing_vl}")

# Rows where either VH or VL is missing
incomplete = final_df[final_df["VH_full"].isna() | out_df["VL_full"].isna()]
print(f"Incomplete pairs: {len(incomplete)} out of {len(final_df)} total")


Missing VH: 0
Missing VL: 0
Incomplete pairs: 0 out of 3003127 total


In [3]:
import gzip
import pandas as pd
from pathlib import Path

# ====== CONFIG ======
CSV_PATH = Path("all_paired_antibodies_compact.csv")  # from your earlier script
OUT_DIR = Path("data/sharded_oas")
OUT_DIR.mkdir(parents=True, exist_ok=True)
NUM_SHARDS = 8
MAX_LEN = 1024

VH_COL, VL_COL = "VH_full", "VL_full"

def clean_seq(s: str) -> str:
    """Uppercase, strip gaps/stops, truncate."""
    if not isinstance(s, str):
        return ""
    s = s.upper().replace("-", "").replace("*", "")
    return s[:MAX_LEN]

def csv_to_8shards():
    df = pd.read_csv(CSV_PATH, usecols=[VH_COL, VL_COL], dtype=str)
    df = df.dropna(subset=[VH_COL, VL_COL])

    # Clean
    df[VH_COL] = df[VH_COL].map(clean_seq)
    df[VL_COL] = df[VL_COL].map(clean_seq)
    df = df[(df[VH_COL].str.len() > 0) & (df[VL_COL].str.len() > 0)]

    n = len(df)
    shard_size = n // NUM_SHARDS
    print(f"Total pairs: {n}, ≈{shard_size} per shard")

    for i in range(NUM_SHARDS):
        start = i * shard_size
        end = (i+1) * shard_size if i < NUM_SHARDS - 1 else n
        shard_df = df.iloc[start:end]

        out_path = OUT_DIR / f"pairs_shard_{i}.fa"
        with gzip.open(out_path, "wt") as fh:
            for j, (vh, vl) in enumerate(zip(shard_df[VH_COL], shard_df[VL_COL]), start=start+1):
                pid = f"p{j:08d}"
                fh.write(f">{pid}|H\n{vh}\n>{pid}|L\n{vl}\n")
        print(f"[done] {out_path} ({end-start} pairs)")


## Sharding OAS Data to Fasta shards for SAE training

In [4]:
csv_to_8shards()

Total pairs: 3003127, ≈375390 per shard
[done] data/sharded_oas/pairs_shard_1.fa (375390 pairs)
[done] data/sharded_oas/pairs_shard_2.fa (375390 pairs)
[done] data/sharded_oas/pairs_shard_3.fa (375390 pairs)
[done] data/sharded_oas/pairs_shard_4.fa (375390 pairs)
[done] data/sharded_oas/pairs_shard_5.fa (375390 pairs)
[done] data/sharded_oas/pairs_shard_6.fa (375390 pairs)
[done] data/sharded_oas/pairs_shard_7.fa (375390 pairs)
[done] data/sharded_oas/pairs_shard_8.fa (375397 pairs)


In [6]:
def build_variable_domain(df, prefix="heavy"):
    parts = [
        f"fwr1_aa_{prefix}",
        f"cdr1_aa_{prefix}",
        f"fwr2_aa_{prefix}",
        f"cdr2_aa_{prefix}",
        f"fwr3_aa_{prefix}",
        f"cdr3_aa_{prefix}",
        f"fwr4_aa_{prefix}",
    ]
    return df[parts].fillna("").agg("".join, axis=1)

df["VH_clean"] = build_variable_domain(df, "heavy")
df["VL_clean"] = build_variable_domain(df, "light")

In [8]:
df.iloc[0]["sequence_alignment_aa_heavy"]

'QLQLQESGPGLVKPSETLSLTCTVSGGSISSSSYYWGWIRQPPGKGLEWIGSIYYSGSTYYNPSLKSRVTISVDTSKNQFSLKLSSVTAADTAVYYCARGAAAAGTLPLDYWGQGTLVTVSS'

In [7]:
df.iloc[0]["VH_clean"]

'QLQLQESGPGLVKPSETLSLTCTVSGGSISSSSYYWGWIRQPPGKGLEWIGSIYYSGSTYYNPSLKSRVTISVDTSKNQFSLKLSSVTAADTAVYYCARGAAAAGTLPLDYWGQGTLVTVSS'

In [10]:
def print_antibody_segments(df, row_idx=0):
    """Pretty-print FWR/CDR segments for heavy and light chains of one antibody."""

    def show_chain(prefix, label):
        print(f"\n=== {label} Chain ({prefix}) ===")
        for seg in ["fwr1", "cdr1", "fwr2", "cdr2", "fwr3", "cdr3", "fwr4"]:
            col = f"{seg}_aa_{prefix}"
            if col in df.columns:
                seq = df.at[row_idx, col]
                print(f"{seg.upper():<5}: {seq if pd.notna(seq) else ''}")

    show_chain("heavy", "Heavy (VH)")
    show_chain("light", "Light (VL)")


# Example usage: look at the first row
print_antibody_segments(df, row_idx=0)



=== Heavy (VH) Chain (heavy) ===
FWR1 : QLQLQESGPGLVKPSETLSLTCTVS
CDR1 : GGSISSSSYY
FWR2 : WGWIRQPPGKGLEWIGS
CDR2 : IYYSGST
FWR3 : YYNPSLKSRVTISVDTSKNQFSLKLSSVTAADTAVYYC
CDR3 : ARGAAAAGTLPLDY
FWR4 : WGQGTLVTVSS

=== Light (VL) Chain (light) ===
FWR1 : EIVLTQSPATLSLSPGERATLSCRAS
CDR1 : QSVSSY
FWR2 : LAWYQQKPGQAPRLLIY
CDR2 : DAS
FWR3 : NRATGIPARFSGSGSGTDFTLTISSLEPEDFAVYYC
CDR3 : QQRSNWPPSLT
FWR4 : FGGGTKVEIK
