# Prepare Datasets for Protein Reaction Benchmarking Tasks

> author: Shizhenkun   
> email: zhenkun.shi@tib.cas.cn   
> date: 2025-08-14  



## 1. Import packages

In [1]:
import os, sys, json
import numpy as np
import pandas as pd
from IPython.display import display_markdown
from pandarallel import pandarallel  # 导入 pandarallel
from tkinter import _flatten
sys.path.insert(0, f"{os.path.dirname(os.path.realpath('__file__'))}/../../../")
from rxnrecer.config import config as cfg
from rxnrecer.utils import file_utils as ftool
from rxnrecer.utils import uniprot_utils as uptool
from rxnrecer.utils import rhea_utils as rheatool
from rxnrecer.utils import ec_utils as ectool
from rxnrecer.data.sources import uniprot, rhea
from rxnrecer.lib.ml import mlcommon

pandarallel.initialize(progress_bar=False)
FIRST_TIME_RUN = False  # 初次运行置 True 可下载数据，耗时较长

%load_ext autoreload
%autoreload 2

INFO: Pandarallel will run on 192 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


## 2. Obtain original data from public web sites
### 2.1 Download snapshot data from Uniprot and extract protein records from snapshots

In [2]:
if FIRST_TIME_RUN:    
    
    # Download snapshot data from Uniprot
    ftool.downlod(download_url=cfg.URL_SPROT_SNAP201801, save_file=cfg.FILE_SPROT_SNAP201801) #snapshot 2018-01
    ftool.downlod(download_url=cfg.URL_SPROT_SNAP202401, save_file=cfg.FILE_SPROT_SNAP202401) #snapshot 2024-01
    
    #  Extract protein records from snapshots
    cmd_array = [
        #2018 data
        f'tar -zxvf {cfg.FILE_SPROT_SNAP201801} -C {cfg.UNIPROT_DATA_DIR}',
        f'mv {cfg.UNIPROT_DATA_DIR}uniprot_sprot.dat.gz {cfg.UNIPROT_DATA_DIR}sprot2018.data.gz', 
        f'rm -f {cfg.UNIPROT_DATA_DIR}uniprot_sprot.fasta.gz {cfg.UNIPROT_DATA_DIR}uniprot_sprot_varsplic.fasta.gz {cfg.UNIPROT_DATA_DIR}uniprot_sprot.xml.gz',
        
        #2024 data
        f'tar -zxvf {cfg.FILE_SPROT_SNAP202401} -C {cfg.UNIPROT_DATA_DIR}',
        f'mv {cfg.UNIPROT_DATA_DIR}uniprot_sprot.dat.gz {cfg.UNIPROT_DATA_DIR}sprot2024.data.gz', 
        f'rm -f {cfg.UNIPROT_DATA_DIR}uniprot_sprot.fasta.gz {cfg.UNIPROT_DATA_DIR}uniprot_sprot_varsplic.fasta.gz {cfg.UNIPROT_DATA_DIR}uniprot_sprot.xml.gz'
    ]

    [os.system(item) for item in cmd_array]

    #snapshot 2018-01
    uptool.run_exact_task(infile=f'{cfg.UNIPROT_DATA_DIR}sprot2018.data.gz', outfile=f'{cfg.UNIPROT_DATA_DIR}sprot2018.tsv')
    
    #snapshot 2024-01
    uptool.run_exact_task(infile=f'{cfg.UNIPROT_DATA_DIR}sprot2024.data.gz', outfile=f'{cfg.UNIPROT_DATA_DIR}sprot2024.tsv')
else:
    #sprot 20218
    sprot2018 = pd.read_csv(f'{cfg.UNIPROT_DATA_DIR}sprot2018.tsv', sep='\t',header=0) #读入文件
    date_columns = ['date_integraged', 'date_sequence_update', 'date_annotation_update']
    sprot2018[date_columns] = sprot2018[date_columns].apply(pd.to_datetime, errors='coerce', format='%d-%b-%Y')
    print(f'Records in sprot snapshot 2018-01:\t{len(sprot2018)}')

    #sprot 2024
    sprot2024 = pd.read_csv(f'{cfg.UNIPROT_DATA_DIR}sprot2024.tsv', sep='\t',header=0) #读入文件
    sprot2024[date_columns] = sprot2024[date_columns].apply(pd.to_datetime, errors='coerce', format='%d-%b-%Y')
    print(f'Records in sprot snapshot 2024-01:\t{len(sprot2024)}') 


