In [2]:
import os
from collections import defaultdict
from typing import Dict, List, Set, Tuple

import numpy as np
from tqdm.auto import tqdm

import torch
from transformers import AutoTokenizer, AutoModel
from sklearn.neighbors import NearestNeighbors


def parse_train_fasta(path: str) -> Dict[str, str]:
    """Parse train_sequences.fasta -> {EntryID: sequence}."""
    seqs = {}
    with open(path, "r") as f:
        current_id = None
        current_seq_parts = []
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                # save previous
                if current_id is not None:
                    seqs[current_id] = "".join(current_seq_parts)
                # header format: sp|A0A0C5B5G6|MOTSC_HUMAN ...
                header = line[1:]
                parts = header.split("|")
                if len(parts) >= 2:
                    entry_id = parts[1]
                else:
                    entry_id = header.split()[0]
                current_id = entry_id
                current_seq_parts = []
            else:
                current_seq_parts.append(line)
        # last record
        if current_id is not None:
            seqs[current_id] = "".join(current_seq_parts)
    return seqs


def parse_test_fasta(path: str) -> Tuple[Dict[str, str], Dict[str, str]]:
    """Parse testsuperset.fasta -> ({EntryID: sequence}, {EntryID: taxon_id})."""
    seqs = {}
    taxa = {}
    with open(path, "r") as f:
        current_id = None
        current_seq_parts = []
        for line in f:
            line = line.strip()
            if not line:
                continue
            if line.startswith(">"):
                if current_id is not None:
                    seqs[current_id] = "".join(current_seq_parts)
                header = line[1:]
                parts = header.split()
                entry_id = parts[0]
                tax_id = parts[1] if len(parts) > 1 else None
                current_id = entry_id
                taxa[entry_id] = tax_id
                current_seq_parts = []
            else:
                current_seq_parts.append(line)
        if current_id is not None:
            seqs[current_id] = "".join(current_seq_parts)
    return seqs, taxa


def load_train_terms(path: str) -> Dict[str, Set[str]]:
    """train_terms.tsv -> {EntryID: set(GO_id)}."""
    labels = defaultdict(set)
    with open(path, "r") as f:
        header = f.readline()
        for line in f:
            line = line.strip()
            if not line:
                continue
            entry_id, term, aspect = line.split("\t")
            labels[entry_id].add(term)
    return labels


