In [10]:
import pandas as pd
from transformers import AutoModel, AutoTokenizer, Trainer, TrainingArguments
import torch
import torch.nn as nn
from torch.nn import functional as F
#from torch.utils.data import Dataset
from datasets import Dataset

# 读取并预处理数据集
df = pd.read_csv('train_labels.csv')
# 将所有列中的 NaN 替换为 0
df = df.fillna(0)
df

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1
0,1SCL_A_1,G,1,13.760,-25.974001,0.102
1,1SCL_A_2,G,2,9.310,-29.638000,2.669
2,1SCL_A_3,G,3,5.529,-27.813000,5.878
3,1SCL_A_4,U,4,2.678,-24.900999,9.793
4,1SCL_A_5,G,5,1.827,-20.136000,11.793
...,...,...,...,...,...,...
137090,8Z1F_T_82,U,82,0.000,0.000000,0.000
137091,8Z1F_T_83,C,83,0.000,0.000000,0.000
137092,8Z1F_T_84,A,84,0.000,0.000000,0.000
137093,8Z1F_T_85,U,85,0.000,0.000000,0.000


In [11]:
df['molecule_id'] = df['ID'].apply(lambda x: x.split('_')[0])
df['residue_id'] = df['resid']
df

Unnamed: 0,ID,resname,resid,x_1,y_1,z_1,molecule_id,residue_id
0,1SCL_A_1,G,1,13.760,-25.974001,0.102,1SCL,1
1,1SCL_A_2,G,2,9.310,-29.638000,2.669,1SCL,2
2,1SCL_A_3,G,3,5.529,-27.813000,5.878,1SCL,3
3,1SCL_A_4,U,4,2.678,-24.900999,9.793,1SCL,4
4,1SCL_A_5,G,5,1.827,-20.136000,11.793,1SCL,5
...,...,...,...,...,...,...,...,...
137090,8Z1F_T_82,U,82,0.000,0.000000,0.000,8Z1F,82
137091,8Z1F_T_83,C,83,0.000,0.000000,0.000,8Z1F,83
137092,8Z1F_T_84,A,84,0.000,0.000000,0.000,8Z1F,84
137093,8Z1F_T_85,U,85,0.000,0.000000,0.000,8Z1F,85


In [12]:
groups = df.groupby('molecule_id')
data = []
for molecule_id, group in groups:
    group = group.sort_values('residue_id')
    sequence = ''.join(group['resname'].values)
    labels = group[['x_1', 'y_1', 'z_1']].values
    data.append((sequence, labels))

data[0]

('GGCGUAAGGAUUACCUAUGCC',
 array([[ 35.85699844, -10.76900005,  -7.54799986],
        [ 30.22999954, -12.07499981,  -8.61400032],
        [ 23.96800041, -11.35599995,  -7.69000006],
        [ 19.29599953,  -9.8739996 ,  -4.77799988],
        [ 16.36199951,  -6.04699993,  -0.70599997],
        [ 15.63599968,  -1.54900002,   2.46300006],
        [ 16.96999931,   2.89299989,   4.62599993],
        [ 20.39100075,   6.86199999,   5.54899979],
        [ 24.37000084,   9.63000011,   3.34800005],
        [ 26.34199905,  12.36499977,  -0.59399998],
        [ 23.91799927,  16.02300072,  -5.41800022],
        [ 24.93799973,  15.56499958, -11.24300003],
        [ 25.58799934,  10.09500027, -10.00399971],
        [ 28.33300018,   7.8039999 ,  -6.25500011],
        [ 28.9489994 ,   4.83599997,  -0.75599998],
        [ 26.74900055,   1.38999999,   3.28699994],
        [ 24.11899948,  -2.8210001 ,   6.02600002],
        [ 22.77099991,  -7.66499996,   5.35500002],
        [ 22.32999992, -13.6260004 ,  

In [13]:
sequence, labels = data[0]
print(sequence)
print(labels)

GGCGUAAGGAUUACCUAUGCC
[[ 35.85699844 -10.76900005  -7.54799986]
 [ 30.22999954 -12.07499981  -8.61400032]
 [ 23.96800041 -11.35599995  -7.69000006]
 [ 19.29599953  -9.8739996   -4.77799988]
 [ 16.36199951  -6.04699993  -0.70599997]
 [ 15.63599968  -1.54900002   2.46300006]
 [ 16.96999931   2.89299989   4.62599993]
 [ 20.39100075   6.86199999   5.54899979]
 [ 24.37000084   9.63000011   3.34800005]
 [ 26.34199905  12.36499977  -0.59399998]
 [ 23.91799927  16.02300072  -5.41800022]
 [ 24.93799973  15.56499958 -11.24300003]
 [ 25.58799934  10.09500027 -10.00399971]
 [ 28.33300018   7.8039999   -6.25500011]
 [ 28.9489994    4.83599997  -0.75599998]
 [ 26.74900055   1.38999999   3.28699994]
 [ 24.11899948  -2.8210001    6.02600002]
 [ 22.77099991  -7.66499996   5.35500002]
 [ 22.32999992 -13.6260004    3.10700011]
 [ 25.37299919 -17.3560009   -0.29300001]
 [ 29.96199989 -19.02499962  -3.30900002]]


In [14]:
# 准备输入和标签
def prepare_input_labels(sequence, labels):
    seq_pos_list_local = []
    for i in range(0, len(sequence)):
        seq = sequence[0:i+1].replace("U","T") #U-->T
        #如果seq长度大于256则只要最后面的256个字符
        if len(seq) > 1024:
            seq = seq[-1024:]
        pos = labels[i].tolist()
        #print(seq, pos)
        data = {"seq":seq, "label":pos}
        seq_pos_list_local.append(data)
    return seq_pos_list_local
        
#seq_pos_list = prepare_input_labels(sequence, labels)

In [15]:
seq_pos_list_train = []
for sequence, labels in data:
    seq_pos_list = prepare_input_labels(sequence, labels)
    seq_pos_list_train.extend(seq_pos_list)


In [16]:
seq_pos_list_train[0]

{'seq': 'G',
 'label': [35.85699844360352, -10.769000053405762, -7.547999858856201]}

In [17]:
import json
filename = "rna_pos_1024.jsonl"
rna_file = open(filename,"w")
for item in seq_pos_list_train:
    rna_file.write(json.dumps(item) + "\n")

rna_file.close()