# Protein classifier training

This notebook trains classifiers for applying ULMFit to protein sequences. The dataproc and ulmptrain notebooks should be run first to create datasets and train language models.

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

### Import libraries:

In [2]:
from fastai import *
from fastai.text import *

import pickle

sys.path.append("..")
import readseq
import ulmputils

seqDirPath = Config().data_path()/'pseq'
ulmputils.modelDirPath = modelDirPath = seqDirPath/'models'

### Create classifier data bunch:
from dataset CSV files generated using makeClasDataset in the dataproc notebook.

In [3]:
bs = 64
def makeClasDatabunch(lmFName,dsName) :
    l = TextList.from_csv(seqDirPath,lmFName+'_'+dsName+'Train.csv', cols='sequence',
                          processor=ulmputils.processor)
    lspl = l.split_from_df('is_valid')
    llab = lspl.label_from_df('label')
    res = llab.databunch(bs=bs)
    return res

In [4]:
# makeClasDatabunch('swissprotPE3AA20NoTM','atpBind')

### Set up hyperparameters for training classifiers:
Should match those used to train language models in the ulmptrain notebook.

In [5]:
pref = 'awd_lstm'
clasPars = dict(awd_lstm_clas_config,
    emb_sz = 400,
    n_hid = 1152,
    n_layers = 3,
    qrnn = False,
)
archPars = dict(
    drop_mult=0.2,
)

### Create and train classifier:

In [6]:
def tryClasTraining(lmFName, dsName, nEpochs, lr,
                    moms=(0.8,0.7), metrics=accuracy, seed=31415) :
    global data, learn, modelFName
    data = makeClasDatabunch(lmFName,dsName)
    if seed is not None :
        ulmputils.randomSeedForTraining(seed)
    learn = text_classifier_learner(data, AWD_LSTM, config=clasPars,
                                    pretrained=False, metrics=metrics,
                                    **archPars)
    print('Loss function:',learn.loss_func)
    encoderFName = ulmputils.getModelFNameFromHyperPars(pref+'_'+lmFName,
                                              (clasPars,awd_lstm_clas_config),archPars)
    learn.load_encoder(encoderFName+'_enc')
    modelFName = encoderFName + '_' + dsName
    #learn.lr_find()
    #learn.recorder.plot()
    learn.fit_one_cycle(nEpochs, lr, moms=moms)

### Calculate precision, recall, etc:

In [7]:
def filterP1AndTarg(thresh,p1AndTarg,fn,reverse=True,returnPred=False) :
    fList = [(i,pred) for i,(targ,pred) in enumerate(p1AndTarg) if fn(targ,pred,thresh)]
    fList.sort(key = lambda x : x[1], reverse=reverse)
    if not returnPred :
        fList = [i for i,pred in fList]
    return fList
def truePos(thresh,p1AndTarg,reverse=True,returnPred=False) :
    return filterP1AndTarg(thresh,p1AndTarg,
                           lambda targ,pred,thresh : targ==1 and pred>=thresh,
                           reverse,returnPred)
def falsePos(thresh,p1AndTarg,reverse=True,returnPred=False) :
    return filterP1AndTarg(thresh,p1AndTarg,
                           lambda targ,pred,thresh : targ==0 and pred>=thresh,
                           reverse,returnPred)
def falseNeg(thresh,p1AndTarg,reverse=True,returnPred=False) :
    return filterP1AndTarg(thresh,p1AndTarg,
                           lambda targ,pred,thresh : targ==1 and pred<thresh,
                           reverse,returnPred)
