In [1]:
import random
from protein import *
from util import *
from model import *
import warnings
warnings.filterwarnings('ignore')

In [2]:
def read_pkpdb():
    pka={}
    temp = np.array(pd.read_csv('D:\\pdb\\final_pka.csv', index_col=False))
    for i in temp:
        pka[i[0]+i[1]+str(i[2])+i[3]]=float(i[4])
    return pka

def read_pkad():
    pka = {}
    fd = pd.read_csv('D:\\pdb\\pkad\\PKAD2_DOWNLOAD.csv', index_col=False)
    fd = np.array(fd)
    # print(fd[:10])
    for i in fd:
        if '<' in i[4] or '>' in i[4] or '-' in i[4] or '~' in i[4]:
            # i[4]=i[4][1:]
            continue
        if i[1] == 'N-term' or i[1] == 'C-term':
            continue
        try:
            pka[i[0]+i[2]+str(i[3])+i[1]]=float(i[4])
        except:
            pass

    return pka

In [3]:
def run(device="cpu", threads=None):
    if threads:
        torch.set_num_threads(threads)
    #model = load_model(model_name, device)
    model=Net()
    #model = GAT()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
    loss_func = nn.MSELoss()

    min_mae = 100
    epoches = 1
    for epoch in range(epoches):
        predictions_tr = torch.Tensor()
        labels_tr = torch.Tensor()
        err = []
        losses = []
        file_num=0

        start = time.time()
        pdb_path = 'D:\\pdb\\pkPDB\\data\\'
        pssm_path = 'D:\\pdb\\pkPDB\\pssm\\'
        pka = read_pkpdb()

        files=os.listdir(pdb_path)
        random.shuffle(files)

        file_count=0
        file_length = len(files)
        for file_name in files:
            file_count+=1

            prot = Protein(pdb_path,file_name,pssm_path)
            res = prot.read_pssm()
            if res == -1:
                #print(file_name,'-1')
                continue

            try:
                prot.align_pssm()
                prot.align_pka(pka, file_name.replace('.pdb',''))
                prot.apply_cutoff()
                prot.creat_graph()

                optimizer.zero_grad()
                preds, pkas = prot.predict_pkas(model, device, loss_func, optimizer)
            except:
                #print(file_name)
                continue

            predictions_tr = torch.cat((predictions_tr, preds), 0)
            labels_tr = torch.cat((labels_tr, pkas), 0)

            file_num += 1
            if file_num >=50:
                loss = loss_func(predictions_tr, labels_tr)

                err.extend(abs(predictions_tr.view(-1).detach().numpy() - labels_tr.view(-1).detach().numpy()))

                loss.requires_grad_(True).backward()
                losses.append(loss.view(-1).detach().numpy())
                optimizer.step()

                predictions_tr = torch.Tensor()
                labels_tr = torch.Tensor()
                file_num = 0

                print('files:',file_count,'/',file_length,' time:', str(time.time() - start), '---MAE:', str(np.average(err)))
                start = time.time()
                err = []

        predictions_tr = torch.Tensor()
        labels_tr = torch.Tensor()

        print('test')
        model.eval()
        pkad_path = 'D:\\pdb\\pkad\\data\\'
        pkad_pssm_path = 'D:\\pdb\\pkad\\pssm\\'
        pkad = read_pkad()

        files = os.listdir(pkad_path)
        ave=[]
        for file_name in files:
            #print(file_name)
            prot_test = Protein(pkad_path,file_name,pkad_pssm_path)
            res = prot_test.read_pssm()
            if res == -1:
                #print(file_name, '-1')
                continue
            try:
                prot_test.align_pssm()
                prot_test.align_pka(pkad, file_name.replace('.pdb', ''))
                prot_test.apply_cutoff()
                prot_test.creat_graph()

                result=prot_test.test(model, device)
            except:
                #print(file_name)
                continue

            for i in result:
                if i[1]==-1:
                    #print(result)
                    continue
                ave.append(abs(i[0]-i[1]))
        mae = np.average(ave)
        print(mae)

        save_path = 'C:\\Users\\dx\\Documents\\dx_pkai_model_gat+20pssm\\dx_pkai_model_gat_pssm\\save_model\\'
        epoch_model_save_name = f'pka_net_epoch{epoch}_mae{mae:.5f}.pt'
        best_model_save_name = 'pka_net_best_mae.pt'
        torch.save(model.state_dict(), save_path + epoch_model_save_name)
        if min_mae > mae:
            torch.save(model.state_dict(), save_path + best_model_save_name)
            min_mae = mae

In [None]:
run(device="cpu", threads=None)