Records in sprot snapshot 2018-01:	556568
Records in sprot snapshot 2024-01:	570830


### 2.2 Retrieval PR realtion information from Uniprot

In [3]:
if FIRST_TIME_RUN:
    api_url =f'https://rest.uniprot.org/uniprotkb/search?query=reviewed=true&format=tsv&fields=accession,ec,rhea&size=500'

    def is_COMP(stritem):
        return 'RHEA-COMP' not in  stritem

    uniprot_rhea_relation = uniprot.get_batch_data_from_uniprot_rest_api(url=api_url)    #从网站获取数据
    uniprot_rhea_relation = pd.DataFrame(uniprot_rhea_relation, columns=['uniprot_id', 'ec', 'reaction_id'])    #添加列名
    uniprot_rhea_relation = uniprot_rhea_relation.replace('','-')   #处理空值
    uniprot_rhea_relation.reaction_id = uniprot_rhea_relation.reaction_id.apply(lambda x: (';').join([item for item in filter(is_COMP, x.split(' '))]))    #去除化合物
    uniprot_rhea_relation.to_feather(cfg.FILE_UNIPROT_PROTEIN_REACTION_RELATION)    #保存文件
else:
    uniprot_rhea_relation = pd.read_feather(cfg.FILE_UNIPROT_PROTEIN_REACTION_RELATION)
    print(f'Records in uniprot-rhea relation:\t{len(uniprot_rhea_relation)}')

Records in uniprot-rhea relation:	571282


### 2.3 Download Reaction Data

In [4]:
#rhea_web_reactions
if FIRST_TIME_RUN:
    rhea_reactions=rhea.get_rhea_reactions().rename(columns={
        'Reaction identifier':'reaction_id', 
        'Equation':'equation', 
        'ChEBI identifier':'chebi_id', 
        'EC number':'ec_number'})
   
    chebi_cpd = pd.read_feather(cfg.FILE_DS_CHEBI_CPD) # 读取CHEBI 化合物
    rhea_reactions[['equation_chebi', 'equation_smiles']]=rhea_reactions.parallel_apply(lambda x: rheatool.format_equation_chebi(rheaid=x.reaction_id,equation=x.equation, chebiid=x.chebi_id, chebi_cmp_df=chebi_cpd), axis=1, result_type='expand') #格式化反应 # type: ignore

    rhea_reactions.to_feather(cfg.FILE_RHEA_REACTION)
else:
    #rhea-reaction
    rhea_reactions = pd.read_feather(cfg.FILE_RHEA_REACTION)
    print(f'Records in rhea-reaction:\t\t{len(rhea_reactions)}')

Records in rhea-reaction:		16410


## 3. Preprocessing

### 3.1 Filter proteins records with changed sequences

In [5]:
common_set =  sprot2024[['id', 'seq']].merge(sprot2018[['id', 'seq']], on='id',  how='left')
common_set['match']=common_set.apply(lambda x: x.seq_x == x.seq_y, axis=1)
common_set = common_set[(common_set.match==False) & (~common_set.seq_y.isnull())].reset_index(drop=True)

p18 = sprot2018[~sprot2018.id.isin(common_set.id)].reset_index(drop=True)
print(f'Records in p18: {len(p18)}')
p24 = sprot2024[~sprot2024.id.isin(common_set.id)].reset_index(drop=True)
print(f'Records in p24: {len(p24)}')

Records in p18: 554424
Records in p24: 568686


### 3.2 Format EC

In [6]:
#sprot2018
p18['ec_number'] = p18.ec_number.parallel_apply(lambda x: ectool.format_ec(x))
p18['ec_number'] = p18.ec_number.parallel_apply(lambda x: ectool.specific_ecs(x))
p18['functionCounts'] = p18.ec_number.parallel_apply(lambda x: 0 if x=='-'  else len(x.split(cfg.SPLITER)))

