In [None]:
import torch
import esm
import pandas as pd

device = torch.device("cuda:0")
try:
    torch.zeros(1).to(device)
except RuntimeError:
    device = torch.device("cpu")

print(f"Using device: {device}")

In [None]:
from torch import Tensor
import torch.nn as nn
from tqdm import tqdm
import os

class ESMFeaturesExtractor:
    # 模型配置映射
    MODEL_CONFIGS = {
        "esm2_t6_8M_UR50D": {
            "layers": 6,
            "sample_layers": [3, 6],
            "feature_dim": 320,
        },
        "esm2_t12_35M_UR50D": {
            "layers": 12,
            "sample_layers": [6, 12],
            "feature_dim": 480,
        },
        "esm2_t30_150M_UR50D": {
            "layers": 30,
            "sample_layers": [10, 20, 30],
            "feature_dim": 640,
        },
        "esm2_t33_650M_UR50D": {
            "layers": 33,
            "sample_layers": [11, 22, 33],
            "feature_dim": 1280,
        },
        "esm2_t36_3B_UR50D": {
            "layers": 36,
            "sample_layers": [12, 24, 36],
            "feature_dim": 2560,
        },
    }

    def __init__(self, model_name="esm2_t30_150M_UR50D", device="cpu"):
        self.model_name = model_name
        if model_name not in self.MODEL_CONFIGS:
            raise ValueError(f"not support: {model_name}")

        # 加载指定模型
        model, alphabet = getattr(esm.pretrained, model_name)()
        self.batch_converter = alphabet.get_batch_converter()
        self.model = model.eval().to(device)

        self.config = self.MODEL_CONFIGS[model_name]
        self.device = device

    def extract_features(self, sequences):
        """优化后的特征提取函数"""
        features_list = []

        # 设置环境变量
        os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

        for i, seq in tqdm(
            enumerate(sequences),
            total=len(sequences),
            desc="Extracting Features",
        ):
            # 准备单个序列数据
            data = [(f"protein_{i}", seq)]
            _, _, tokens = self.batch_converter(data)

            # 使用with语句自动管理内存
            with torch.amp.autocast('cuda'), torch.no_grad():
                tokens = tokens.to(self.device)
                results = self.model(
                    tokens,
                    repr_layers=self.config["sample_layers"],
                    return_contacts=True,
                )

                # 立即提取需要的特征并释放其他数据
                sequence_features = []
                for layer in self.config["sample_layers"]:
                    features = results["representations"][layer].mean(dim=1)
                    sequence_features.append(features)
                # 合并并立即转移到CPU
                sequence_features = torch.cat(sequence_features, dim=1).cpu()
                features_list.append(sequence_features)


        return torch.cat(features_list, dim=0)

def get_data():
    train_df = pd.read_csv("../data/train.csv")
    train_updates_df = pd.read_csv("../data/train_updates_20220929.csv")
    test_df = pd.read_csv("../data/test.csv")
    train_df.update(train_updates_df)

    train_df.dropna(inplace=True)
    train_df.drop(columns=["pH"], inplace=True)
    train_df.drop(columns=["data_source"], inplace=True)
    train_df.drop(columns=["seq_id"], inplace=True)

    # drop data with len(seq) > 1700
    train_df = train_df[train_df["protein_sequence"].str.len() <= 1500]

    return (
        train_df["protein_sequence"].values,
        train_df["tm"].values,
        test_df["protein_sequence"].values,
    )

In [None]:
train_sequences, train_labels, test_sequences = get_data()

In [None]:
# extract features
features_extractor = ESMFeaturesExtractor(model_name="esm2_t12_35M_UR50D", device=device)

In [None]:
train_features = features_extractor.extract_features(train_sequences)

train_features_df = pd.DataFrame()
train_features_df["protein_sequence"] = train_sequences
train_features_df["features"] = train_features.tolist()
train_features_df["tm"] = train_labels
train_features_df.to_csv("train_features.csv", index=False)

In [None]:
print(f"train_features.shape: {train_features.shape}")

# save features with seq using pd
train_features_df = pd.DataFrame()
train_features_df["protein_sequence"] = train_sequences
train_features_df["features"] = train_features.tolist()
train_features_df["tm"] = train_labels

train_features_df.to_csv("train_features_with_tm.csv", index=False)

In [None]:
print(train_features_df.describe)
print(train_features_df.info)