files: 55 / 46896  time: 380.67513847351074 ---MAE: 6.8899136
files: 109 / 46896  time: 406.56704354286194 ---MAE: 6.709753
files: 167 / 46896  time: 422.50311374664307 ---MAE: 6.4352508
files: 223 / 46896  time: 324.85463643074036 ---MAE: 6.0086007
files: 275 / 46896  time: 397.6251723766327 ---MAE: 4.588406
files: 328 / 46896  time: 428.95722246170044 ---MAE: 3.9249573
files: 382 / 46896  time: 420.98835611343384 ---MAE: 7.0240474
files: 438 / 46896  time: 374.3790657520294 ---MAE: 3.9568
files: 496 / 46896  time: 356.3319888114929 ---MAE: 3.7465951
files: 554 / 46896  time: 363.7449231147766 ---MAE: 4.4785814
files: 610 / 46896  time: 393.65129375457764 ---MAE: 4.879106
files: 666 / 46896  time: 377.92378282546997 ---MAE: 5.0624924
files: 720 / 46896  time: 279.40457344055176 ---MAE: 5.0985427
files: 777 / 46896  time: 390.70204854011536 ---MAE: 4.90468
files: 836 / 46896  time: 333.8731553554535 ---MAE: 4.7130375
files: 894 / 46896  time: 355.0323312282562 ---MAE: 4.1978965
files: 

files: 7190 / 46896  time: 293.29805755615234 ---MAE: 2.0692778
files: 7246 / 46896  time: 237.93314385414124 ---MAE: 1.8950689
files: 7301 / 46896  time: 252.9747154712677 ---MAE: 1.8734617
files: 7357 / 46896  time: 255.98583936691284 ---MAE: 1.8773094
files: 7411 / 46896  time: 265.8071217536926 ---MAE: 1.9230189
files: 7465 / 46896  time: 245.71649432182312 ---MAE: 1.8753824
files: 7522 / 46896  time: 240.80026173591614 ---MAE: 1.8519754
files: 7578 / 46896  time: 239.45622515678406 ---MAE: 1.9511809
files: 7635 / 46896  time: 267.4038004875183 ---MAE: 1.9710716
files: 7690 / 46896  time: 231.00980734825134 ---MAE: 1.8571588
files: 7747 / 46896  time: 222.01650285720825 ---MAE: 1.8697584
files: 7801 / 46896  time: 243.032452583313 ---MAE: 1.736099
files: 7858 / 46896  time: 271.7847933769226 ---MAE: 1.8371533
files: 7916 / 46896  time: 269.3820219039917 ---MAE: 1.9676098
files: 7968 / 46896  time: 273.54774165153503 ---MAE: 1.9803114
files: 8020 / 46896  time: 245.78315258026123 --

files: 14221 / 46896  time: 285.239284992218 ---MAE: 1.7519377
files: 14278 / 46896  time: 294.45757937431335 ---MAE: 1.7948312
files: 14332 / 46896  time: 250.4795172214508 ---MAE: 1.8388764
files: 14386 / 46896  time: 294.453857421875 ---MAE: 1.730971
files: 14443 / 46896  time: 244.02511191368103 ---MAE: 1.7883344
files: 14496 / 46896  time: 267.6068687438965 ---MAE: 1.7877053
files: 14550 / 46896  time: 229.70225429534912 ---MAE: 1.6811339
files: 14605 / 46896  time: 282.3291049003601 ---MAE: 1.7385707
files: 14657 / 46896  time: 242.52487897872925 ---MAE: 1.6505331
files: 14711 / 46896  time: 273.0398337841034 ---MAE: 1.8381484
files: 14763 / 46896  time: 248.67700719833374 ---MAE: 1.7415428
files: 14817 / 46896  time: 277.8181176185608 ---MAE: 1.7283587
files: 14871 / 46896  time: 253.55071759223938 ---MAE: 1.8151735
files: 14928 / 46896  time: 234.68469762802124 ---MAE: 1.7505044
files: 14981 / 46896  time: 247.5030961036682 ---MAE: 1.7008543
files: 15039 / 46896  time: 276.4054

files: 21250 / 46896  time: 411.62691950798035 ---MAE: 1.6648213
files: 21302 / 46896  time: 385.1593589782715 ---MAE: 1.6401383
files: 21358 / 46896  time: 334.3768661022186 ---MAE: 1.7745464
files: 21410 / 46896  time: 391.0223813056946 ---MAE: 1.668668
files: 21465 / 46896  time: 336.9799540042877 ---MAE: 1.6701845
files: 21516 / 46896  time: 400.0452935695648 ---MAE: 1.628667
files: 21570 / 46896  time: 423.0308895111084 ---MAE: 1.5500427
files: 21628 / 46896  time: 410.0967667102814 ---MAE: 1.6236835
files: 21680 / 46896  time: 463.47481870651245 ---MAE: 1.7820687
files: 21736 / 46896  time: 399.7903206348419 ---MAE: 1.668521
files: 21790 / 46896  time: 457.8240249156952 ---MAE: 1.6612613
files: 21844 / 46896  time: 356.8891215324402 ---MAE: 1.6851323
files: 21896 / 46896  time: 329.8011622428894 ---MAE: 1.626815
files: 21950 / 46896  time: 414.76183247566223 ---MAE: 1.6081977
files: 22001 / 46896  time: 421.730845451355 ---MAE: 1.6518296
files: 22053 / 46896  time: 336.9631981849

