# 1.GENERACIÓN DE SECUENCIAS
En primer lugar vamos a generar una cantidad suficiente de secuencias para tener margen para filtrarlas y quedarnos con las más probables:

In [None]:
!pip install fair-esm
!pip install tqdm biopython numpy


Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m36.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85


In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

MODEL_ID = "hugohrban/progen2-base"

tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
tok.pad_token = tok.eos_token          # ← solución al ValueError

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map="cuda",
    trust_remote_code=True
).eval()

prompts = ["1M"] * 2
batch   = tok(prompts, return_tensors="pt", padding=True).to(model.device)

with torch.no_grad():
    out = model.generate(
        **batch,
        max_length=100,
        do_sample=True,
        top_p=0.9,
        temperature=0.8,
        num_return_sequences=50,        # 128 secuencias
        pad_token_id=tok.pad_token_id  # usa el que acabas de fijar
    )

seqs = tok.batch_decode(out, skip_special_tokens=True)
print(len(seqs), "secuencias generadas")


config.json:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

configuration_progen.py:   0%|          | 0.00/2.63k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/hugohrban/progen2-base:
- configuration_progen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


tokenizer.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

modeling_progen.py:   0%|          | 0.00/24.6k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/hugohrban/progen2-base:
- modeling_progen.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors:   0%|          | 0.00/3.06G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

100 secuencias generadas


In [None]:
from tqdm.auto import tqdm          # barra de progreso adaptativa
import torch, gc

TOTAL           = 50000
PROMPTS_PER_RUN = 4                 # 3 seeds por iter
VARIANTS        = 100                # 3 × 50 = 150 secuencias / iter
MAX_LEN         = 100
PAD_ID          = tok.pad_token_id

out_path = "/content/drive/MyDrive/TFM/progen_sequences2.fasta"

with open(out_path, "w") as fa, tqdm(total=TOTAL, unit="seq") as pbar:
    generated = 0

    while generated < TOTAL:
        # 1) Verificar tokens (debug)
        print("EOS token:", tok.eos_token)  # Debería ser '2'
        print("Pad token:", tok.pad_token)  # Debería ser '2'

        # 2) Preparar lote (sin cambios)
        batch = tok(["1M"] * PROMPTS_PER_RUN,
                    return_tensors="pt",
                    padding=True).to(model.device)

        # 3) Generación con EOS
        with torch.no_grad():
            out = model.generate(
                **batch,
                max_length=MAX_LEN,
                temperature=0.8,
                top_p=0.9,
                do_sample=True,
                num_return_sequences=VARIANTS,
                eos_token_id=tok.eos_token_id,  # Detiene al generar '2'
                pad_token_id=PAD_ID
            )

        # 4) Decodificación + limpieza
        for seq in tok.batch_decode(out, skip_special_tokens=True):
            if generated >= TOTAL:
                break
            clean = seq[1:] if seq.startswith("1") else seq
            clean = clean.rstrip('2')  # Elimina '2' residual
            generated += 1
            pbar.update(1)
            fa.write(f">seq_{generated}\n{clean}\n")

        # 5) Liberar memoria (sin cambios)
        torch.cuda.empty_cache()
        gc.collect()


  0%|          | 0/50000 [00:00<?, ?seq/s]

EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>
EOS token: <|endoftext|>
Pad token: <|endoftext|>


# FILTRADO DE SECUENCIAS PLAUSIBLES
Una vez generadas las secuencias vamos a pasar a filtrarlas:
0. Inspección de homología
1. Inspección básica (Metionina terminal, sin repeticiones de un aminoácido ni de dímeros)
2. Detección de señales tansmembrana y peptidos señal
3. ESM-1v score para quedarnos con las secuencias más plausibles.

## 0. INSPECCIÓN DE HOMOLOGÍA

In [None]:
#!wget ftp://ftp.uniprot.org/pub/databases/uniprot/current_release/uniref/uniref90/uniref90.fasta.gz
#!wget http://github.com/bbuchfink/diamond/releases/download/v2.1.11/diamond-linux64.tar.gz
#!tar xzf diamond-linux64.tar.gz
#!gunzip uniref90.fasta.gz
#!./diamond makedb --in uniref90.fasta -d uniref
#!./diamond blastp --query passed.fasta --db swissprot.dmnd --out diamond_output.tsv --outfmt 6 --evalue 1e-5 --max-target-seqs 1 --threads 4!

