# Training a DeeProtGO model

This notebook is aimed to train DeeProtGO to predict Bilogical Process (BP) Gene Ontology (GO) terms for *NK* proteins from Eukarya organisms. 

## Requirements

In [None]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F


Downloading GitHub repository and setting data directory

In [None]:
!git clone https://github.com/gamerino/DeeProtGO.git
!./DeeProtGO/data/processed/Training/dataPreparation.sh

In [None]:
os.chdir('DeeProtGO/')

In [None]:
from src.sampler import Sampler
from src.dataloader import Dataloader
from src.DNNModel import DNNModel
from src.logger import Logger
from src.DNN import DNN
from src.earlyStop import EarlyStopping


### 1. Setting global parameters

In [None]:
torch.set_num_threads(torch.get_num_threads())
use_GPU = True
if use_GPU & torch.cuda.is_available():
    torch.cuda.manual_seed_all(1)
else:
    use_GPU = False
np.random.seed(1)
torch.manual_seed(1)


<torch._C.Generator at 0x7fe5a2be3f60>

In [None]:
# Directory for saving model fitting results 
trainingDate = '13052021' # This date can be changed for avoiding overwriting current files 
res_dir = "examples/train_NK_EUKA_BP/"
if not os.path.isdir('./'+res_dir):
    os.mkdir(res_dir)

### 2.  Setup parameters and model hyperparameters

In [None]:
Nfolds = 3
activFunc = F.elu
optimMethod = torch.optim.Adam
criterion = nn.BCELoss
thresh = 0.2
learning_rate = 0.005
nepoch = 100
patience = 10
nbatch = 128
pDrop = 0.5

pPSD1 = 0.5
pPSD2 = 0.35
pEmb1 = 0.7
pEmb2 = 0.5
pTaxon1 = 0.7
pTaxon2 = 0.5

pHidden1 = 1
pHidden2 = 0.7

samplingNegPerc = 0.1399 # Of all negatives, we want to take 0.1399*100 of the data


### 3. Data loading 

In [None]:
dirData = 'data/processed/Training/'
propOutFile = dirData + "GOTermsPropRel_Euka_BP_train.tab" # file with relationships between GO terms

In [None]:
dloaderAll = Dataloader(dirData = dirData, posEntriesFile = "PosEntries_Euka_BP.tab",
                        negEntriesFile = "NegEntries_Euka_BP.tab",
                        netOutFile = "netOut_BP_Euka.h5", inputData1 = "LevSim_BP_Euka.h5", 
                        inputData2 = "Emb_BP_Euka.h5", inputData3 = "Taxon_BP_Euka.h5", 
                        samplingNegPerc = samplingNegPerc)


Loading data...
Dataset ready with 7180 cases .


In [None]:
case = "PSD_Emb_Taxon"
cases = dloaderAll.get_labels()

### 4. Model training

In [None]:
sampler_list = []
for fold in range(Nfolds):
    sampler_list.append(Sampler(cases, fold, Nfolds, nbatch, partsize=[.70, .10, .20]))

loggerTrain = Logger(res_dir)
loggerTrain.start("train_%s" % (trainingDate))

loggerValid = Logger(res_dir)
loggerValid.start("valid_%s" % (trainingDate))

loggerTest = Logger(res_dir)
loggerTest.start("test_%s" % (trainingDate))
loggerTest.log("fold\tF1CAFA\tpreCAFA\trecCAFA\tthresh\tF1\tpre\trec\n")

fold	F1CAFA	preCAFA	recCAFA	thresh	F1	pre	rec