#sprot2024
p24['ec_number'] = p24.ec_number.parallel_apply(lambda x: ectool.format_ec(x))
p24['ec_number'] = p24.ec_number.parallel_apply(lambda x: ectool.specific_ecs(x))
p24['functionCounts'] = p24.ec_number.parallel_apply(lambda x: 0 if x=='-'  else len(x.split(cfg.SPLITER)))

#Trim string
with pd.option_context('mode.chained_assignment', None):
    p18.seq = p18.seq.parallel_apply(lambda x : str(x).strip()) #seq trim
    p24.seq = p24.seq.parallel_apply(lambda x : str(x).strip()) #seq trim

### 3.3 Filter reactions without proteins

In [7]:
uids = list(set(p18.id) | set(p24.id))
uniprot_rhea_relation = uniprot_rhea_relation[uniprot_rhea_relation.uniprot_id.isin(uids)].reset_index(drop=True) # type: ignore
print(f'uniprot-rhea relation:\t{len(uniprot_rhea_relation)}')

uniprot-rhea relation:	568685


### 3.4 Filter reactions with relation

In [8]:
# cacl appeared ids
rxn_ids = uniprot_rhea_relation[uniprot_rhea_relation.reaction_id!='-'].reaction_id.to_list()
rxn_ids = _flatten([item.split(';') for item in rxn_ids])
rhea_reactions  = rhea_reactions[rhea_reactions.reaction_id.isin(list(set(rxn_ids)))].reset_index(drop=True)
print(f'Reaction records in RHEA database:\t\t{len(rhea_reactions)}')

Reaction records in RHEA database:		12198


### 3.5 Map reaction and protein

In [9]:
p18 = p18.merge(uniprot_rhea_relation, left_on='id', right_on='uniprot_id', how='left')
p24 = p24.merge(uniprot_rhea_relation, left_on='id', right_on='uniprot_id', how='left')

#del with NULL cells
p18 = p18.fillna('-')
p24 = p24.fillna('-')

#update isenzyme, isMultiFunctional	functionCounts flag
p18['isenzyme'] = p18.reaction_id.parallel_apply(lambda x: True if x!='-' else False)
p24['isenzyme'] = p24.reaction_id.parallel_apply(lambda x: True if x!='-' else False)

p18['functionCounts'] = p18.reaction_id.parallel_apply(lambda x: 0 if x=='-' else len(x.split(';')))
p24['functionCounts'] = p24.reaction_id.parallel_apply(lambda x: 0 if x=='-' else len(x.split(';')))

p18['isMultiFunctional'] = p18.functionCounts.parallel_apply(lambda x: True if x>1 else False)
p24['isMultiFunctional'] = p24.functionCounts.parallel_apply(lambda x: True if x>1 else False)

### 3.6 split train and test

In [10]:
train = p18[['id', 'seq', 'reaction_id', 'ec_number', 'functionCounts', 'ec_specific_level', 'isenzyme']]
print(f'Protein in train: \t{len(train)}')
test = p24[~p24.id.isin(p18.id)][['id', 'seq', 'reaction_id', 'ec_number', 'functionCounts', 'ec_specific_level', 'isenzyme']].reset_index(drop=True)
print(f'Protein in test: \t{len(test)}')


Protein in train: 	554424
Protein in test: 	14782


### 3.7 Filter records with new reaction ids and filter records in sprot 2018 which have EC but without Reaction

In [11]:
rxn_id_train = set(_flatten([item.split(';') for item in train.reaction_id.astype('str')]))
print(f'Reactions in train:\t {len(rxn_id_train)-1}')

rxn_id_test = set(_flatten([item.split(';') for item in test.reaction_id]))
print(f'Reactions in test:\t {len(rxn_id_test)-1}')



def filer_new_ids(rxn_ids, rxn_id_train):
    """
    Filter new reaction ids that appeared in sprot24 but not in sprot18
    """
    if rxn_ids == '-':
        return '-'
    else:
        rxn_ids = set(rxn_ids.split(';'))
        
        res = rxn_ids.intersection(rxn_id_train)
        
        if len(rxn_ids) - len(res) == 0: #所有反应id完全在训练集中
            return ';'.join(rxn_ids)
        elif len(res) == 0:              #所有反应id完全不在训练集中
            return '---'
        else:                            #所有反应id有部分在训练集中
            return  ';'.join(res)