diamond v2.1.11.165 (C) Max Planck Society for the Advancement of Science, Benjamin Buchfink, University of Tuebingen
Documentation, support and updates available at http://www.diamondsearch.org
Please cite: http://dx.doi.org/10.1038/s41592-021-01101-x Nature Methods (2021)

#CPU threads: 8
Scoring parameters: (Matrix=BLOSUM62 Lambda=0.267 K=0.041 Penalties=11/1)
Database input file: uniref90.fasta
Opening the database file...  [0.034s]
Loading sequences...  [11.417s]
Masking sequences...  [13.26s]
Writing sequences...  [0.884s]
Hashing sequences...  [0.295s]
Loading sequences...  [11.473s]
Masking sequences...  [12.806s]
Writing sequences...  [0.888s]
Hashing sequences...  [0.296s]
Loading sequences...  [11.95s]
Masking sequences...  [12.791s]
Writing sequences...  [0.919s]
Hashing sequences...  [0.298s]
Loading sequences...  [11.831s]
Masking sequences...  [12.817s]
Writing sequences...  [1.041s]
Hashing sequences...  [0.314s]
Loading sequences...  [12.763s]
Masking sequences...  [12

## 1. Scores ESM-1v (primer filtro de probabilidad de secuencia)

In [None]:
#!/usr/bin/env python3
"""
Descripción:
 1) Lee un FASTA de entrada.
 2) Calcula para cada secuencia el score promedio (mean log-prob) con el modelo ESM.
 3) Filtra según un cuantíl o umbral fijado en variables.
 4) Escribe un FASTA con las secuencias que pasan el filtro y un CSV con id, score, passed.
"""
import torch
import torch.nn.functional as F
import esm
from tqdm import tqdm
from Bio import SeqIO
import numpy as np


def score_sequence(model, alphabet, seq, device):
    # Filtrar caracteres no válidos para evitar KeyError
    valid = set(alphabet.tok_to_idx.keys())
    seq = ''.join([c for c in seq if c in valid])
    if len(seq) == 0:
        raise ValueError("Sequence contains no valid amino acids after filtering.")
    # Tokenizar y mover a device
    # Tokenizar y mover a device
    batch_converter = alphabet.get_batch_converter()
    _, _, toks = batch_converter([("seq", seq)])
    toks = toks.to(device)
    # Forward pass
    with torch.no_grad():
        out = model(toks, repr_layers=[], return_contacts=False)
    logits = out["logits"][0, 1:1+len(seq)]  # (L, vocab)
    log_probs = F.log_softmax(logits, dim=-1)
    token_ids = toks[0, 1:1+len(seq)]
    pos_logps = log_probs[torch.arange(len(seq)), token_ids]
    return pos_logps.mean().item()


def main():
    # --- Parámetros predefinidos ---
    input_fasta   = "/content/progen_sequences2.fasta"
    model_name    = "esm1v_t33_650M_UR90S_1"
    device        = "cuda"  # "cpu" o "cuda"
    # Define un umbral fijo o un cuantíl
    #threshold   = -3.5
    quantile     = 0.9
    output_fasta  = "passed.fasta"
    output_csv    = "scores.csv"
    # -------------------------------

    # Configura dispositivo
    device = torch.device(device if torch.cuda.is_available() and device.startswith("cuda") else "cpu")

    # Cargar modelo y alfabeto
    print(f"[INFO] Cargando modelo {model_name} en {device}")
    model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
    model = model.eval().to(device)

    # Leer FASTA
    records = list(SeqIO.parse(input_fasta, "fasta"))
    print(f"[INFO] {len(records)} secuencias cargadas desde {input_fasta}")

    # Calcular scores
    results = []
    for rec in tqdm(records, desc="Scoring"):
        # Limpiar secuencia: solo letras (A-Z)
        raw_seq = str(rec.seq)
        seq = ''.join(c for c in raw_seq if c.isalpha())
        if not seq:
            print(f"[WARN] Secuencia {rec.id} vacía tras limpieza, se omite.")
            continue
        score = score_sequence(model, alphabet, seq, device)
        results.append((rec.id, rec.seq, score))

    # Determinar umbral(rec.id, rec.seq, score))

    # Determinar umbral
    scores = [s for _,_,s in results]
    if 'quantile' in locals():
        thr = float(np.quantile(scores, quantile))
        print(f"[INFO] Umbral basado en cuantíl {quantile}: {thr:.6f}")
    else:
        thr = threshold
        print(f"[INFO] Umbral fijo: {thr:.6f}")

    # Filtrar y escribir CSV y FASTA
    passed_ids = []
    with open(output_csv, "w") as csvf:
        csvf.write("id,score,passed")
        for ident, seq, score in results:
            passed = int(score >= thr)
            csvf.write(f"{ident},{score:.6f},{passed}")
            if passed:
                passed_ids.append(ident)
    # Escribir FASTA resultante
    with open(output_fasta, "w") as fh:
        for rec in records:
            if rec.id in passed_ids:
                SeqIO.write(rec, fh, "fasta")

    print(f"[DONE] {len(passed_ids)}/{len(results)} secuencias pasan el filtro")

if __name__ == "__main__":
    main()


[INFO] Cargando modelo esm1v_t33_650M_UR90S_1 en cuda
[INFO] 50000 secuencias cargadas desde /content/progen_sequences2.fasta


Scoring: 100%|██████████| 50000/50000 [54:51<00:00, 15.19it/s]


[INFO] Umbral basado en cuantíl 0.9: -0.121989
[DONE] 5000/50000 secuencias pasan el filtro


In [None]:
input_fasta    = "/content/progen_sequences2.fasta"
filter_csv     = "/content/progen_sequences_filter1.csv"
ESMthreshold   = -0.2
conservate     = set()

# 1) Lee el CSV y rellena el set de secuencias a conservar
with open(filter_csv) as f:
    # Si tu CSV tiene cabecera, descomenta la siguiente línea:
    # next(f)
    n=1
    for line in f:
        if n==1:
          n=0
          continue
        ignorar,seq, score, ident = line.strip().split(",")
        seq=seq.strip('"').strip('"')
        score = float(score)
        ident = float(ident)
        # si identidad > 30% lo saltamos;
        # si score < umbral lo guardamos:
        if ident <= 30 and score > ESMthreshold:
            conservate.add(seq)

print(f"Secuencias a conservar: {len(conservate)}")

# 2) Filtra el FASTA
output_fasta = "/content/progen_sequences_conserved.fasta"
with open(input_fasta) as fin, open(output_fasta, "w") as fout:
    write = False
    for line in fin:
        if line.startswith(">"):
            # Extrae sólo el ID (hasta el primer espacio)
            seqid = line[1:].strip().split()[0]
            write = (seqid in conservate)
        if write:
            fout.write(line)

print(f"Escrito el FASTA filtrado en: {output_fasta}")


Secuencias a conservar: 9861
Escrito el FASTA filtrado en: /content/progen_sequences_conserved.fasta


## 2. Inspección básica y detección de señales transmembrana y péptidos señal

In [None]:
!pip install -q git+https://github.com/BernhoferM/TMbed.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m115.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m99.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m59.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m21.3 MB/s[0m eta [36m0:

In [None]:
"""
La pipeline hace:
 1) Filtros de calidad (Met N-terminal, repeticiones simples y dímeros).
 2) Lanza TMbed (GPU) sobre el FASTA filtrado.
 3) Filtra las predicciones para descartar todas las proteínas que tengan
    al menos un dominio transmembrana (H/h/B/b) o signal peptide (S).
 4) Guarda un CSV con las secuencias **sin** TM ni SP y reporta totales.
 5) Genera un nuevo FASTA con las secuencias conservadas.
"""

import sys, os, csv, subprocess, tempfile
# Filtros de calidad


def pop_if_dimer_repeat(sequence, seq_id, seq_dict):
    for inicio in (0,1):
        prev = sequence[inicio:inicio+2]; count=1
        for i in range(inicio+2, len(sequence)-1, 2):
            d = sequence[i:i+2]
            if d == prev:
                count += 1
                if count >= 5:
                    seq_dict.pop(seq_id, None)
                    return True
            else:
                count = 1; prev = d
    return False


def pop_if_single_repeat(sequence, seq_id, seq_dict):
    prev = sequence[0]; k = 1
    for aa in sequence[1:]:
        if aa == prev:
            k += 1
            if k >= 4:
                seq_dict.pop(seq_id, None)
                return True
        else:
            k = 1; prev = aa
    return False

# Lectura FASTA y filtrado


def load_fasta(path):
    seqs = {}
    sid = None
    buf = []
    for L in open(path):
        if L.startswith(">"):
            if sid:
                seqs[sid] = "".join(buf).upper()
            sid = L[1:].split()[0].strip()
            buf = []
        else:
            buf.append(L.strip())
    if sid:
        seqs[sid] = "".join(buf).upper()
    return seqs


def quality_filter(seqs):
    kept = seqs.copy()
    for sid, seq in list(seqs.items()):
        if not seq.startswith("M"):
            kept.pop(sid, None); continue
        if pop_if_single_repeat(seq, sid, kept):
            continue
        pop_if_dimer_repeat(seq, sid, kept)
    return kept

# EJECUCIÓN de TMbed


def run_tmbed(fasta, out_pred):
    cmd = [
        "tmbed", "predict",
        "--fasta", fasta,
        "--predictions", out_pred,
        "--use-gpu",
        "--out-format", "0"
    ]
    subprocess.run(cmd, check=True)
# Parseo y filtrado de predicciones


def summarize_and_filter(pred_file, out_csv):
    total = discarded = kept = 0
    kept_ids = []
    with open(pred_file) as f, open(out_csv, "w", newline="") as out:
        w = csv.writer(out)
        w.writerow(["id","has_TM","has_SP"])
        while True:
            header = f.readline().strip()
            if not header:
                break
            _seq = f.readline()
            labels = f.readline().strip()
            pid = header[1:]
            has_tm = any(c in labels for c in "HhBb")
            has_sp = "S" in labels
            total += 1
            if has_tm or has_sp:
                discarded += 1
            else:
                kept += 1
                kept_ids.append(pid)
                w.writerow([pid, 0, 0])
    return total, discarded, kept, kept_ids
# Main pipeline


def main():
    if len(sys.argv) < 2:
        sys.exit(__doc__)
    inp_fasta = "/content/progen_sequences_conserved.fasta"
    out_root  = "results"
    os.makedirs(out_root, exist_ok=True)

    # 1) Carga y filtra
    seqs = load_fasta(inp_fasta)
    print(f"[INFO] Secuencias totales: {len(seqs)}")
    seqs2 = quality_filter(seqs)
    print(f"[INFO] Tras filtros calidad: {len(seqs2)} "
          f"(descartadas {len(seqs)-len(seqs2)})")

    # 2) FASTA temporal filtrado
    with tempfile.TemporaryDirectory() as td:
        filt_fasta = os.path.join(td, "filtered.fasta")
        with open(filt_fasta, "w") as fh:
            for sid, seq in seqs2.items():
                fh.write(f">{sid}\n{seq}\n")

        # 3) Ejecuta TMbed
        pred_file = os.path.join(out_root, "tmbed.pred")
        print("[INFO] Ejecutando TMbed...")
        run_tmbed(filt_fasta, pred_file)

    # 4) Filtra predicciones y resume
    summary_csv = os.path.join(out_root, "filtered_summary.csv")
    total, discarded, kept, kept_ids = summarize_and_filter(pred_file, summary_csv)

    # 5) Genera nuevo FASTA con secuencias conservadas
    kept_fasta = os.path.join(out_root, "kept_sequences.fasta")
    with open(kept_fasta, "w") as fh:
        for pid in kept_ids:
            fh.write(f">{pid}\n{seqs2[pid]}\n")

    # 6) Informar al usuario
    print(f"[OK] Todos los resultados en '{out_root}':")
    print(f"  - Predicción TMbed   : {pred_file}")
    print(f"  - Resumen CSV        : {summary_csv}")
    print(f"  - FASTA conservadas  : {kept_fasta}")
    print(f"  - Total procesadas   : {total}")
    print(f"  - Descartadas (TM/SP): {discarded}")
    print(f"  - Conservar          : {kept}")

if __name__ == "__main__":
    main()



[INFO] Secuencias totales: 9861
[INFO] Tras filtros calidad: 1206 (descartadas 8655)
[INFO] Ejecutando TMbed...
[OK] Todos los resultados en 'results':
  - Predicción TMbed   : results/tmbed.pred
  - Resumen CSV        : results/filtered_summary.csv
  - FASTA conservadas  : results/kept_sequences.fasta
  - Total procesadas   : 1206
  - Descartadas (TM/SP): 347
  - Conservar          : 859


## 3. Immunogenicidad


In [None]:
!pip install fair-esm
!pip install biopython


Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
Collecting biopython
  Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Downloading biopython-1.85-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: biopython
Successfully installed biopython-1.85


In [None]:
from pathlib import Path
from typing import Union, List, Optional
import torch, numpy as np
from esm import FastaBatchedDataset, pretrained
from tqdm import tqdm


def extract_embeddings_chunked(
    model_name: str,
    fasta_file: Union[str, Path],
    output_dir: Union[str, Path],
    tokens_per_batch: int = 3000,
    seq_length: int = 1022,
    repr_layer: int = 33,
    batches_per_file: int = 1000,   # ⬅️ guarda cada 1 000 batches
):
    fasta_file = Path(fasta_file)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # 1) Modelo
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()

    # 2) DataLoader
    dataset  = FastaBatchedDataset.from_file(str(fasta_file))
    batches  = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)
    collate  = alphabet.get_batch_converter(seq_length)
    loader   = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batches,
        collate_fn=collate,
        num_workers=4,
        pin_memory=True,
    )

    # 3) Buffers para acumular hasta batches_per_file
    buf_ids, buf_embs = [], []
    chunk_idx = 0         # número de archivo guardado

    with torch.inference_mode():
        for bidx, (labels, strs, toks) in enumerate(tqdm(loader, desc="Batches")):
            if torch.cuda.is_available():
                toks = toks.cuda(non_blocking=True)

            out = model(toks, repr_layers=[repr_layer], return_contacts=False)
            rep = out["representations"][repr_layer]   # (B, L+1, 1280)

            # → emb medio por secuencia
            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                L = min(seq_length, len(strs[i]))
                emb = rep[i, 1:L+1].mean(0).detach().cpu()   # 1 280-D

                buf_ids.append(entry_id)
                buf_embs.append(emb)

            # ¿Hemos llegado a batches_per_file?
            if (bidx + 1) % batches_per_file == 0:
                _save_chunk(buf_ids, buf_embs, output_dir, chunk_idx)
                chunk_idx += 1
                buf_ids, buf_embs = [], []        # vacía buffers
                torch.cuda.empty_cache()

        # Al terminar, guarda lo que quede (último trozo)
        if buf_ids:
            _save_chunk(buf_ids, buf_embs, output_dir, chunk_idx)