In [None]:
for fold in range(Nfolds):
    early_stopping = EarlyStopping(patience = patience, verbose = True, res_dir = res_dir,
                                   modelName = case)
    loggerTrain.log( "epoch\tloss\n" )
    loggerValid.log( "epoch\tloss\tF1\tpre\trec\n" )
    trainingSampler = sampler_list[fold].batch_ind( "train" )
    total_batch_train = len(trainingSampler)
    validSampler = sampler_list[fold].batch_ind( "validation" )
    total_batch_val = len(validSampler)
    testSampler = sampler_list[fold].batch_ind("test")
    N_in_1 = len(dloaderAll.get_inData1(cases[0]))
    N_in_2 = len(dloaderAll.get_inData2(cases[0]))
    N_in_3 = len(dloaderAll.get_inData3(cases[0]))
    DeeProtGO_Euka_NK_BP = DNNModel(res_dir, propOutFile = propOutFile, nbatch = nbatch, N_in_1 = N_in_1,
                                    N_in_2 = N_in_2, N_in_3 = N_in_3, pN1_1 = pPSD1, pN1_2 = pPSD2,
                                    pN2_1 = pEmb1, pN2_2 = pEmb2, pN3_1 = pTaxon1, pN3_2 = pTaxon2,
                                    pNO_1 = pHidden1, pNO_2 = pHidden2, thresh = thresh, pDrop = pDrop,
                                    activFunc = activFunc, optimMethod = optimMethod,criterion = criterion,
                                    learningRate = learning_rate, useGPU = use_GPU)
    epoch = 0
    while epoch < nepoch:
        train_loss=0
        for it in range(total_batch_train):
            inData1, inData2, inData3, _, _, _, outData = dloaderAll.get_batch(trainingSampler[it])
            loss, _ = DeeProtGO_Euka_NK_BP.train(outData, inData1, inData2, inData3)
            train_loss += loss
        msgA="%d\t%.6f\n" %(epoch, train_loss/total_batch_train)
        loggerTrain.log(msgA)
        # validation
        valid_loss = 0
        valid_f1 = 0
        valid_pre = 0
        valid_rec = 0
        for it in range(total_batch_val):
            inData1, inData2, inData3, _, _, _, outData = dloaderAll.get_batch(validSampler[it])
            loss, f1, pre, rec = DeeProtGO_Euka_NK_BP.test(outData, inData1,inData2, inData3, 
                                              CAFAerror = False)
            valid_loss += loss
            valid_f1 += f1
            valid_pre += pre
            valid_rec += rec
        valid_loss = valid_loss/total_batch_val
        valid_pre = valid_pre/total_batch_val
        valid_rec = valid_rec/total_batch_val
        valid_f1 = valid_f1/total_batch_val
        msg="%d\t%.5f\t%.3f\t%.3f\t%.3f\n" %(epoch,valid_loss, valid_f1, valid_pre, valid_rec)
        loggerValid.log(msg)
        early_stopping(valid_loss, DeeProtGO_Euka_NK_BP.net)
        if early_stopping.early_stop:
            print("Early stopping")
            break
        epoch += 1
    DeeProtGO_Euka_NK_BP.net.load_state_dict(torch.load(res_dir + case + 'checkpoint.pt'))
    inData1, inData2, inData3, _, _, _, outData = dloaderAll.get_batch(testSampler[0])
    _, fCafa, preCafa, recCafa, thresh_A, test_f1, test_pre, test_rec = DeeProtGO_Euka_NK_BP.test(outData, inData1,
                                                                                                  inData2, inData3,
                                                                                                  CAFAerror = True,
                                                                                                  propPrediction = True)
    msg="%d\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\n" %(fold, fCafa, preCafa, recCafa, thresh_A,
                                                                test_f1, test_pre, test_rec)
    loggerTest.log(msg)



epoch	loss
epoch	loss	F1	pre	rec
0	2.773918
0	14.31079	0.004	0.002	0.520
Validation loss decreased (inf --> 14.310792).  Saving model ...
1	0.724964
1	0.58338	0.148	0.090	0.430
Validation loss decreased (14.310792 --> 0.583381).  Saving model ...
2	0.434526
2	0.20428	0.340	0.284	0.426
Validation loss decreased (0.583381 --> 0.204284).  Saving model ...
3	0.316381
3	0.25314	0.209	0.145	0.387
EarlyStopping counter: 1 out of 10.000000
4	0.313401
4	0.20369	0.340	0.292	0.409
Validation loss decreased (0.204284 --> 0.203692).  Saving model ...
5	0.278024
5	0.22112	0.339	0.295	0.401
EarlyStopping counter: 1 out of 10.000000
6	0.223540
6	0.15606	0.208	0.138	0.433
Validation loss decreased (0.203692 --> 0.156063).  Saving model ...
7	0.158314
7	0.09642	0.332	0.284	0.405
Validation loss decreased (0.156063 --> 0.096424).  Saving model ...
8	0.145634
8	0.08480	0.367	0.327	0.421
Validation loss decreased (0.096424 --> 0.084798).  Saving model ...
9	0.104297
9	0.06918	0.379	0.339	0.434
Validation l

