<a href="https://colab.research.google.com/github/despotZZ/colab/blob/main/protein_language_modeling_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install Transformers as well as some other libraries.

In [None]:
# 安装库：
# 1. Transformer是深度学习的框架，现阶段所有最尖端的深度学习都使用Transformer。
# Transformer有一个厉害的点是有self-attention机制，它可以注意到一些和忽略一些feature。它对每一个feature有不同的权重，因此它不会忽视掉远距离的关系，这在蛋白质中非常重要。
# 2. Pandas用于处理数据。
# 3. Datasets用于构建训练集。当使用Transformer框架训练模型时，可以使用Datasets库简单的构建训练集和测试集。
# 4. scikit-learn是机器学习库。它提供了很多方法，可以用于数据预处理等操作。
! pip install transformers pandas datasets scikit-learn



In [None]:
# 链接Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Fine-Tuning Protein Language Models

The specific model we're going to use is ESM-2. The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).

There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints are:

| Checkpoint name | Num layers | Num parameters |
|------------------------------|----|----------|
| `esm2_t48_15B_UR50D`         | 48 | 15B     |
| `esm2_t36_3B_UR50D`          | 36 | 3B      |
| `esm2_t33_650M_UR50D`        | 33 | 650M    |
| `esm2_t30_150M_UR50D`        | 30 | 150M    |
| `esm2_t12_35M_UR50D`         | 12 | 35M     |
| `esm2_t6_8M_UR50D`           | 6  | 8M      |

Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! Also, note that memory usage for attention during training will scale as `O(batch_size * num_layers * seq_len^2)`, so larger models on long sequences will use quite a lot of memory! We will use the `esm2_t12_35M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU.

In [None]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D"
# model_checkpoint = "drive/MyDrive/esm_go_function"

## Data preparation

In [None]:
import pandas

# 加载数据
df = pandas.read_csv('drive/MyDrive/humen_data.tsv.gz', compression='gzip', sep='\t')
df.drop(['Entry','Gene Ontology (biological process)','Gene Ontology (cellular component)'], axis=1, inplace=True)

df = df.dropna()
# New
df = df[df['Sequence'].str.len() <= 1000]

# 将Gene Ontology (molecular function)按照；分开变成一个列表
df['Gene Ontology (molecular function)'] = df['Gene Ontology (molecular function)'].apply(lambda x: x.split('; '))
# 判断列表里的每个元素出现过多少次
label_counts = pandas.Series([label for sublist in df['Gene Ontology (molecular function)'] for label in sublist]).value_counts()
# 设定阈值。现在这个不严谨的方法中需要舍弃掉阈值以下的targets
threshold = 100
# 保留下来的targets
labels_to_keep = label_counts[label_counts >= threshold]

# 将label重组为df（添加新列）
# df['molecular function'] = df['Gene Ontology (molecular function)'].apply(lambda labels: [label for label in labels if label in labels_to_keep])
# New
df['molecular function'] = df['Gene Ontology (molecular function)'].apply(lambda labels: [label if label in labels_to_keep else 'Other(One of Uncommon Functions)' for label in labels])
# 删除旧列
df.drop(['Gene Ontology (molecular function)'], axis=1, inplace=True)
# 去除缺失值
df = df[df['molecular function'].map(len) > 1]

df

Unnamed: 0,Sequence,molecular function
0,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...,"[Other(One of Uncommon Functions), heme bindin..."
2,MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...,"[ATP binding [GO:0005524], magnesium ion bindi..."
3,MRWQEMGYIFYPRKLR,"[DNA binding [GO:0003677], DNA-binding transcr..."
6,MLLLLLLLLLLPPLVLRVAASRCLHDETQKSVSLLRPPFSQLPSKS...,"[metal ion binding [GO:0046872], Other(One of ..."
8,MTAEDSTAAMSSDSAAGSSAKVPEGVAGAPNEAALLALMERTGYSM...,"[enzyme binding [GO:0019899], Other(One of Unc..."
...,...,...
19884,MALSQGLLTFRDVAIEFSQEEWKCLDPAQRTLYRDVMLENYRNLVS...,"[DNA-binding transcription activator activity,..."
19958,MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...,"[DNA binding [GO:0003677], metal ion binding [..."
19966,MADKRAGTPEAAARPPPGLAREGDARTVPAARAREAGGRGSLHPAA...,"[chromatin binding [GO:0003682], histone bindi..."
19992,MLMPKKNRIAIHELLFKEGVMVAKKDVHMPKHPELADKNVPNLHVM...,"[RNA binding [GO:0003723], structural constitu..."


In [None]:
# New
total_labels = sum(label_counts)
label_weights = labels_to_keep.apply(lambda count: total_labels / ((len(labels_to_keep)+1) * count))

label_weights['Other(One of Uncommon Functions)'] = total_labels / ((len(labels_to_keep)+1) * (total_labels - sum(labels_to_keep)))

label_weight_dict = label_weights.to_dict()
label_weight_dict = {k: label_weight_dict[k] for k in sorted(label_weight_dict)}

label_weight_dict

{'ATP binding [GO:0005524]': 0.655970470955922,
 'ATP hydrolysis activity [GO:0016887]': 2.8778959810874705,
 'DNA binding [GO:0003677]': 0.8898757309941521,
 'DNA-binding transcription activator activity, RNA polymerase II-specific [GO:0001228]': 1.533572688334593,
 'DNA-binding transcription factor activity [GO:0003700]': 1.4638648388648388,
 'DNA-binding transcription factor activity, RNA polymerase II-specific [GO:0000981]': 0.6060085623257666,
 'DNA-binding transcription factor binding [GO:0140297]': 4.569632132132132,
 'DNA-binding transcription repressor activity, RNA polymerase II-specific [GO:0001227]': 2.2173952641165755,
 'G protein-coupled receptor activity [GO:0004930]': 1.0154738071404739,
 'GTP binding [GO:0005525]': 1.8734225915666358,
 'GTPase activator activity [GO:0005096]': 3.864603174603175,
 'GTPase activity [GO:0003924]': 2.1402074542897327,
 'Other(One of Uncommon Functions)': 0.027380791722896987,
 'RNA binding [GO:0003723]': 0.5775453079039757,
 'RNA polymeras

In [None]:
# 导入sklearn预处理模块的多目标编码器MultiLabelBinarizer
from sklearn.preprocessing import MultiLabelBinarizer
import numpy as np

mlb = MultiLabelBinarizer()

# 把molecular function这一列变成one_hot编码。One_hot编码解释：https://cloud.tencent.com/developer/article/1688022。
# 因为模型不能处理列表targets，因此需要改成多个targets。
# 因此One_hot编码把值为列表的targets变成多个值为单一数字的targets。
# New,这里之前的解释有误，应该是multi_hot，不是one_hot，one_hot是只有一个1，咱们有多个1，但是代码没改
multi_label_encoded = mlb.fit_transform(df['molecular function'])
# New
multi_label_encoded = multi_label_encoded.astype(np.float32)
# 把处理好的数据加上表头
multi_label_df = pandas.DataFrame(multi_label_encoded, columns=mlb.classes_)
# New
df_reset = df.drop('molecular function', axis=1).reset_index(drop=True)
multi_label_df_reset = multi_label_df.reset_index(drop=True)
# 删除molecular function这一列，然后拼接新的
result_df = pandas.concat([df_reset, multi_label_df_reset], axis=1)

result_df

Unnamed: 0,Sequence,ATP binding [GO:0005524],ATP hydrolysis activity [GO:0016887],DNA binding [GO:0003677],"DNA-binding transcription activator activity, RNA polymerase II-specific [GO:0001228]",DNA-binding transcription factor activity [GO:0003700],"DNA-binding transcription factor activity, RNA polymerase II-specific [GO:0000981]",DNA-binding transcription factor binding [GO:0140297],"DNA-binding transcription repressor activity, RNA polymerase II-specific [GO:0001227]",G protein-coupled receptor activity [GO:0004930],...,transcription cis-regulatory region binding [GO:0000976],transcription coactivator activity [GO:0003713],transcription corepressor activity [GO:0003714],transmembrane signaling receptor activity [GO:0004888],transmembrane transporter binding [GO:0044325],ubiquitin protein ligase activity [GO:0061630],ubiquitin protein ligase binding [GO:0031625],ubiquitin-protein transferase activity [GO:0004842],unfolded protein binding [GO:0051082],zinc ion binding [GO:0008270]
0,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,MRWQEMGYIFYPRKLR,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,MLLLLLLLLLLPPLVLRVAASRCLHDETQKSVSLLRPPFSQLPSKS...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,MTAEDSTAAMSSDSAAGSSAKVPEGVAGAPNEAALLALMERTGYSM...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
10866,MALSQGLLTFRDVAIEFSQEEWKCLDPAQRTLYRDVMLENYRNLVS...,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10867,MPTNCAAAGCATTYNKHINISFHRFPLDPKRRKEWVRLVRRKNFVP...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10868,MADKRAGTPEAAARPPPGLAREGDARTVPAARAREAGGRGSLHPAA...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
10869,MLMPKKNRIAIHELLFKEGVMVAKKDVHMPKHPELADKNVPNLHVM...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [None]:
# New
# 1. 拆分训练集和测试集
# 2. 拆分feature（sequence）和targets
from sklearn.model_selection import train_test_split

num_labels = len(result_df.columns)-1

X = result_df['Sequence'].tolist()

Y = result_df.drop(['Sequence'], axis=1).values.tolist()

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)

print(X_test)
print(Y_test)


['MRPTLLWSLLLLLGVFAAAAAAPPDPLSQLPAPQHPKIRLYNAEQVLSWEPVALSNSTRPVVYQVQFKYTDSKWFTADIMSIGVNCTQITATECDFTAASPSAGFPMDFNVTLRLRAELGALHSAWVTMPWFQHYRNVTVGPPENIEVTPGEGSLIIRFSSPFDIADTSTAFFCYYVHYWEKGGIQQVKGPFRSNSISLDNLKPSRVYCLQVQAQLLWNKSNIFRVGHLSNISCYETMADASTELQQVILISVGTFSLLSVLAGACFFLVLKYRGLIKYWFHTPPSIPLQIEEYLKDPTQPILEALDKDSSPKDDVWDSVSIISFPEKEQEDVLQTL', 'MQQNNSVPEFILLGLTQDPLRQKIVFVIFLIFYMGTVVGNMLIIVTIKSSRTLGSPMYFFLFYLSFADSCFSTSTAPRLIVDALSEKKIITYNECMTQVFALHLFGCMEIFVLILMAVDRYVAICKPLRYPTIMSQQVCIILIVLAWIGSLIHSTAQIILALRLPFCGPYLIDHYCCDLQPLLKLACMDTYMINLLLVSNSGAICSSSFMILIISYIVILHSLRNHSAKGKKKALSACTSHIIVVILFFGPCIFIYTRPPTTFPMDKMVAVFYTIGTPFLNPLIYTLRNAEVKNAMRKLWHGKIISENKG', 'MNHKSKKRIREAKRSARPELKDSLDWTRHNYYESFSLSPAAVADNVERADALQLSVEEFVERYERPYKPVVLLNAQEGWSAQEKWTLERLKRKYRNQKFKCGEDNDGYSVKMKMKYYIEYMESTRDDSPLYIFDSSYGEHPKRRKLLEDYKVPKFFTDDLFQYAGEKRRPPYRWFVMGPPRSGTGIHIDPLGTSAWNALVQGHKRWCLFPTSTPRELIKVTRDEGGNQQDEAITWFNVIYPRTQLPTWPPEFKPLEILQKPGETVFVPGGWWHVVLNLDTTIAITQNFASSTNFPVVWHKTVRGRPKLSRKWYRILKQEHPELAVLADSVDLQESTGIASDSS

## Tokenizing the data

In [None]:
# 从已有模型里面获取Tokenizer。因为大语言模型接收的数据是token，因此需要将数据token化。
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
# token化
train_tokenized = tokenizer(X_train)
test_tokenized = tokenizer(X_test)

## Dataset creation

In [None]:
# 使用Dataset库建立可用于模型训练的数据格式
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

# 把targets拼接到后面
train_dataset = train_dataset.add_column("labels", Y_train)
test_dataset = test_dataset.add_column("labels", Y_test)

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 8696
})

## Model loading

In [None]:
# TFAutoModelForSequenceClassification是一个分类器。使用它可以方便地加载不同的预训练 Transformer 模型，而无需手动选择和下载相应的模型。
from transformers import TFAutoModelForSequenceClassification
model = TFAutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFEsmForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'esm.embeddings.position_ids']
- This IS expected if you are initializing TFEsmForSequenceClassification from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFEsmForSequenceClassification from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
Some weights or buffers of the TF 2.0 model TFEsmForSequenceClassification were not initialized from the PyTorch model and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.b

In [None]:
# 训练前的准备工作，使用Transformer的固定写法
tf_train_set = model.prepare_tf_dataset(
    train_dataset,
    batch_size=16,
    shuffle=True,
    tokenizer=tokenizer
)

tf_test_set = model.prepare_tf_dataset(
    test_dataset,
    batch_size=16,
    shuffle=False,
    tokenizer=tokenizer
)

In [None]:
# New
import tensorflow as tf
initial_learning_rate = 4e-5
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=543,
    decay_rate=0.96,
    staircase=True)

optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

In [None]:
# # New
# from transformers import AdamWeightDecay

# class WeightedBinaryCrossEntropy(tf.keras.losses.Loss):
#     def __init__(self, weights, from_logits=False, name='weighted_binary_crossentropy'):
#         super(WeightedBinaryCrossEntropy, self).__init__(name=name)
#         self.weights = weights  # 权重字典
#         self.from_logits = from_logits

#     def call(self, y_true, y_pred):
#         if not self.from_logits:
#             epsilon = tf.keras.backend.epsilon()
#             y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
#             y_pred = tf.math.log(y_pred / (1 - y_pred))

#         weights_list = [self.weights[label] for label in sorted(self.weights.keys())]
#         weights_tensor = tf.constant(weights_list, dtype=tf.float32)

#         # 计算加权二进制交叉熵
#         bce = tf.keras.backend.binary_crossentropy(y_true, y_pred, from_logits=True)

#         # 扩展 weights_tensor 以匹配 bce 的形状
#         weights_tensor = tf.reshape(weights_tensor, [1, -1])  # 从 [num_labels] 调整为 [1, num_labels]
#         weights_tensor = tf.broadcast_to(weights_tensor, tf.shape(bce))  # 扩展到与 bce 相同的形状

#         weighted_bce = weights_tensor * bce  # 应用权重

#         return tf.reduce_mean(weighted_bce)


# model.compile(optimizer=optimizer, loss=WeightedBinaryCrossEntropy(weights=label_weight_dict, from_logits=True), metrics=['Precision', 'Recall'])

In [None]:
# New
from transformers import AdamWeightDecay

class WeightedFocalLoss(tf.keras.losses.Loss):
    def __init__(self, weights, gamma=2.0, alpha=0.25, from_logits=False, name="weighted_focal_loss"):
        super(WeightedFocalLoss, self).__init__(name=name)
        self.weights = tf.constant([weights[label] for label in sorted(weights.keys())], dtype=tf.float32)
        self.gamma = gamma
        self.alpha = alpha
        self.from_logits = from_logits

    def call(self, y_true, y_pred):
        if self.from_logits:
            y_pred = tf.sigmoid(y_pred)
        else:
            epsilon = tf.keras.backend.epsilon()
            y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)

        cross_entropy = tf.keras.backend.binary_crossentropy(y_true, y_pred, from_logits=False)
        probs = tf.where(y_true == 1, y_pred, 1 - y_pred)
        alpha = tf.where(y_true == 1, self.alpha, 1 - self.alpha)
        focal_weight = alpha * tf.pow((1 - probs), self.gamma)

        # 调整损失以考虑每个标签的权重
        weighted_focal_weight = self.weights * focal_weight

        focal_loss = weighted_focal_weight * cross_entropy
        return tf.reduce_mean(focal_loss)

model.compile(optimizer=optimizer,
              loss=WeightedFocalLoss(weights=label_weight_dict, from_logits=False),
              metrics=['Precision', 'Recall'])

In [None]:
# 训练
model.fit(tf_train_set, validation_data=tf_test_set, epochs=5)

Epoch 1/5




Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x7cc9119ea1d0>

In [None]:
# 保存
model.save_pretrained('drive/MyDrive/esm_go_function')
tokenizer.save_pretrained('drive/MyDrive/esm_go_function')

('drive/MyDrive/esm_go_function/tokenizer_config.json',
 'drive/MyDrive/esm_go_function/special_tokens_map.json',
 'drive/MyDrive/esm_go_function/vocab.txt',
 'drive/MyDrive/esm_go_function/added_tokens.json')