def _save_chunk(ids: List[str], embs: List[torch.Tensor], outdir: Path, idx: int):
    """Convierte a numpy y guarda chunk_{idx}.npz"""
    X = torch.stack(embs).numpy().astype("float32")   # (n,1280)
    np.savez_compressed(outdir / f"chunk_{idx:04d}.npz",
                        ids=np.array(ids),
                        X=X)
    print(f"🔹 Guardado chunk_{idx:04d}.npz  ({len(ids)} secuencias)")

import sys
model_name = 'esm2_t33_650M_UR50D'
fasta_file = "/content/results/kept_sequences.fasta"
output_dir = 'immunology_prediction'
extract_embeddings_chunked(model_name, fasta_file, output_dir)
import os
import numpy as np

# Directorios y nombres
emb_dir = "/content/immunology_prediction"   # donde están chunk_000*.npz
out_dir = "/content/"
os.makedirs(out_dir, exist_ok=True)

final_name = "protein_embeddings"      # fichero final: unlabeled_ESM_ALL.npz

# Buffers
all_ids, all_X, chunk_index = [], [], []

# Recorre todos los .npz ordenados
npz_files = sorted(f for f in os.listdir(emb_dir) if f.endswith(".npz"))

for idx, fname in enumerate(npz_files, 1):
    path = os.path.join(emb_dir, fname)
    print(f"Cargando {idx}/{len(npz_files)} → {fname}")

    data = np.load(path, allow_pickle=False)
    X_chunk  = data["X"]                 # (m,1280)  float32
    ids_chunk = data["ids"]              # (m,)      str

    all_X.append(X_chunk)
    all_ids.append(ids_chunk)
    chunk_index.append(np.full(len(ids_chunk), idx-1, dtype=np.int16))  # opcional

