In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai import *
from fastai.text import *

import pickle

sys.path.append("..")
import readseq
import ulmputils

seqDirPath = Config().data_path()/'pseq'

In [3]:
bs = 64

In [4]:
def makeClasDatabunch(fName,trSetName) :
    l = TextList.from_csv(seqDirPath,fName+trSetName+'Train.csv', cols='sequence',
                          processor=ulmputils.processor)
    lspl = l.split_from_df('is_valid')
    llab = lspl.label_from_df('label')
    res = llab.databunch(bs=bs)
    return res

def tryTraining(trSetName, encoderFName, nEpochs, lr, fName='swissprotPE3',
                train_seed=31415, drop_mult=0.2, metrics=[], moms=(0.8,0.7),
                **kwargs) :
    data = makeClasDatabunch(fName,trSetName)
    if train_seed is not None :
        ulmputils.randomSeedForTraining(train_seed)
    learn = text_classifier_learner(data, AWD_LSTM, pretrained=False,
                                    drop_mult=drop_mult, metrics=metrics,
                                    **kwargs)
    print('Loss function:',learn.loss_func)
    learn.load_encoder(encoderFName)
    #learn.lr_find()
    #learn.recorder.plot()
    learn.fit_one_cycle(nEpochs, lr, moms=moms)
    return data,learn

In [5]:
def filterP1AndTarg(thresh,p1AndTarg,fn,reverse=True,returnPred=False) :
    fList = [(i,pred) for i,(targ,pred) in enumerate(p1AndTarg) if fn(targ,pred,thresh)]
    fList.sort(key = lambda x : x[1], reverse=reverse)
    if not returnPred :
        fList = [i for i,pred in fList]
    return fList
def truePos(thresh,p1AndTarg,reverse=True,returnPred=False) :
    return filterP1AndTarg(thresh,p1AndTarg,
                           lambda targ,pred,thresh : targ==1 and pred>=thresh,
                           reverse,returnPred)
def falsePos(thresh,p1AndTarg,reverse=True,returnPred=False) :
    return filterP1AndTarg(thresh,p1AndTarg,
                           lambda targ,pred,thresh : targ==0 and pred>=thresh,
                           reverse,returnPred)
def falseNeg(thresh,p1AndTarg,reverse=True,returnPred=False) :
    return filterP1AndTarg(thresh,p1AndTarg,
                           lambda targ,pred,thresh : targ==1 and pred<thresh,
                           reverse,returnPred)
def printLearnInfo(learn) :
    p = learn.get_preds(ordered=True)
    p1 = p[0][:,1]
    p1AndTarg = list(zip(p[1].tolist(),p1.tolist()))
    validDF, trainDF = learn.data.valid_ds.inner_df, learn.data.train_ds.inner_df
    validIds, validNames, validSeqs = (validDF.identifier.values,
                                       validDF.name.values, validDF.sequence.values)
    trainSeqs, trainLabels = trainDF.sequence.values, trainDF.label.values
    trainPosSeqs = [sequence for sequence,label in zip(trainSeqs,trainLabels)
                    if label=='pos']
    for thresh in [0.05,0.1,0.15,0.2,0.25,0.3,0.4,0.5,0.6,0.7,0.8] :
        nTruePos,nFalsePos,nFalseNeg = (
              len(truePos(thresh,p1AndTarg)),
              len(falsePos(thresh,p1AndTarg)),
              len(falseNeg(thresh,p1AndTarg))
        )
        recall = nTruePos/(nTruePos+nFalseNeg)
        precision = nTruePos/(nTruePos+nFalsePos)
        print('{} - recall {:.3f} precision {:.3f} F1 {:.3f}'.format(thresh,
                        recall,precision,2/(1/recall+1/precision)))
        #print(thresh,'-',
        #      len(truePos(thresh,p1AndTarg)),
        #      len(falsePos(thresh,p1AndTarg)),
        #      len(falseNeg(thresh,p1AndTarg)))
        #if thresh == 0.5 :
        #    print('True pos:')
        #    for i in truePos(thresh,p1AndTarg) :
        #        print(validNames[i],
        #              readseq.findLongestRepeat(validSeqs[i],trainPosSeqs))
        #    print('False neg:',[validNames[i] for i in falseNeg(thresh,p1AndTarg)])
    print('Most likely:')
    for i,pred in falsePos(0.0,p1AndTarg,returnPred=True)[:20] :
        print(pred,validIds[i],validNames[i],
                      #readseq.findLongestRepeat(validSeqs[i],trainPosSeqs)
             )

