# Training a DeeProtGO model

This notebook is aimed to train DeeProtGO to predict Bilogical Process (BP) Gene Ontology (GO) terms for *NK* proteins from Eukarya organisms. 

## Requirements

In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
#from google.colab import drive


In [2]:
# drive.mount('/content/drive')

In [2]:
os.chdir('/home/gabriela/Insync/gmerino@sinc.unl.edu.ar/Google Drive/EBI-EMBL/GOAnnot/DeeProtGO/')

In [3]:
from src2.sampler import Sampler
from src2.dataloader import Dataloader
from src2.DNNModel import DNNModel
from src2.logger import Logger
from src2.DNN import DNN
from src2.earlyStop import EarlyStopping


### 1. Setting global parameters

In [4]:
torch.set_num_threads(torch.get_num_threads())
use_GPU = True
if use_GPU & torch.cuda.is_available():
    torch.cuda.manual_seed_all(1)
else:
    use_GPU = False
np.random.seed(1)
torch.manual_seed(1)


<torch._C.Generator at 0x7fe097655f60>

In [7]:
# Directory for saving model fitting results 
trainingDate = '26032021'
res_dir = "examples/train_NK_EUKA_BP/"
if not os.path.isdir('./'+res_dir):
    os.mkdir(res_dir)

### 2.  Setup parameters and model hyperparameters

In [8]:
Nfolds = 3
activFunc = F.elu
optimMethod = torch.optim.Adam
criterion = nn.BCELoss
thresh = 0.2
learning_rate = 0.005
nepoch = 100
patience = 10
nbatch = 128
pDrop = 0.5

pPSD1 = 0.5
pPSD2 = 0.35
pEmb1 = 0.7
pEmb2 = 0.5
pTaxon1 = 0.7
pTaxon2 = 0.5

pHidden1 = 1
pHidden2 = 0.7

samplingNegPerc = 0.1399 # Of all negatives, we want to take 0.1399*100 of the data


### 3. Data loading 

In [9]:
dirData = 'data/processed/Training/'
propOutFile = dirData + "GOTermsPropRel_Euka_BP_train.tab" # file with relationships between GO terms

In [10]:
dloaderAll = Dataloader(dirData = dirData, posEntriesFile = "PosEntries_Euka_BP.tab",
                        negEntriesFile = "NegEntries_Euka_BP.tab",
                        netOutPath = "netOut_BP_Euka.h5", PSDPath = "LevSim_BP_Euka.h5", 
                        EmbPath = "Emb_BP_Euka.h5", 
                        TaxonPath = "Taxon_BP_Euka.h5", samplingNegPerc = samplingNegPerc)


Loading data...
Dataset ready with 7180 cases .


In [11]:
case = "PSD_Emb_Taxon"
cases = dloaderAll.get_labels()

### 4. Model training

In [12]:
sampler_list = []
for fold in range(Nfolds):
    sampler_list.append(Sampler(cases, fold, Nfolds, nbatch, partsize=[.70, .10, .20]))
loggerTrain = Logger(res_dir)
loggerTrain.start("train_%s" % (trainingDate))
loggerValid = Logger(res_dir)
loggerValid.start("valid_%s" % (trainingDate))
loggerTest = Logger(res_dir)
loggerTest.start("test_%s" % (trainingDate))
loggerTest.log("fold\tF1CAFA\tpreCAFA\trecCAFA\tthresh\tF1\tpre\trec\n")

fold	F1CAFA	preCAFA	recCAFA	thresh	F1	pre	rec