# Concatenar todo
X_final   = np.vstack(all_X)             # (N,1280)
ids_final = np.concatenate(all_ids)      # (N,)

# Guardar
out_path = os.path.join(out_dir, f"{final_name}.npz")
np.savez_compressed(out_path,
                    X=X_final,
                    ids=ids_final,
                    chunks=np.concatenate(chunk_index))   # 'chunks' es opcional

print(f"\n✔️  Guardado archivo único: {out_path}")
print("   Embeddings totales:", X_final.shape[0])




Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to /root/.cache/torch/hub/checkpoints/esm2_t33_650M_UR50D-contact-regression.pt
Batches: 100%|██████████| 29/29 [00:08<00:00,  3.52it/s]


🔹 Guardado chunk_0000.npz  (859 secuencias)
Cargando 1/1 → chunk_0000.npz

✔️  Guardado archivo único: /content/protein_embeddings.npz
   Embeddings totales: 859


In [None]:
import pandas as pd
from tensorflow.keras.models import load_model
from Bio import SeqIO  # pip install biopython
import numpy as np
X_final = np.load('/content/protein_embeddings.npz')['X']
ids_final = np.load('/content/protein_embeddings.npz')['ids']
input_fasta = "/content/kept_sequences.fasta"
output_fasta = "/content/immupig_filtered.fasta"
# 1) Carga el modelo y predice
modelo_immun = load_model('immupig_epoch500.keras')
probs = modelo_immun.predict(X_final).ravel()   # shape (n,)

