In [None]:
!pip install fair-esm
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118


Collecting fair-esm
  Downloading fair_esm-2.0.0-py3-none-any.whl.metadata (37 kB)
Downloading fair_esm-2.0.0-py3-none-any.whl (93 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/93.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fair-esm
Successfully installed fair-esm-2.0.0
Looking in indexes: https://download.pytorch.org/whl/cu118
INFO: pip is looking at multiple versions of torch to determine which version is compatible with other requirements. This could take a while.
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.0%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading https://download.pytorch.org/whl/sympy-1.13.3-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu11==11.8.89 (from torch)
  Downloading htt

In [None]:
# coding: utf-8
from pathlib import Path
from typing import Union, List, Optional

import torch
from esm import FastaBatchedDataset, pretrained


def extract_embeddings(
    model_name: str,
    fasta_file: Union[str, Path],
    output_dir: Union[str, Path],
    tokens_per_batch: int = 4096,
    seq_length: int = 1022,
    repr_layers: Optional[List[int]] = None,
):

    if repr_layers is None:
        repr_layers = [33]

    # Aseguramos que las rutas sean objetos Path
    fasta_file = Path(fasta_file)
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # 1. Carga del modelo y del alfabeto
    model, alphabet = pretrained.load_model_and_alphabet(model_name)
    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()

    # 2. Preparación del DataLoader
    dataset = FastaBatchedDataset.from_file(str(fasta_file))
    batches = dataset.get_batch_indices(tokens_per_batch, extra_toks_per_seq=1)
    collate_fn = alphabet.get_batch_converter(seq_length)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=batches,
        collate_fn=collate_fn,
    )

    # 3. Extracción de embeddings
    with torch.no_grad():
        for batch_idx, (labels, strs, toks) in enumerate(data_loader):
            print(f"Processing batch {batch_idx + 1} of {len(batches)}")

            if torch.cuda.is_available():
                toks = toks.to(device="cuda", non_blocking=True)

            out = model(toks, repr_layers=repr_layers, return_contacts=False)

            representations = {
                layer: t.to(device="cpu") for layer, t in out["representations"].items()
            }

            for i, label in enumerate(labels):
                entry_id = label.split()[0]
                filename = output_dir / f"{entry_id}.pt"

                truncate_len = min(seq_length, len(strs[i]))

                mean_reps = {
                    layer: t[i, 1 : truncate_len + 1].mean(0).clone()
                    for layer, t in representations.items()
                }
                print(f"--> guardando {filename}")
                print(mean_reps[repr_layers[0]][:5])
                torch.save(
                    {
                        "entry_id": entry_id,
                        "mean_representations": mean_reps,
                    },
                    filename,
                )

import sys
model_name = 'esm2_t33_650M_UR50D'
fasta_file = "/content/drive/MyDrive/TFM/DATABASE/unlabeled_clean.fasta"
output_dir = '/content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled'

extract_embeddings(model_name, fasta_file, output_dir)

[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
Processing batch 4766 of 6077
--> guardando /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1CBM9|A0A0V1CBM9_TRIBR.pt
tensor([-0.0136, -0.0338,  0.0048,  0.0305, -0.0229])
--> guardando /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1CLF5|A0A0V1CLF5_TRIBR.pt
tensor([-0.0123, -0.0392,  0.0303, -0.0749, -0.0260])
Processing batch 4767 of 6077
--> guardando /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1DFW7|A0A0V1DFW7_TRIBR.pt
tensor([ 0.0172,  0.0014, -0.0294,  0.0188,  0.0820])
--> guardando /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|F1KR09|F1KR09_ASCSU.pt
tensor([ 0.0237, -0.0596, -0.0163,  0.0089, -0.0174])
Processing batch 4768 of 6077
--> guardando /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|F1KQM1|F1KQM1_ASCSU.pt
tensor([-0.0020, -0.0302, -0.0035, -0.0040, -0.1004])
--> guardando /content/drive/MyDrive

A continuación vamos a generar un objeto numpy para el entrenamiento del primer clasificador (apilar tensores)

In [None]:
print("hola")

hola


In [None]:
import os
import torch
import numpy as np

# Rutas
emb_dir = "/content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled"
out_dir = "/content/drive/MyDrive/TFM/ESM_arrays"
name = "unlabeled_ESM_mean"
layer = 33  # capa de ESM a usar

train_X = []
n = 0

# Procesar archivos .pt
pt_files = [f for f in os.listdir(emb_dir) if f.endswith(".pt")]

for filename in pt_files:
    n += 1
    pt_file = os.path.join(emb_dir, filename)
    print(f"Procesando {n}: {pt_file}", flush=True)  # Fuerza impresión inmediata

    try:
        # Cargar embedding en GPU y extraer layer
        emb = torch.load(pt_file, map_location="cuda")["mean_representations"][layer]
        train_X.append(emb)  # mantener en GPU hasta el final
    except Exception as e:
        print(f"Error en {pt_file}: {e}", flush=True)

# Unir todos los embeddings (en GPU)
X_tensor = torch.stack(train_X)

# Mover a CPU y convertir a numpy
X_np = X_tensor.cpu().numpy().astype("float32")

# Guardar
os.makedirs(out_dir, exist_ok=True)
np.savez_compressed(os.path.join(out_dir, f"{name}.npz"), X=X_np)






[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
Procesando 9265: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1BX19|A0A0V1BX19_TRISP.pt
Procesando 9266: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1BIM0|A0A0V1BIM0_TRISP.pt
Procesando 9267: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1AWX8|A0A0V1AWX8_TRISP.pt
Procesando 9268: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1BWD8|A0A0V1BWD8_TRISP.pt
Procesando 9269: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1AXC8|A0A0V1AXC8_TRISP.pt
Procesando 9270: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1BK15|A0A0V1BK15_TRISP.pt
Procesando 9271: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1CRC6|A0A0V1CRC6_TRIBR.pt
Procesando 9272: /content/drive/MyDrive/TFM/DATABASE/ESM_embeddings_unlabeled/tr|A0A0V1CHX6|A0A0V1CHX6_TRIBR.pt
Procesando 9273: /content/dri