In [22]:
import sys
sys.path.append("../")  # src 상위 폴더 기준 상대경로

from src.datasets.dataprocess import get_Data
from src.utils import get_RSA
from Bio.PDB.MMCIFParser import MMCIFParser
from Bio.PDB.PDBParser import PDBParser
from src.models.Model import SAPP_Model
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="Bio.PDB.DSSP")
import torch 
from collections import defaultdict
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import os
from tqdm import tqdm 
import torch


input_path = '/home/bis/230419_CYJ_Methyl/SAPP/inference_test_wAF.txt'
batch_size = 128
device = 'cuda:2' 
output_path = 'inference_result.csv'

with open(input_path,'r') as f:
    lists = f.readlines()

PTMtype_to_residue = {'SAPPPhos':['S','T'],'SAPPmethylR':['R'],'SAPPphosY':['Y'],
                      'SAPPsumoK':['K'],'SAPPmethylK':['K'], 'SAPPacetylK':['K'],
                     'SAPPubiquitinK':['K']}



# PTMtype 기준으로 그룹핑된 정보 저장
grouped_test_info = defaultdict(list)
grouped_RSA_dic = defaultdict(dict)
grouped_protein_dic = defaultdict(dict)

for line in lists:
    proteinid,protein_seq, site, label, PTMtype, AF_file = line.strip().split('\t')

    site = int(site)
    label = int(label)
    assert protein_seq[site] in PTMtype_to_residue[PTMtype]

    # 구조 파싱
    parser = MMCIFParser() if AF_file.endswith('.cif') else PDBParser()
    structure = parser.get_structure('test', AF_file)
    model = structure[0]
    pdb_seq, RSA = get_RSA(model, AF_file)

    assert protein_seq == pdb_seq

    # 그룹별 저장
    grouped_test_info[PTMtype].append((proteinid, site, label, PTMtype))
    grouped_RSA_dic[PTMtype][proteinid] = RSA
    grouped_protein_dic[PTMtype][proteinid] = protein_seq

grouped_test_result = dict()

for ptm_type in grouped_test_info:
    print(f"Processing {ptm_type}...")
    
    # 데이터 준비
    seq_list, rsa_list, mask_list, rsamask_list, label_list = get_Data(
        grouped_test_info[ptm_type],
        grouped_protein_dic[ptm_type],
        grouped_RSA_dic[ptm_type]
    )

    test_dataset = TensorDataset(
        seq_list,
        rsa_list,
        mask_list,
        rsamask_list,
        label_list
    )
    
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    
    all_preds = []

    # 모델 앙상블 평균
    model = SAPP_Model(
        22, hidden=256, n_layers=2, attn_heads=4, dropout=0.2,
        feed_forward_dim=758, device=device
    ).to(device)

    model_path = f'../data/models/{ptm_type}'
    weight_files = os.listdir(model_path)

    for weight_file in weight_files:
        model.load_state_dict(torch.load(os.path.join(model_path, weight_file)))
        model.eval()

        pred_y = []
        for data in tqdm(test_loader, desc=f"{ptm_type} - {weight_file}"):
            data = [b.to(device) for b in data]
            pred, _ = model(data[0], data[1], data[2], data[3])
            pred_y.append(pred.view(-1).detach().cpu().numpy())

        all_preds.append(np.concatenate(pred_y))

    pred_y_list = np.mean(np.stack(all_preds), axis=0)
    
    protein_ids = [info[0] for info in grouped_test_info[ptm_type]]
    sites = [info[1] for info in grouped_test_info[ptm_type]]


    # 결과 저장
    result_df = pd.DataFrame({
        'ProteinID': protein_ids,
        'Site': sites,
        'Pred': pred_y_list,
        'Label': label_list.cpu().numpy(),
        'PTMType': ptm_type
    })
    grouped_test_result[ptm_type] = result_df

# 모든 PTM type 결과 합치기
total_df = pd.concat(grouped_test_result.values(), ignore_index=True)
total_df.to_csv(output_path)

In [44]:
with open('../example_input_wAF.txt','r') as f:
    lists = f.readlines()

In [46]:
new_list = []
for line in lists:
    proteinid, seq, site, label, ptmtype, mmcif = line.replace('\n','').split('\t')
    new_list.append(proteinid+'\t'+seq+'\t'+site+'\t'+label+'\t'+ptmtype+'\t'+'data/RSA_files/Q9ESV1.npy'+'\n')

In [47]:
with open('../example_input_wRSA.txt','w') as f:
    f.writelines(''.join(new_list))

In [28]:
protein_list = list()
rsa_list = list()

for protein in set(test_df['Protein']):
    protein_list.append(protein)
    rsa_list.append(total_RSA[protein])

In [29]:
df = pd.DataFrame({'Protein':protein_list,'RSA':rsa_list})

In [31]:
df.to_csv('/home/bis/230419_CYJ_Methyl/SAPP/data/RSA_files/testdata_RSA.csv')

In [34]:
df = pd.read_csv('../data/RSA_files/testdata_RSA.csv',index_col=0)

In [43]:
np.save('/home/bis/230419_CYJ_Methyl/SAPP/data/RSA_files/Q9ESV1.npy', total_RSA['Q9ESV1'])

In [48]:
with open('/data1/CYJ_Methyl/paper_data_2025/Phospho_ST/SAPP_test_result.pickle','rb') as f:
    SAPP = pickle.load(f)

In [49]:
protein_list = list()
site_list = list()
pred_list = list()
real_list = list()

for info, pred, real in zip(SAPP['info'], SAPP['pred_y'], SAPP['real_y']):
    protein_list.append(info[0])
    site_list.append(info[1])
    pred_list.append(pred)
    real_list.append(real)

In [50]:
SAPP_df = pd.DataFrame({'Protein':protein_list,'Site':site_list, 'Pred':pred_list, 'Label':real_list})

In [53]:
with open('/data1/CYJ_Methyl/paper_data_2025/Phospho_ST/Total_RSA.pickle','rb') as f:
    SAPPPhos_RSA = pickle.load(f)


NameError: name 'Total_RSA' is not defined