In [6]:
data,learn = tryTraining('gtpBind','swissp40epoch_0_2dropmult_53acc_enc',5,1e-2,
                         metrics=accuracy)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.059743,0.028896,0.996039,24:14
1,0.03987,0.009562,0.997946,22:39
2,0.038213,0.009159,0.998056,23:48
3,0.018562,0.010238,0.99802,21:51
4,0.02286,0.010147,0.997763,20:28


In [7]:
printLearnInfo(learn)

0.05 - recall 0.960 precision 0.470 F1 0.631
0.1 - recall 0.954 precision 0.688 F1 0.799
0.15 - recall 0.947 precision 0.800 F1 0.868
0.2 - recall 0.943 precision 0.844 F1 0.891
0.25 - recall 0.934 precision 0.881 F1 0.907
0.3 - recall 0.929 precision 0.901 F1 0.915
0.4 - recall 0.921 precision 0.931 F1 0.926
0.5 - recall 0.918 precision 0.945 F1 0.932
0.6 - recall 0.918 precision 0.961 F1 0.939
0.7 - recall 0.916 precision 0.974 F1 0.944
0.8 - recall 0.896 precision 0.988 F1 0.940
Most likely:
0.9895246624946594 Q7SXY4 Cytoplasmic dynein 2 light intermediate chain 1
0.9474827647209167 P76556 Ethanolamine utilization protein EutP
0.9297962784767151 P80361 Probable translation initiation factor eIF-2B subunit gamma
0.9004318714141846 Q9JU97 3,4-dihydroxy-2-butanone 4-phosphate synthase
0.8350386023521423 Q9BSD7 Cancer-related nucleoside-triphosphatase
0.7747216820716858 Q97X93 UPF0273 protein SSO1861 {ECO:0000255|HAMAP-Rule:MF_01076}
0.773944616317749 Q8Z5I4 Glucose-1-phosphate cytidyly

In [8]:
data,learn = tryTraining('atpBind','swissp40epoch_0_2dropmult_53acc_enc',5,1e-2,
                         metrics=accuracy)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.094848,0.063217,0.982434,18:40
1,0.076341,0.063666,0.983717,16:28
2,0.06111,0.055671,0.984487,17:44
3,0.043388,0.049755,0.986028,17:42
4,0.045106,0.048325,0.986981,17:00


In [9]:
printLearnInfo(learn)