In [13]:
for fold in range(Nfolds):
    early_stopping = EarlyStopping(patience = patience, verbose = True, res_dir = res_dir,
                                   modelName = case)
    loggerTrain.log("epoch\tloss\n")
    loggerValid.log("epoch\tloss\tF1\tpre\trec\n")
    trainingSampler = sampler_list[fold].batch_ind("train")
    total_batch_train = len(trainingSampler)
    validSampler = sampler_list[fold].batch_ind("validation")
    total_batch_val = len(validSampler)
    testSampler = sampler_list[fold].batch_ind("test")
    N_in_1 = len(dloaderAll.get_PSD(cases[0]))
    N_in_2 = len(dloaderAll.get_Taxon(cases[0]))
    N_in_3 = len(dloaderAll.get_Emb(cases[0]))
    DeeProtGO_Euka_NK_BP = DNNModel(res_dir, propOutFile = propOutFile, nbatch = nbatch, N_in_1 = N_in_1,
                                    N_in_2 = N_in_2, N_in_3 = N_in_3, pN1_1 = pPSD1, pN1_2 = pPSD2,
                                    pN2_1 = pEmb1, pN2_2 = pEmb2, pN3_1 = pTaxon1, pN3_2 = pTaxon2,
                                    pNO_1 = pHidden1, pNO_2 = pHidden2, thresh = thresh, pDrop = pDrop,
                                    activFunc = activFunc, optimMethod = optimMethod,criterion = criterion,
                                    learning_rate = learning_rate, useGPU = use_GPU)
    epoch = 0
    while epoch < nepoch:
        train_loss=0
        for it in range(total_batch_train):
            inData1, inData2, inData3, _, _, _, outData = dloaderAll.get_batch(trainingSampler[it])
            loss, _ = DeeProtGO_Euka_NK_BP.train(outData, inData1, inData2, inData3)
            train_loss += loss
        msgA="%d\t%.6f\n" %(epoch, train_loss/total_batch_train)
        loggerTrain.log(msgA)
        # validation
        valid_loss = 0
        valid_f1 = 0
        valid_pre = 0
        valid_rec = 0
        for it in range(total_batch_val):
            inData1,inData2,inData3,_,_,_,outData = dloaderAll.get_batch(validSampler[it])
            loss, f1, pre, rec = DeeProtGO_Euka_NK_BP.test(outData, inData1,inData2, inData3, 
                                              CAFAerror = False)
            valid_loss += loss
            valid_f1 += f1
            valid_pre += pre
            valid_rec += rec
        valid_loss = valid_loss/total_batch_val
        valid_pre = valid_pre/total_batch_val
        valid_rec = valid_rec/total_batch_val
        valid_f1 = valid_f1/total_batch_val
        msg="%d\t%.5f\t%.3f\t%.3f\t%.3f\n" %(epoch,valid_loss, valid_f1, valid_pre, valid_rec)
        loggerValid.log(msg)
        early_stopping(valid_loss, DeeProtGO_Euka_NK_BP.net)
        if early_stopping.early_stop:
            print("Early stopping")
            break
        epoch+=1
    DeeProtGO_Euka_NK_BP.net.load_state_dict(torch.load(res_dir+case+'checkpoint.pt'))
    inData1,inData2,inData3,_,_,_,outData = dloaderAll.get_batch(testSampler[0])
    _, fCafa, preCafa, recCafa, thresh_A, test_f1, test_pre, test_rec = DeeProtGO_Euka_NK_BP.test(outData, inData1,
                                                                                                  inData2, inData3,
                                                                                                  CAFAerror = True,
                                                                                                  propPrediction = True)
    msg="%d\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\n" %(fold,fCafa, preCafa, recCafa, thresh_A,
                                                                test_f1, test_pre, test_rec)
    loggerTest.log(msg)



epoch	loss
epoch	loss	F1	pre	rec
0	2.165125
0	1.11681	0.006	0.003	0.526
Validation loss decreased (inf --> 1.116807).  Saving model ...
1	0.401084
1	0.19270	0.272	0.203	0.418
Validation loss decreased (1.116807 --> 0.192700).  Saving model ...
2	0.204410
2	0.10673	0.340	0.283	0.429
Validation loss decreased (0.192700 --> 0.106725).  Saving model ...
3	0.198276
3	0.11562	0.323	0.258	0.434
EarlyStopping counter: 1 out of 10.000000
4	0.131033
4	0.08051	0.369	0.328	0.423
Validation loss decreased (0.106725 --> 0.080513).  Saving model ...
5	0.089963
5	0.05742	0.353	0.291	0.449
Validation loss decreased (0.080513 --> 0.057424).  Saving model ...
6	0.076800
6	0.05491	0.354	0.297	0.440
Validation loss decreased (0.057424 --> 0.054908).  Saving model ...
7	0.061794
7	0.05016	0.380	0.340	0.432
Validation loss decreased (0.054908 --> 0.050165).  Saving model ...
8	0.063115
8	0.04809	0.378	0.333	0.437
Validation loss decreased (0.050165 --> 0.048093).  Saving model ...
9	0.055177
9	0.04434	0.393	