# 2) DataFrame y CSV
df = pd.DataFrame({
    'id':    ids_final,
    'score': probs
})
df.to_csv('scores_immupig.csv', index=False)
print("CSV escrito: scores_immupig.csv")

# 3) Filtrar FASTA según umbral
threshold = 0.5
# crear un set con los ids que cumplen score > 0.5
ids_bajos = set(df.loc[df['score'] > threshold, 'id'])

total = 0
kept  = 0

with open(output_fasta, 'w') as out_f:
    for rec in SeqIO.parse(input_fasta, 'fasta'):
        total += 1
        if rec.id in ids_bajos:
            SeqIO.write(rec, out_f, 'fasta')
            kept += 1

discarded = total - kept

print(f"Secuencias originales: {total}")
print(f"Secuencias mantenidas (score > {threshold}): {kept}")
print(f"Secuencias descartadas: {discarded}")
print(f"FASTAs filtrados escritos en: {output_fasta}")

[1m27/27[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 117ms/step
CSV escrito: scores_immupig.csv
Secuencias originales: 859
Secuencias mantenidas (score > 0.5): 499
Secuencias descartadas: 360
FASTAs filtrados escritos en: /content/immupig_filtered.fasta


#FILTRADO PLAUSIBILIDAD ESTRUCTURAL

## 1. Designable vs Undesignable folds




In [None]:
%%time
#@title install
#@markdown install ESMFold, OpenFold and download Params (~2min 30s)
version = "1" # @param ["0", "1"]
model_name = "esmfold_v0.model" if version == "0" else "esmfold.model"
import os, time
if not os.path.isfile(model_name):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system(f"aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/{model_name} &")

  if not os.path.isfile("finished_install"):
    # install libs
    print("installing libs...")
    os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol modelcif")
    os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

    print("installing openfold...")
    # install openfold
    os.system(f"pip install -q git+https://github.com/sokrypton/openfold.git")

    print("installing esmfold...")
    # install esmfold
    os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")
    os.system("touch finished_install")

  # wait for Params to finish downloading...
  while not os.path.isfile(model_name):
    time.sleep(5)
  if os.path.isfile(f"{model_name}.aria2"):
    print("downloading params...")
  while os.path.isfile(f"{model_name}.aria2"):
    time.sleep(5)

installing libs...
installing openfold...
installing esmfold...
CPU times: user 616 ms, sys: 61.7 ms, total: 678 ms
Wall time: 3min 47s