test['keep_flag']=test.reaction_id.apply(lambda x:filer_new_ids(rxn_ids=x, rxn_id_train=rxn_id_train))
test = test[test.keep_flag!='---'].reset_index(drop=True)
test['reaction_id'] = test['keep_flag']
test = test[['id','seq','reaction_id', 'ec_number', 'functionCounts', 'ec_specific_level', 'isenzyme']]

# delete records in sprot 2018 which have EC but without Reaction
train = train[~train.id.isin(train[(train.ec_number!='-') & (train.reaction_id=='-')].id)].reset_index(drop=True)


print(f'Protein in train: \t {len(train)}')
print(f'Protein in test: \t {len(test)}')

Reactions in train:	 10478
Reactions in test:	 3905
Protein in train: 	 508587
Protein in test: 	 13515


## 5. Make labels

### 5.1 make reaction2id dict

In [12]:
# 反应列表，包括'-' 非酶
label_list = train.reaction_id.to_list() + test.reaction_id.to_list()
label_list = list(set(_flatten([item.split(';') for item in label_list])))
dict_rxn2id =dict(zip(label_list, range(len(label_list))))             # 反应对应编号字典
dict_id2rxn = {value: key for key, value in dict_rxn2id.items()}  # 编号对反应号字典

### 5.2 Make one-hot label

In [13]:
with pd.option_context('mode.chained_assignment', None):
    train['label'] = train.reaction_id.parallel_apply(lambda x: mlcommon.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))
    test['label'] = test.reaction_id.parallel_apply(lambda x: mlcommon.make_label(reaction_id=x, rxn_label_dict=dict_rxn2id))

## 6. Save data

In [14]:
# Save trainning and testing sets

train = train.rename(columns={'id': 'uniprot_id'})
test = test.rename(columns={'id': 'uniprot_id'})

# 统一更新分隔符
train.ec_number = train.ec_number.apply(lambda x: x.replace(',', cfg.SPLITER))
test.ec_number = test.ec_number.apply(lambda x: x.replace(',', cfg.SPLITER))

train.to_feather(cfg.FILE_DS_TRAIN)
test.to_feather(cfg.FILE_DS_TEST)

# # Write fasta files
ftool.table2fasta(table=train[['uniprot_id','seq']], file_out=cfg.FILE_DS_TRAIN_FASTA)
ftool.table2fasta(table=test[['uniprot_id','seq']], file_out=cfg.FILE_DS_TEST_FASTA)

# Reaction data
rhea_reactions.to_feather(cfg.FILE_DS_RHEA_REACTIONS)

# 将字典保存为 JSON 文件
with open(cfg.FILE_DS_DICT_RXN2ID, "w") as json_file:
    json.dump(dict_rxn2id, json_file)
    
with open(cfg.FILE_DS_DICT_ID2RXN, "w") as json_file:
    json.dump(dict_id2rxn, json_file)


# 从 JSON 文件加载字典数据
with open(cfg.FILE_DS_DICT_RXN2ID, "r") as json_file:
    dict_rxn2id = json.load(json_file)
    print(dict_rxn2id)  # 打印加载的数据

with open(cfg.FILE_DS_DICT_ID2RXN, "r") as json_file:
    dict_id2rxn = json.load(json_file)
    print(dict_id2rxn)