files: 28214 / 46896  time: 335.72029733657837 ---MAE: 1.8239471
files: 28269 / 46896  time: 365.6333601474762 ---MAE: 1.7925855
files: 28322 / 46896  time: 277.8701128959656 ---MAE: 1.7229227
files: 28374 / 46896  time: 342.5592749118805 ---MAE: 1.726214
files: 28424 / 46896  time: 368.2768144607544 ---MAE: 1.71306
files: 28483 / 46896  time: 344.8782870769501 ---MAE: 1.8202924
files: 28536 / 46896  time: 310.34906220436096 ---MAE: 1.8154403
files: 28592 / 46896  time: 349.48417472839355 ---MAE: 1.7963866
files: 28643 / 46896  time: 425.8077049255371 ---MAE: 1.7729418
files: 28698 / 46896  time: 328.33212065696716 ---MAE: 1.6801058
files: 28751 / 46896  time: 395.5644793510437 ---MAE: 1.9041433
files: 28801 / 46896  time: 360.9061117172241 ---MAE: 1.8150557
files: 28858 / 46896  time: 405.2363750934601 ---MAE: 1.7645504
files: 28912 / 46896  time: 278.9703760147095 ---MAE: 1.7784926
files: 28965 / 46896  time: 332.83838653564453 ---MAE: 1.791504
files: 29025 / 46896  time: 404.9627795

files: 35138 / 46896  time: 361.1603696346283 ---MAE: 1.6800061
files: 35194 / 46896  time: 349.0521593093872 ---MAE: 1.7502346
files: 35247 / 46896  time: 409.4975996017456 ---MAE: 1.7147645
files: 35301 / 46896  time: 374.8398141860962 ---MAE: 1.7156345
files: 35354 / 46896  time: 362.1363756656647 ---MAE: 1.6818904
files: 35406 / 46896  time: 361.1615250110626 ---MAE: 1.6475109
files: 35461 / 46896  time: 373.96128821372986 ---MAE: 1.652099
files: 35518 / 46896  time: 380.93480801582336 ---MAE: 1.6766788
files: 35570 / 46896  time: 369.13081979751587 ---MAE: 1.652909
files: 35624 / 46896  time: 353.0544877052307 ---MAE: 1.630183
files: 35679 / 46896  time: 383.7532968521118 ---MAE: 1.6735445
files: 35736 / 46896  time: 385.1389467716217 ---MAE: 1.7215109
files: 35791 / 46896  time: 449.493910074234 ---MAE: 1.6579003
files: 35847 / 46896  time: 389.37977933883667 ---MAE: 1.7195994
files: 35899 / 46896  time: 364.79582691192627 ---MAE: 1.6705165
files: 35951 / 46896  time: 383.5622804

files: 42119 / 46896  time: 343.10996866226196 ---MAE: 1.5929091
files: 42173 / 46896  time: 343.7021162509918 ---MAE: 1.5798186
files: 42227 / 46896  time: 350.8728606700897 ---MAE: 1.6354455
files: 42288 / 46896  time: 365.8657445907593 ---MAE: 1.5550997
files: 42342 / 46896  time: 337.2127892971039 ---MAE: 1.5975678
files: 42400 / 46896  time: 362.83055567741394 ---MAE: 1.5294796
files: 42455 / 46896  time: 401.48314595222473 ---MAE: 1.6262395
files: 42510 / 46896  time: 333.9124209880829 ---MAE: 1.5459197
files: 42564 / 46896  time: 407.514502286911 ---MAE: 1.6195831
files: 42618 / 46896  time: 347.9115369319916 ---MAE: 1.5857875
files: 42674 / 46896  time: 375.8001506328583 ---MAE: 1.694697
files: 42728 / 46896  time: 345.80916237831116 ---MAE: 1.6117082
files: 42781 / 46896  time: 354.73585653305054 ---MAE: 1.5622855
files: 42833 / 46896  time: 439.4876925945282 ---MAE: 1.5808846
files: 42887 / 46896  time: 358.9317090511322 ---MAE: 1.6235396
files: 42942 / 46896  time: 375.99127