def printLearnInfo(learn) :
    p = learn.get_preds(ordered=True)
    p1 = p[0][:,1]
    p1AndTarg = list(zip(p[1].tolist(),p1.tolist()))
    validDF, trainDF = learn.data.valid_ds.inner_df, learn.data.train_ds.inner_df
    validIds, validNames, validSeqs = (validDF.identifier.values,
                                       validDF.name.values, validDF.sequence.values)
    trainSeqs, trainLabels = trainDF.sequence.values, trainDF.label.values
    trainPosSeqs = [sequence for sequence,label in zip(trainSeqs,trainLabels)
                    if label=='pos']
    for thresh in [0.05,0.1,0.15,0.2,0.25,0.3,0.4,0.5,0.6,0.7,0.8] :
        nTruePos,nFalsePos,nFalseNeg = (
              len(truePos(thresh,p1AndTarg)),
              len(falsePos(thresh,p1AndTarg)),
              len(falseNeg(thresh,p1AndTarg))
        )
        recall = nTruePos/(nTruePos+nFalseNeg)
        precision = nTruePos/(nTruePos+nFalsePos)
        print('{} - recall {:.3f} precision {:.3f} F1 {:.3f}'.format(thresh,
                        recall,precision,2/(1/recall+1/precision)))
        #print(thresh,'-',
        #      len(truePos(thresh,p1AndTarg)),
        #      len(falsePos(thresh,p1AndTarg)),
        #      len(falseNeg(thresh,p1AndTarg)))
        #if thresh == 0.5 :
        #    print('True pos:')
        #    for i in truePos(thresh,p1AndTarg) :
        #        print(validNames[i],
        #              readseq.findLongestRepeat(validSeqs[i],trainPosSeqs))
        #    print('False neg:',[validNames[i] for i in falseNeg(thresh,p1AndTarg)])
    print('Most likely new positives:')
    for i,pred in falsePos(0.0,p1AndTarg,returnPred=True)[:20] :
        print(pred,validIds[i],validNames[i],
                      #readseq.findLongestRepeat(validSeqs[i],trainPosSeqs)
             )

### Try predicting some properties:

In [11]:
tryClasTraining('swissprotPE3AA20NoTM', 'gtpBind', 5, 1e-2)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.056342,0.019075,0.996797,20:23
1,0.045514,0.014813,0.996923,21:10
2,0.045548,0.023521,0.996123,21:21
3,0.02989,0.009145,0.998061,20:28
4,0.027956,0.012062,0.997809,21:49


In [12]:
printLearnInfo(learn)

0.05 - recall 0.977 precision 0.333 F1 0.496
0.1 - recall 0.973 precision 0.569 F1 0.718
0.15 - recall 0.964 precision 0.721 F1 0.825
0.2 - recall 0.962 precision 0.802 F1 0.874
0.25 - recall 0.957 precision 0.849 F1 0.900
0.3 - recall 0.952 precision 0.873 F1 0.911
0.4 - recall 0.950 precision 0.913 F1 0.931
0.5 - recall 0.948 precision 0.935 F1 0.942
0.6 - recall 0.946 precision 0.954 F1 0.950
0.7 - recall 0.937 precision 0.974 F1 0.955
0.8 - recall 0.928 precision 0.983 F1 0.955
Most likely new positives:
0.983332633972168 A5PJI7 Translation initiation factor eIF-2B subunit gamma
0.9307911396026611 P55243 Glucose-1-phosphate adenylyltransferase large subunit 3, chloroplastic/amyloplastic (Precursor)
0.9252060055732727 P76556 Ethanolamine utilization protein EutP
0.890681803226471 Q00188 Protein TraL
0.8312211036682129 P09396 Movement and silencing protein TGBp1
0.8227682113647461 A4WLK1 Nucleoside-triphosphatase THEP1 {ECO:0000255|HAMAP-Rule:MF_00796}
0.8119547367095947 O43543 DNA r

In [23]:
ulmputils.saveNextModelVersion(learn,modelFName)

Saving awd_lstm_swissprotPE3AA20NoTM_drop_mult_0_2_gtpBind


In [8]:
tryClasTraining('swissprotPE3AA20NoTM', 'atpBind', 5, 1e-2)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.1067,0.07876,0.978254,16:52
1,0.085162,0.072583,0.979096,14:51
2,0.081038,0.061042,0.982089,14:18
3,0.05639,0.052795,0.984154,15:39
4,0.06577,0.050603,0.985207,15:54


In [9]:
printLearnInfo(learn)

0.05 - recall 0.969 precision 0.628 F1 0.762
0.1 - recall 0.932 precision 0.901 F1 0.916
0.15 - recall 0.923 precision 0.946 F1 0.934
0.2 - recall 0.917 precision 0.963 F1 0.939
0.25 - recall 0.912 precision 0.971 F1 0.940
0.3 - recall 0.906 precision 0.978 F1 0.941
0.4 - recall 0.897 precision 0.984 F1 0.939
0.5 - recall 0.892 precision 0.988 F1 0.938
0.6 - recall 0.885 precision 0.991 F1 0.935
0.7 - recall 0.877 precision 0.993 F1 0.932
0.8 - recall 0.869 precision 0.995 F1 0.928
Most likely new positives:
0.9997912049293518 A5G9I1 10 kDa chaperonin {ECO:0000255|HAMAP-Rule:MF_00580}
0.9994750618934631 Q54BN3 Probable replication factor C subunit 3
0.9952515363693237 O94697 Replication factor C subunit 5
0.9842198491096497 P06619 Adenylate dimethylallyltransferase
0.9557788372039795 P14011 Adenylate dimethylallyltransferase
0.9280935525894165 Q6AY16 Actin-like protein 9
0.9275141954421997 O94625 DnaJ-related protein spj1
0.9077320694923401 P07165 Protein virC1
0.8848217129707336 Q9625