{'RHEA:32487': 0, 'RHEA:14673': 1, 'RHEA:11516': 2, 'RHEA:22872': 3, 'RHEA:34739': 4, 'RHEA:77219': 5, 'RHEA:36555': 6, 'RHEA:13941': 7, 'RHEA:34627': 8, 'RHEA:29007': 9, 'RHEA:18461': 10, 'RHEA:69416': 11, 'RHEA:28170': 12, 'RHEA:18653': 13, 'RHEA:12665': 14, 'RHEA:22456': 15, 'RHEA:60912': 16, 'RHEA:11916': 17, 'RHEA:21404': 18, 'RHEA:15189': 19, 'RHEA:31495': 20, 'RHEA:72099': 21, 'RHEA:24931': 22, 'RHEA:51072': 23, 'RHEA:16225': 24, 'RHEA:34619': 25, 'RHEA:15297': 26, 'RHEA:25540': 27, 'RHEA:50924': 28, 'RHEA:37647': 29, 'RHEA:46084': 30, 'RHEA:14189': 31, 'RHEA:30927': 32, 'RHEA:29043': 33, 'RHEA:43152': 34, 'RHEA:66784': 35, 'RHEA:71091': 36, 'RHEA:58892': 37, 'RHEA:54540': 38, 'RHEA:26349': 39, 'RHEA:25840': 40, 'RHEA:56580': 41, 'RHEA:41556': 42, 'RHEA:61488': 43, 'RHEA:68232': 44, 'RHEA:65276': 45, 'RHEA:70963': 46, 'RHEA:11500': 47, 'RHEA:37451': 48, 'RHEA:75103': 49, 'RHEA:76379': 50, 'RHEA:43440': 51, 'RHEA:63312': 52, 'RHEA:15305': 53, 'RHEA:40431': 54, 'RHEA:23904': 55, '

## 7. Describe dataset

In [15]:
display_markdown(
   f'''## Summary of the Protein-Reaction (PR) Dataset 
The PR dataset is consists of two parts: <u>a training set</u> and <u>a testing set</u>, include ***{len(dict_rxn2id)-1}*** distinct reactions, of ***{len(rxn_id_test)-1}*** are appeared in the testing set.  

The training set is from snapshot Jan-2018 and ***excludes*** those <u>deleted items</u> and <u>sequences changed items</u> in snapshot Jan-2024.
(We utilize the most recent 6 years\' data from the start of work as test data, while the remaining data is used for training purposes.)  
The Dataset **ds_rcv** is consists of ***{len(train)}*** 
records, of which ***{len(train[train.reaction_id=='-'])}*** without reactions, we treat this part of proteins as ***non-enzyme***, and ***{len(train[train.reaction_id!='-'])}*** with reactions have clear reaction reponses), we treat this part as ***enzymes***.   

The dataset **ds_rcp** is from snapshot Jan-2024 and ***excludes*** those items that appeared in snapshot Jan-2018 and ***filtered*** new reactions which are not appeared in the traning set.   
The dataset **ds_rcp** is consists of ***{len(test)}*** records, of wich ***{len(test[test.reaction_id=='-'])}*** records without reactions, the rest ***{len(test[test.reaction_id!='-'])}*** have reaactions.

    '''
,raw=True)

## Summary of the Protein-Reaction (PR) Dataset 
The PR dataset is consists of two parts: <u>a training set</u> and <u>a testing set</u>, include ***10478*** distinct reactions, of ***3905*** are appeared in the testing set.  

The training set is from snapshot Jan-2018 and ***excludes*** those <u>deleted items</u> and <u>sequences changed items</u> in snapshot Jan-2024.
(We utilize the most recent 6 years' data from the start of work as test data, while the remaining data is used for training purposes.)  
The Dataset **ds_rcv** is consists of ***508587*** 
records, of which ***282920*** without reactions, we treat this part of proteins as ***non-enzyme***, and ***225667*** with reactions have clear reaction reponses), we treat this part as ***enzymes***.   

The dataset **ds_rcp** is from snapshot Jan-2024 and ***excludes*** those items that appeared in snapshot Jan-2018 and ***filtered*** new reactions which are not appeared in the traning set.   
The dataset **ds_rcp** is consists of ***13515*** records, of wich ***10310*** records without reactions, the rest ***3205*** have reaactions.

    

In [16]:
train

Unnamed: 0,uniprot_id,seq,reaction_id,ec_number,functionCounts,ec_specific_level,isenzyme,label
0,Q6GZX4,MAFSAEDVLKEYDRRRRMEALLLSLYYPNDRKLLDYKEWSPPRVQV...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,Q6GZX3,MSIIGATRLQNDKSDTYSAGPCYAGGCSAFTPRGTCGKDWDLGEQT...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,Q197F8,MASNTVSAQGGSNRPVRDFSNIQDVAQFLLFDPIWNEQPGSIVPWK...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,Q197F7,MYQAINPCPQSWYGSPQLEREIVCKMSGAPHYPNYYPVHPNALGGA...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,Q6GZX2,MARPLLGKTSSVRRRLESLSACSIFFFLRKFCQKMASLVFLNSPVY...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...
508582,Q6UY62,MGNSKSKSKLSANQYEQQTVNSTKQVAILKRQAEPSLYGRHNCRCC...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
508583,P08105,MSSSLEITSFYSFIWTPHIGPLLFGIGLWFSMFKEPSHFCPCQHPH...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
508584,Q88470,MGNCNRTQKPSSSSNNLEKPPQAAEFRRTAEPSLYGRYNCKCCWFA...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
508585,A9JR22,MGLRYSKEVRDRHGDKDPEGRIPITQTMPQTLYGRYNCKSCWFANK...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [17]:
test

Unnamed: 0,uniprot_id,seq,reaction_id,ec_number,functionCounts,ec_specific_level,isenzyme,label
0,A9JLI2,MLGLQIFTLLSIPTLLYTYEIEPLERTSTPPEKEFGYWCTYANHCR...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,A9JLI3,MRFFSYLGLLLAGLTSLQGFSTDNLLEEELRYWCQYVKNCRFCWTC...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,A9JLI5,MLVIFLGILGLLANQVLGLPTQAEGHLRSTDNPPQEELGYWCTYME...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,A9JLI7,MLVIILGVIGLLANQVLGLPTQAGGHLRSTDNPPQEELGYWCTYME...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,B5KVH4,MAKPILLSIYLCLIIVALFNGCLAQSGGRQQHKFGQCQLNRLDALE...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...
13510,P0DW91,MSGAEEAGGGGPAAGPAGSVPAGVGVGAGAGAGVGVGAGPGAAAGP...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
13511,P0DTL6,MSGAEEAGGGGPAAGPAGSVPAGVGVGVGAGPGAAAGQAAAAALGE...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
13512,P0DW87,MSGAEEAGGGGPAAGPAGAVPAGVGVGAGPGAAAGPAAAALGEAAG...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
13513,P0DW89,MSGAEEAGGGGPAAGPAGAVPAGVGVGVGPGAAAGPAAAALGEAAG...,-,-,0,0,False,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


## 8. Make fixed validation set

In [18]:
def make_train_valid_set(dataset, valid_ratio=0.1, seed=42):
    # 设置随机种子以确保可重复性
    np.random.seed(seed)

    # 计算验证集大小
    valid_size = int(len(dataset) * valid_ratio)

    # 随机生成验证集的索引
    valid_indices = np.random.choice(dataset.index, size=valid_size, replace=False)

    # 划分验证集和训练集
    valid_set = dataset.loc[valid_indices]
    train_set = dataset.drop(valid_indices)

    return train_set, valid_set

In [19]:
# 读取训练和测试集
train_data = pd.read_feather(cfg.FILE_DS_TRAIN)

# 定义多个随机种子
seeds = [1, 42, 66, 90, 99, 999, 10000, 301, 789, 615]

# 遍历种子并生成不同的验证集
for fold_index, seed in enumerate(seeds, start=1):
    # 定义保存路径并创建目录
    write_dir = os.path.join(cfg.DIR_DATASET, f'validation/fold{fold_index}')
    os.makedirs(write_dir, exist_ok=True)

    # 生成训练集和验证集
    train_set, valid_set = make_train_valid_set(train_data, valid_ratio=0.1, seed=seed)

    # 保存训练集和验证集
    train_set.to_feather(os.path.join(write_dir, 'train.feather'))
    valid_set.to_feather(os.path.join(write_dir, 'valid.feather'))
    
        # # Write fasta files
    ftool.table2fasta(table=train_set[['uniprot_id','seq']], file_out=f'{write_dir}/train.fasta')
    ftool.table2fasta(table=valid_set[['uniprot_id','seq']], file_out=f'{write_dir}/valid.fasta')

    # 打印进度信息
    print(f"Fold {fold_index}: train size[{len(train_set)}], valid size[{len(valid_set)}], Finished")


Fold 1: train size[457729], valid size[50858], Finished
Fold 2: train size[457729], valid size[50858], Finished
Fold 3: train size[457729], valid size[50858], Finished
Fold 4: train size[457729], valid size[50858], Finished
Fold 5: train size[457729], valid size[50858], Finished
Fold 6: train size[457729], valid size[50858], Finished
Fold 7: train size[457729], valid size[50858], Finished
Fold 8: train size[457729], valid size[50858], Finished
Fold 9: train size[457729], valid size[50858], Finished
Fold 10: train size[457729], valid size[50858], Finished