def load_ia(path: str) -> Dict[str, float]:
    """IA.tsv -> {GO_id: ia_value}."""
    ia = {}
    with open(path, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            term, val = line.split("\t")
            ia[term] = float(val)
    return ia


def build_id_lists(
    train_seqs: Dict[str, str],
    train_labels: Dict[str, Set[str]],
) -> Tuple[List[str], List[str]]:
    """Return ordered list of train IDs that have at least one label,
    plus list of all GO terms used.
    """
    ids = []
    all_terms = set()
    for pid, seq in train_seqs.items():
        terms = train_labels.get(pid, set())
        if not terms:
            continue
        ids.append(pid)
        all_terms.update(terms)
    term_list = sorted(all_terms)
    return ids, term_list


def build_label_sets_for_ids(
    ids: List[str],
    train_labels: Dict[str, Set[str]],
    term_to_idx: Dict[str, int],
) -> List[Set[int]]:
    """Convert GO term sets to index sets for given IDs."""
    label_sets: List[Set[int]] = []
    for pid in ids:
        gos = train_labels.get(pid, set())
        idxs = {term_to_idx[t] for t in gos if t in term_to_idx}
        label_sets.append(idxs)
    return label_sets


def embed_sequences(
    model_name: str,
    ids: List[str],
    seqs: Dict[str, str],
    batch_size: int = 8,
    device: str = None,
) -> np.ndarray:
    """Compute per-sequence embeddings using a protein LM from HuggingFace.

    Returns array of shape (len(ids), hidden_dim).
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False)
    model = AutoModel.from_pretrained(model_name)
    model.to(device)
    model.eval()

    all_embs = []
    for start in tqdm(range(0, len(ids), batch_size), desc="Embedding sequences"):
        batch_ids = ids[start : start + batch_size]
        batch_seqs = [seqs[i] for i in batch_ids]
        tokens = tokenizer(
            batch_seqs,
            return_tensors="pt",
            padding=True,
            truncation=True,
        )
        tokens = {k: v.to(device) for k, v in tokens.items()}
        with torch.no_grad():
            out = model(**tokens)
            # last_hidden_state: [B, L, D]
            hidden = out.last_hidden_state
            mask = tokens["attention_mask"].unsqueeze(-1)  # [B, L, 1]
            hidden = hidden * mask
            summed = hidden.sum(dim=1)  # [B, D]
            lengths = mask.sum(dim=1)   # [B, 1]
            pooled = summed / lengths.clamp(min=1)
        all_embs.append(pooled.cpu().numpy())
    return np.concatenate(all_embs, axis=0)


def fit_knn(
    train_embs: np.ndarray,
    n_neighbors: int = 50,
) -> NearestNeighbors:
    """Fit a cosine-distance k-NN index on train embeddings."""
    nn = NearestNeighbors(
        n_neighbors=n_neighbors,
        metric="cosine",
        algorithm="brute",
        n_jobs=-1,
    )
    nn.fit(train_embs)
    return nn


def knn_predict_scores(
    query_embs: np.ndarray,
    nn: NearestNeighbors,
    train_label_sets: List[Set[int]],
) -> List[Dict[int, float]]:
    """For each query embedding, transfer labels from its k neighbors.

    Returns a list of dicts {term_idx: score in (0, 1]}.
    """
    dists, indices = nn.kneighbors(query_embs, return_distance=True)
    # cosine distance in [0, 2] -> similarity in [1, -1]
    sims = 1.0 - dists
    all_scores: List[Dict[int, float]] = []
    for neigh_idxs, neigh_sims in zip(indices, sims):
        scores: Dict[int, float] = defaultdict(float)
        # keep only positive similarities
        sim_pos = np.clip(neigh_sims, a_min=0.0, a_max=None)
        denom = float(sim_pos.sum())
        if denom <= 0:
            all_scores.append({})
            continue
        for idx, s in zip(neigh_idxs, sim_pos):
            if s <= 0:
                continue
            labels = train_label_sets[idx]
            for t in labels:
                scores[t] += float(s)
        # normalize to (0, 1]
        for t in list(scores.keys()):
            scores[t] = scores[t] / denom
            # clip away exact zeros just in case
            if scores[t] <= 0:
                del scores[t]
        all_scores.append(scores)
    return all_scores


def ia_weighted_fmax(
    gold_label_sets: List[Set[int]],
    pred_score_dicts: List[Dict[int, float]],
    ia_per_idx: Dict[int, float],
    thresholds: np.ndarray = None,
) -> Tuple[float, float, float, float]:
    """Approximate CAFA IA-weighted Fmax on a validation split.

    Returns (best_F, best_threshold, precision_at_best, recall_at_best).
    """
    if thresholds is None:
        thresholds = np.linspace(0.05, 0.9, 18)

    best_F = 0.0
    best_thr = 0.5
    best_P = 0.0
    best_R = 0.0

    for thr in thresholds:
        num = 0.0  # IA of true positives
        den_p = 0.0  # IA of predicted positives
        den_r = 0.0  # IA of actual positives

        for gold, pred_scores in zip(gold_label_sets, pred_score_dicts):
            pred = {t for t, s in pred_scores.items() if s >= thr}
            if not gold and not pred:
                continue
            if gold:
                den_r += sum(ia_per_idx.get(t, 0.0) for t in gold)
            if pred:
                den_p += sum(ia_per_idx.get(t, 0.0) for t in pred)
            if gold and pred:
                inter = gold & pred
                if inter:
                    num += sum(ia_per_idx.get(t, 0.0) for t in inter)

        P = num / den_p if den_p > 0 else 0.0
        R = num / den_r if den_r > 0 else 0.0
        F = 2 * P * R / (P + R) if (P + R) > 0 else 0.0

        if F > best_F:
            best_F = F
            best_thr = thr
            best_P = P
            best_R = R

    return best_F, best_thr, best_P, best_R


def write_submission(
    out_path: str,
    protein_ids: List[str],
    pred_score_dicts: List[Dict[int, float]],
    idx_to_term: List[str],
    max_terms_per_protein: int = 1500,
    min_score: float = 1e-3,
):
    """Write submission file with GO term predictions only."""
    with open(out_path, "w") as out:
        for pid, scores in zip(protein_ids, pred_score_dicts):
            if not scores:
                continue
            # convert idx->(GO, score) and apply filters
            items = [
                (idx_to_term[t], s)
                for t, s in scores.items()
                if s >= min_score
            ]
            if not items:
                continue
            # sort by score desc and keep top-K
            items.sort(key=lambda x: x[1], reverse=True)
            items = items[:max_terms_per_protein]

            for go, s in items:
                # scores must be in (0, 1.000] with up to 3 significant figures
                s_clipped = min(max(s, min_score), 1.0)
                s_str = f"{s_clipped:.3g}"
                out.write(f"{pid}\t{go}\t{s_str}\n")


if __name__ == "__main__":
    # Example end-to-end usage.
    # Adjust DATA_DIR to point to the unzipped competition files.
    DATA_DIR = "cafa-6-protein-function-prediction"

    train_seq_path = os.path.join(DATA_DIR, "Train", "train_sequences.fasta")
    train_terms_path = os.path.join(DATA_DIR, "Train", "train_terms.tsv")
    ia_path = os.path.join(DATA_DIR, "IA.tsv")
    test_fasta_path = os.path.join(DATA_DIR, "Test", "testsuperset.fasta")

    # 1. Load data
    print("Loading data...")
    train_seqs = parse_train_fasta(train_seq_path)
    train_labels = load_train_terms(train_terms_path)
    ia_by_term = load_ia(ia_path)

    train_ids, term_list = build_id_lists(train_seqs, train_labels)
    term_to_idx = {t: i for i, t in enumerate(term_list)}
    idx_to_term = term_list
    train_label_sets = build_label_sets_for_ids(train_ids, train_labels, term_to_idx)

    # Map IA per term index (useful if the user later wants to call ia_weighted_fmax)
    ia_per_idx = {term_to_idx[t]: ia_by_term.get(t, 0.0) for t in term_list}

    # 2. Compute embeddings for train and test
    model_name = "facebook/esm2_t12_35M_UR50D"  # small-ish ESM2; can be swapped for a larger model
    print("Embedding training proteins...")
    train_embs = embed_sequences(model_name, train_ids, train_seqs, batch_size=8)

    print("Embedding test proteins...")
    test_seqs, test_taxa = parse_test_fasta(test_fasta_path)
    test_ids = sorted(test_seqs.keys())
    test_embs = embed_sequences(model_name, test_ids, test_seqs, batch_size=8)

    # 3. Fit k-NN on full training set
    print("Fitting k-NN index...")
    nn = fit_knn(train_embs, n_neighbors=50)

    # 4. Predict scores for test proteins
    print("Predicting GO term scores for test proteins...")
    test_score_dicts = knn_predict_scores(test_embs, nn, train_label_sets)

    # 5. Write submission file
    out_path = "knn_esm2_submission.tsv"
    print(f"Writing submission to {out_path} ...")
    write_submission(out_path, test_ids, test_score_dicts, idx_to_term)
    print("Done.")


Loading data...
Embedding training proteins...


tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Embedding sequences:   0%|          | 0/10301 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


OutOfMemoryError: CUDA out of memory. Tried to allocate 41.92 GiB. GPU 0 has a total capacity of 47.26 GiB of which 41.05 GiB is free. Process 1729276 has 310.00 MiB memory in use. Process 2213346 has 404.00 MiB memory in use. Process 2747726 has 3.82 GiB memory in use. Including non-PyTorch memory, this process has 1.67 GiB memory in use. Of the allocated memory 1.00 GiB is allocated by PyTorch, and 484.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)