27	0.043796
27	0.04319	0.391	0.351	0.445
EarlyStopping counter: 1 out of 10.000000
28	0.043610
28	0.04282	0.398	0.361	0.448
Validation loss decreased (0.043152 --> 0.042824).  Saving model ...
29	0.043657
29	0.04334	0.393	0.346	0.458
EarlyStopping counter: 1 out of 10.000000
30	0.043714
30	0.04340	0.399	0.362	0.449
EarlyStopping counter: 2 out of 10.000000
31	0.043069
31	0.04267	0.399	0.356	0.456
Validation loss decreased (0.042824 --> 0.042671).  Saving model ...
32	0.043053
32	0.04329	0.392	0.351	0.448
EarlyStopping counter: 1 out of 10.000000
33	0.042456
33	0.04274	0.395	0.351	0.456
EarlyStopping counter: 2 out of 10.000000
34	0.042085
34	0.04209	0.396	0.356	0.451
Validation loss decreased (0.042671 --> 0.042092).  Saving model ...
35	0.041888
35	0.04229	0.396	0.350	0.461
EarlyStopping counter: 1 out of 10.000000
36	0.041793
36	0.04244	0.399	0.355	0.459
EarlyStopping counter: 2 out of 10.000000
37	0.041793
37	0.04276	0.401	0.363	0.453
EarlyStopping counter: 3 out of 10.000000
38	0.0

### 5. Model fitting

The model will be now trained with the optimal number of epochs and all the training data.

In [None]:
epoch = 0
nepoch = 51 # Please, consider to change to the optimum value
cases = dloaderAll.get_labels()

sampler_list = Sampler(cases, 0, 1, nbatch, partsize=[1, 0, 0])
trainingSampler = sampler_list.batch_ind("train")
total_batch_train = len(trainingSampler)

N_in_1 = len(dloaderAll.get_inData1(cases[0]))
N_in_2 = len(dloaderAll.get_inData2(cases[0]))
N_in_3 = len(dloaderAll.get_inData3(cases[0]))

DeeProtGO_Euka_NK_BP = DNNModel(res_dir, propOutFile = propOutFile, nbatch = nbatch, N_in_1 = N_in_1, 
                       N_in_2 = N_in_2, N_in_3 = N_in_3, pN1_1 = pPSD1, pN1_2 = pPSD2, 
                       pN2_1 = pEmb1, pN2_2 = pEmb2, pN3_1 = pTaxon1, pN3_2 = pTaxon2,
                       pNO_1 = pHidden1, pNO_2 = pHidden2, thresh = thresh, pDrop = pDrop, 
                       activFunc = activFunc, optimMethod = optimMethod,criterion = criterion, 
                       learningRate = learning_rate, useGPU = use_GPU)

In [None]:
while epoch < nepoch:
    train_loss_A = 0
    for it in range(total_batch_train):
        inData1, inData2, inData3, _, _, _, outData = dloaderAll.get_batch(trainingSampler[it])
        lossA, _ = DeeProtGO_Euka_NK_BP.train(outData, inData1, inData2, inData3)
        train_loss_A += lossA
    train_loss_A = train_loss_A/total_batch_train
    msgA = "%d\t%.6f\n" %(epoch, train_loss_A)
    print(msgA)
    epoch += 1

0	2.130178

1	0.661186

2	0.394147

3	0.119690

4	0.077502

5	0.104366

6	0.067994

7	0.053280

8	0.048065

9	0.050607

10	0.047800

11	0.044651

12	0.043922

13	0.043873

14	0.043209

15	0.042112

16	0.041883

17	0.041389

18	0.041073

19	0.040713

20	0.040512

21	0.040605

22	0.040594

23	0.041091

24	0.040065

25	0.039833

26	0.039334

27	0.038924

28	0.038984

29	0.038734

30	0.038555

31	0.038458

32	0.038774

33	0.038633

34	0.038381

35	0.037881

36	0.037618

37	0.037451

38	0.037291

39	0.037312

40	0.037204

41	0.037232

42	0.036800

43	0.036736

44	0.036464

45	0.036272

46	0.036203

47	0.036054

48	0.035808

49	0.035582

50	0.035631



In [None]:
torch.save(DeeProtGO_Euka_NK_BP.net.state_dict(), res_dir + 'DeeProtGO_PSD_Emb_Taxon_Euka_BP_NK.pt')