32	0.038895
32	0.04225	0.414	0.381	0.457
EarlyStopping counter: 1 out of 10.000000
33	0.038607
33	0.04236	0.403	0.358	0.466
EarlyStopping counter: 2 out of 10.000000
34	0.038315
34	0.04229	0.405	0.362	0.463
EarlyStopping counter: 3 out of 10.000000
35	0.038172
35	0.04173	0.413	0.376	0.462
Validation loss decreased (0.041980 --> 0.041732).  Saving model ...
36	0.038086
36	0.04191	0.403	0.357	0.467
EarlyStopping counter: 1 out of 10.000000
37	0.037847
37	0.04247	0.403	0.363	0.457
EarlyStopping counter: 2 out of 10.000000
38	0.037897
38	0.04175	0.399	0.348	0.472
EarlyStopping counter: 3 out of 10.000000
39	0.037661
39	0.04171	0.401	0.354	0.466
Validation loss decreased (0.041732 --> 0.041713).  Saving model ...
40	0.036907
40	0.04162	0.409	0.373	0.457
Validation loss decreased (0.041713 --> 0.041624).  Saving model ...
41	0.037025
41	0.04227	0.401	0.354	0.467
EarlyStopping counter: 1 out of 10.000000
42	0.037086
42	0.04144	0.395	0.342	0.472
Validation loss decreased (0.041624 --> 0.041439

### 5. Model fitting

The model will be now trained with the optimal number of epochs and all the training data.

In [14]:
epoch = 0
nepoch = 43
cases = dloaderAll.get_labels()

sampler_list = Sampler(cases, 0, 1, nbatch, partsize=[1, 0, 0])
trainingSampler = sampler_list.batch_ind("train")
total_batch_train = len(trainingSampler)

N_in_1 = len(dloaderAll.get_PSD(cases[0]))
N_in_2 = len(dloaderAll.get_Taxon(cases[0]))
N_in_3 = len(dloaderAll.get_Emb(cases[0]))

DeeProtGO_Euka_NK_BP = DNNModel(res_dir, propOutFile = propOutFile, nbatch = nbatch, N_in_1 = N_in_1, 
                       N_in_2 = N_in_2, N_in_3 = N_in_3, pN1_1 = pPSD1, pN1_2 = pPSD2, 
                       pN2_1 = pEmb1, pN2_2 = pEmb2, pN3_1 = pTaxon1, pN3_2 = pTaxon2,
                       pNO_1 = pHidden1, pNO_2 = pHidden2, thresh = thresh, pDrop = pDrop, 
                       activFunc = activFunc, optimMethod = optimMethod,criterion = criterion, 
                       learning_rate = learning_rate, useGPU = use_GPU)

In [15]:
while epoch < nepoch:
    train_loss_A = 0
    for it in range(total_batch_train):
        inData1, inData2, inData3, _, _, _, outData = dloaderAll.get_batch(trainingSampler[it])
        lossA, _ = DeeProtGO_Euka_NK_BP.train(outData, inData1, inData2, inData3)
        train_loss_A += lossA
    train_loss_A = train_loss_A/total_batch_train
    msgA = "%d\t%.6f\n" %(epoch, train_loss_A)
    print(msgA)
    epoch += 1

0	2.201572

1	0.887377

2	0.511639

3	0.442043

4	0.229409

5	0.115245

6	0.085406

7	0.063761

8	0.055834

9	0.051877

10	0.048662

11	0.046085

12	0.044583

13	0.044209

14	0.043724

15	0.043215

16	0.042643

17	0.042263

18	0.042148

19	0.041647

20	0.041475

21	0.041277

22	0.040742

23	0.040903

24	0.040548

25	0.040663

26	0.040289

27	0.039926

28	0.039969

29	0.039501

30	0.040026

31	0.039340

32	0.039014

33	0.038890

34	0.039034

35	0.039401

36	0.038788

37	0.038794

38	0.038560

39	0.038186

40	0.038308

41	0.037926

42	0.037697



In [16]:
torch.save(DeeProtGO_Euka_NK_BP.net.state_dict(), res_dir + 'DeeProtGO_PSD_Emb_Taxon_Euka_BP_NK.pt')