0.05 - recall 0.949 precision 0.832 F1 0.887
0.1 - recall 0.928 precision 0.936 F1 0.932
0.15 - recall 0.918 precision 0.966 F1 0.941
0.2 - recall 0.914 precision 0.974 F1 0.943
0.25 - recall 0.911 precision 0.979 F1 0.944
0.3 - recall 0.907 precision 0.985 F1 0.944
0.4 - recall 0.900 precision 0.987 F1 0.942
0.5 - recall 0.893 precision 0.989 F1 0.938
0.6 - recall 0.886 precision 0.991 F1 0.935
0.7 - recall 0.874 precision 0.993 F1 0.930
0.8 - recall 0.861 precision 0.994 F1 0.923
Most likely:
0.9999616146087646 Q9F4E4 10 kDa chaperonin {ECO:0000255|HAMAP-Rule:MF_00580}
0.9999520778656006 Q03SR0 10 kDa chaperonin {ECO:0000255|HAMAP-Rule:MF_00580}
0.9997187256813049 Q8RKI6 Glycerol-3-phosphate cytidylyltransferase
0.998019814491272 Q94490 Ubiquitin conjugating enzyme E2 B
0.9968795776367188 P05631 ATP synthase subunit gamma, mitochondrial {ECO:0000305} (Precursor)
0.9962904453277588 Q9BYD9 Actin-related protein T3
0.988463819026947 Q73I70 10 kDa chaperonin {ECO:0000255|HAMAP-Rule:MF_00

In [10]:
data,learn = tryTraining('metalBind','swissp40epoch_0_2dropmult_53acc_enc',5,1e-2,
                         metrics=accuracy)

Loss function: FlattenedLoss of CrossEntropyLoss()


epoch,train_loss,valid_loss,accuracy,time
0,0.258967,0.161738,0.946274,23:09
1,0.213522,0.146265,0.949171,22:21
2,0.182258,0.133434,0.955076,21:09
3,0.176618,0.142193,0.956359,20:34
4,0.167359,0.125699,0.959366,20:04


In [11]:
printLearnInfo(learn)

0.05 - recall 0.972 precision 0.237 F1 0.381
0.1 - recall 0.941 precision 0.304 F1 0.460
0.15 - recall 0.900 precision 0.393 F1 0.547
0.2 - recall 0.840 precision 0.487 F1 0.617
0.25 - recall 0.790 precision 0.584 F1 0.672
0.3 - recall 0.744 precision 0.675 F1 0.708
0.4 - recall 0.670 precision 0.803 F1 0.730
0.5 - recall 0.623 precision 0.874 F1 0.727
0.6 - recall 0.583 precision 0.911 F1 0.711
0.7 - recall 0.553 precision 0.942 F1 0.697
0.8 - recall 0.518 precision 0.955 F1 0.671
Most likely:
0.9998219609260559 Q9KDQ1 Non-canonical purine NTP pyrophosphatase homolog
0.9997192025184631 Q89KP9 1-deoxy-D-xylulose 5-phosphate reductoisomerase {ECO:0000255|HAMAP-Rule:MF_00183}
0.9997087121009827 O49203 Nucleoside diphosphate kinase III, chloroplastic/mitochondrial (Precursor)
0.9997019171714783 P84155 Oxygen-dependent coproporphyrinogen-III oxidase
0.9996678829193115 P83584 Ferredoxin
0.9991071820259094 Q969Y2 tRNA modification GTPase GTPBP3, mitochondrial (Precursor)
0.9989118576049805 Q

In [12]:
learn

RNNLearner(data=TextClasDataBunch;

Train: LabelList (375088 items)
x: TextList
xxbos M A R P L L G K T S S V R R R L E S L S A C S I F F F L R K F C Q K M A S L V F L N S P V Y Q M S N I L L T E R R Q V D R A M G G S D D D G V M V V A L S P S D F K T V L G S A L L A V E R D M V H V V P K Y L Q T P G I L H D M L V L L T P I F G E A L S V D M S G A T D V M V Q Q I A T A G F V D V D P L H S S V S W K D N V S C P V A L L A V S N A V R T M M G Q P C Q V T L I I D V G T Q N I L R D L V N L P V E M S G D L Q V M A Y T K D P L G K V P A V G V S V F D S G S V Q K G D A H S V G A P D G L V S F H T H P V S S A V E L N Y H A G W P S N V D M S S L L T M K N L M H V V V A E E G L W T M A R T L S M Q R L T K V L T D A E K D V M R A A A F N L F L P L N E L R V M G T K D S N N K S L K T Y F E V F E T F T I G A L M K H S G V T P T A F V D R R W L D N T I Y H M G F I P W G R D M R F V V E Y D L D G T N P F L N T V P T L M S V K R K A K I Q E M F D N M V S R M V T S,xxbos M I K L F C V L A A F I S I N S 

In [13]:
# learn.freeze_to(-2)
# learn.fit_one_cycle(3, slice(1e-2/(2.6**4),1e-2), moms=(0.8,0.7))

# learn.freeze_to(-3)
# learn.fit_one_cycle(3, slice(5e-3/(2.6**4),5e-3), moms=(0.8,0.7))

# interp = ClassificationInterpretation.from_learner(learn)
# interp.plot_confusion_matrix(figsize=(6,6), dpi=60)