In [10]:
ulmputils.saveNextModelVersion(learn,modelFName)

Saving awd_lstm_swissprotPE3AA20NoTM_drop_mult_0_2_atpBind


In [8]:
tryClasTraining('swissprotPE3AA20NoTM', 'metalBind', 10, 5e-3)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.215057,0.162564,0.942852,16:54
1,0.208599,0.149615,0.949343,17:49
2,0.176557,0.136933,0.952714,19:39
3,0.187473,0.148534,0.955116,20:21
4,0.18099,0.143478,0.949343,20:13
5,0.150977,0.14414,0.954948,18:05
6,0.148632,0.124804,0.956802,19:11
7,0.148988,0.126536,0.958656,19:42
8,0.141524,0.125415,0.957898,18:21
9,0.148803,0.125076,0.957224,18:30


In [9]:
printLearnInfo(learn)

0.05 - recall 0.975 precision 0.280 F1 0.435
0.1 - recall 0.953 precision 0.336 F1 0.497
0.15 - recall 0.926 precision 0.397 F1 0.556
0.2 - recall 0.894 precision 0.456 F1 0.604
0.25 - recall 0.855 precision 0.521 F1 0.648
0.3 - recall 0.817 precision 0.589 F1 0.685
0.4 - recall 0.753 precision 0.720 F1 0.736
0.5 - recall 0.699 precision 0.818 F1 0.754
0.6 - recall 0.661 precision 0.891 F1 0.759
0.7 - recall 0.625 precision 0.925 F1 0.746
0.8 - recall 0.586 precision 0.958 F1 0.727
Most likely new positives:
0.9999992847442627 Q0I8S8 Hydroxyacylglutathione hydrolase {ECO:0000255|HAMAP-Rule:MF_01374}
0.9999314546585083 P15451 Cytochrome c
0.999647855758667 Q04UM1 50S ribosomal protein L31 {ECO:0000255|HAMAP-Rule:MF_00501}
0.9994650483131409 Q3ATS5 50S ribosomal protein L31 {ECO:0000255|HAMAP-Rule:MF_00501}
0.9983882904052734 B5YI09 1-deoxy-D-xylulose 5-phosphate reductoisomerase {ECO:0000255|HAMAP-Rule:MF_00183}
0.9981970191001892 Q7URM5 1-deoxy-D-xylulose 5-phosphate reductoisomerase {

In [12]:
ulmputils.saveNextModelVersion(learn,modelFName)

Saving awd_lstm_swissprotPE3AA20NoTM_drop_mult_0_2_metalBind


# OLD STUFF:

In [6]:
data,learn = tryTraining('gtpBind','swissp40epoch_0_2dropmult_53acc_enc',5,1e-2,
                         metrics=accuracy)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.059743,0.028896,0.996039,24:14
1,0.03987,0.009562,0.997946,22:39
2,0.038213,0.009159,0.998056,23:48
3,0.018562,0.010238,0.99802,21:51
4,0.02286,0.010147,0.997763,20:28


In [7]:
printLearnInfo(learn)

0.05 - recall 0.960 precision 0.470 F1 0.631
0.1 - recall 0.954 precision 0.688 F1 0.799
0.15 - recall 0.947 precision 0.800 F1 0.868
0.2 - recall 0.943 precision 0.844 F1 0.891
0.25 - recall 0.934 precision 0.881 F1 0.907
0.3 - recall 0.929 precision 0.901 F1 0.915
0.4 - recall 0.921 precision 0.931 F1 0.926
0.5 - recall 0.918 precision 0.945 F1 0.932
0.6 - recall 0.918 precision 0.961 F1 0.939
0.7 - recall 0.916 precision 0.974 F1 0.944
0.8 - recall 0.896 precision 0.988 F1 0.940
Most likely:
0.9895246624946594 Q7SXY4 Cytoplasmic dynein 2 light intermediate chain 1
0.9474827647209167 P76556 Ethanolamine utilization protein EutP
0.9297962784767151 P80361 Probable translation initiation factor eIF-2B subunit gamma
0.9004318714141846 Q9JU97 3,4-dihydroxy-2-butanone 4-phosphate synthase
0.8350386023521423 Q9BSD7 Cancer-related nucleoside-triphosphatase
0.7747216820716858 Q97X93 UPF0273 protein SSO1861 {ECO:0000255|HAMAP-Rule:MF_01076}
0.773944616317749 Q8Z5I4 Glucose-1-phosphate cytidyly

In [8]:
data,learn = tryTraining('atpBind','swissp40epoch_0_2dropmult_53acc_enc',5,1e-2,
                         metrics=accuracy)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.094848,0.063217,0.982434,18:40
1,0.076341,0.063666,0.983717,16:28
2,0.06111,0.055671,0.984487,17:44
3,0.043388,0.049755,0.986028,17:42
4,0.045106,0.048325,0.986981,17:00


In [9]:
printLearnInfo(learn)

0.05 - recall 0.949 precision 0.832 F1 0.887
0.1 - recall 0.928 precision 0.936 F1 0.932
0.15 - recall 0.918 precision 0.966 F1 0.941
0.2 - recall 0.914 precision 0.974 F1 0.943
0.25 - recall 0.911 precision 0.979 F1 0.944
0.3 - recall 0.907 precision 0.985 F1 0.944
0.4 - recall 0.900 precision 0.987 F1 0.942
0.5 - recall 0.893 precision 0.989 F1 0.938
0.6 - recall 0.886 precision 0.991 F1 0.935
0.7 - recall 0.874 precision 0.993 F1 0.930
0.8 - recall 0.861 precision 0.994 F1 0.923
Most likely:
0.9999616146087646 Q9F4E4 10 kDa chaperonin {ECO:0000255|HAMAP-Rule:MF_00580}
0.9999520778656006 Q03SR0 10 kDa chaperonin {ECO:0000255|HAMAP-Rule:MF_00580}
0.9997187256813049 Q8RKI6 Glycerol-3-phosphate cytidylyltransferase
0.998019814491272 Q94490 Ubiquitin conjugating enzyme E2 B
0.9968795776367188 P05631 ATP synthase subunit gamma, mitochondrial {ECO:0000305} (Precursor)
0.9962904453277588 Q9BYD9 Actin-related protein T3
0.988463819026947 Q73I70 10 kDa chaperonin {ECO:0000255|HAMAP-Rule:MF_00

In [10]:
data,learn = tryTraining('metalBind','swissp40epoch_0_2dropmult_53acc_enc',5,1e-2,
                         metrics=accuracy)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.258967,0.161738,0.946274,23:09
1,0.213522,0.146265,0.949171,22:21
2,0.182258,0.133434,0.955076,21:09
3,0.176618,0.142193,0.956359,20:34
4,0.167359,0.125699,0.959366,20:04


In [11]:
printLearnInfo(learn)

0.05 - recall 0.972 precision 0.237 F1 0.381
0.1 - recall 0.941 precision 0.304 F1 0.460
0.15 - recall 0.900 precision 0.393 F1 0.547
0.2 - recall 0.840 precision 0.487 F1 0.617
0.25 - recall 0.790 precision 0.584 F1 0.672
0.3 - recall 0.744 precision 0.675 F1 0.708
0.4 - recall 0.670 precision 0.803 F1 0.730
0.5 - recall 0.623 precision 0.874 F1 0.727
0.6 - recall 0.583 precision 0.911 F1 0.711
0.7 - recall 0.553 precision 0.942 F1 0.697
0.8 - recall 0.518 precision 0.955 F1 0.671
Most likely:
0.9998219609260559 Q9KDQ1 Non-canonical purine NTP pyrophosphatase homolog
0.9997192025184631 Q89KP9 1-deoxy-D-xylulose 5-phosphate reductoisomerase {ECO:0000255|HAMAP-Rule:MF_00183}
0.9997087121009827 O49203 Nucleoside diphosphate kinase III, chloroplastic/mitochondrial (Precursor)
0.9997019171714783 P84155 Oxygen-dependent coproporphyrinogen-III oxidase
0.9996678829193115 P83584 Ferredoxin
0.9991071820259094 Q969Y2 tRNA modification GTPase GTPBP3, mitochondrial (Precursor)
0.9989118576049805 Q

In [13]:
# learn.freeze_to(-2)
# learn.fit_one_cycle(3, slice(1e-2/(2.6**4),1e-2), moms=(0.8,0.7))

# learn.freeze_to(-3)
# learn.fit_one_cycle(3, slice(5e-3/(2.6**4),5e-3), moms=(0.8,0.7))

# interp = ClassificationInterpretation.from_learner(learn)
# interp.plot_confusion_matrix(figsize=(6,6), dpi=60)