## Basic Training
In this notebook, we will show

* How to construct a training and validation dataset that respect External Symmetry. Disconnection on the BC Clan graph will satisfy fairness in External Symmetry; this forms a testing dataset ready for k-fold cross validation.
* How to train some models in a 3-fold CV scheme. The training will be done with pytorch taking advantage of its dataloader.
* Several effective data augmentation strategies popularised in residual network training.

We will illustrate this with training on classification of the base/nonsite/phosphate/ribose `S,X,P,R` dataset.


## Imports

In [1]:
# ============== Click Please.Imports
import sys
import glob
import gc
import io

import random
random.seed(42)
import pandas as pd
import numpy as np
import networkx as nx

from scipy import sparse
import torch
import seaborn as sns

import matplotlib.pyplot as plt


import time
import tqdm
import collections


import functools
import itertools
import multiprocessing



import torch 
from torch import nn

import torchvision as tv
import pytorch_lightning as pl


sys.path.append('../')
from NucleicNet.DatasetBuilding.util import *
from NucleicNet.DatasetBuilding.commandReadPdbFtp import ReadBCExternalSymmetry, MakeBcClanGraph
from NucleicNet.DatasetBuilding.commandDataFetcher import FetchIndex, FetchTask, FetchDataset
from NucleicNet import Burn, Fuel
import NucleicNet.Burn.util
import NucleicNet.Burn.M1
import  NucleicNet.Burn.DA
%config InlineBackend.figure_format = 'svg'

sns.set_context("notebook")



# Turn on cuda optimizer
print(torch.backends.cudnn.is_available())
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# disable debugs NOTE use only after debugging
torch.autograd.set_detect_anomaly(False)
torch.autograd.profiler.profile(False)
torch.autograd.profiler.emit_nvtx(False)
# Disable gradient tracking
#torch.no_grad()
#torch.inference_mode()

# ================= Click Please. Directories ==================
DIR_DerivedData = "../Database-PDB/DerivedData/"
DIR_Typi = "../Database-PDB/typi/"
DIR_FuelInput = "../Database-PDB/feature/"


True



## Scope of Data

The cell below defines the scope of data to be used in training SXPR classifcation. We have updated the curation to year 2021 but we cannot guarentee the curation using flags below will suffice the need of our community thereafter without an update. As in this classification task, we need not care about base specificity (just the presence of the base is needed!), there are much more entries we can include.

In [2]:

# TODO For SXPR the selection of pdb has to be separated less stringent than AUCG

Df_grand = pd.read_pickle(DIR_DerivedData + "/DataframeGrand.pkl")
Df_grand = Df_grand.loc[(Df_grand["ProNu"] == "prot-nuc") & (Df_grand['Resolution'] <= 3.5) # NOTE you may consider to relax the 3.0 Angstron resolution limit as cryoEM structure w/ ~3.5 angstrom are not uncommon to be modelled in full atom these days
                                          & ~(Df_grand['PubmedID'].isnull()) # NOTE ~78 structures. Note that some are recent unindexed by pdb; most are unpublished structures. Some contains large missing loops.
                                          & (Df_grand['NucleicAcid'].isin(['rna']))
                                          & (pd.notnull(Df_grand['InternalSymmetryBC-95']))
                                          & (Df_grand["Year"] <= 2021)
                                          & (Df_grand["MeanChainLength_Nucleotide"] >= 4) & (Df_grand["SumChainLength_Peptide"] > 50) 
                                          # NOTE Some machineries that do not show preference in base or a disproportionately small amount of sites with preference.
                                          & ~(Df_grand["Title"].str.contains('ribos|trna', regex=True, na=False)) 
                                          & ~(Df_grand["Header"].str.contains('ribos|trna', regex=True, na=False))
#                                          & ~(Df_grand["Title"].str.contains('ribos|riboz|transcript|polymerase|trna|pseudouridine|srp|signal recognition particle| ribonuclease|exosome|spliceosome', regex=True, na=False)) 
#                                          & ~(Df_grand["Header"].str.contains('ribos|riboz|transcript|polymerase|trna|pseudouridine|srp|signal recognition particle| ribonuclease|exosome|spliceosome', regex=True, na=False))
                                          & ~(Df_grand['NpidbClassification'].isin(["TRANSFERASE/RNA",'TRANSFERASE','RIBOSOME'])) # ~170 structures

                                          # NOTE Unpublished but with pubmedid?
                                          & ~(Df_grand['Pdbid'].isin(['3p6y', '2n8m', '3ahu', '3boy']))
                                          # NOTE Cases where metal/interfacial inhibitor/tip of hairpin/water-mediated/marginally/modified base interacts with rna base
                                          & ~(Df_grand['Pdbid'].isin([  
                                                                      
                                                                        '4oq8', '1a34','2bbv','1ddl','4oq9','4nia','4ang','2b2e',# NOTE Virus capsid with overlappping atoms(?)
                                                                        '2bny','2b2g','1aq4','1zse','2iz8','2bq5','2bs1','2izm','2izn',
                                                                        '2c51','2c50', '2izm', '2izn', '6msf', '7msf', '1bmv','1zdh','1zdi', '1zdj', '1zdk','1aq3',# NOTE Assembled virus capsid can sometime miss half of binding sites (e.g. 2c51 3 equiv protein chain but due to overlaps in assebly only one rna chain appear in some copies) 
                                                                        '4oav', # NOTE Modified nucleotide backbone 
                                                                         '7njc', # NOTE Modified nucleotide backbone
                                                                        '4eya', '4ghl','4ghl', '6o5f', '6sx2',
                                                                        '6sx0','6zlc', '2zko',# NOTE water envelop
                                                                        '1m8v', # NOTE very low clashscore... somethings wrong?


                                                                        
                                                                        '5zc9', # NOTE eIF4A1 chemical clamp, water
                                                                        '6xki', # NOTE eIF4A1 chemical clamp, water
                                                                        '6r7g', # NOTE only 6 units assembled at interface
                                                                        #'4bkk', # NOTE nucleoprotein. There is no mention of base interaction through out the article https://www.microbiologyresearch.org/content/journal/jgv/10.1099/vir.0.053025-0
                                                                        #'6yrb','6yrq', # NOTE No base interaction mentioned in paper (Check again)
                                                                        #'1yyw', '2nug', '2nue', '1yz9',# NOTE These is a AU dsRna but prefer GU in other Rnase3 at Q157, 1yz9 makes no contact w/ base
                                                                        #'2bs0', # NOTE RNA at interface of two varial capsid protein symmetry mates
                                                                        #'7n0c','7n0b','7n0d', # NOTE exoribonuclease proof-reading complex but when mismatch the base makes no touch
                                                                        #'2xgj', # NOTE Helicase w/ no touch at base



                                                                        # NOTE Structure solved with Poly-Oligonucleotiude just as a template
                                                                        #'5wwf', '5ho4', # NOTE These are proteins resolved with same interacting sequence. Its siblings 5wwg 5wwe 5wwx makes most contact with the protein
                                                                        '4ht8', '3gib', # NOTE 4ht9 has a higher resolution also with additional uridine sites shown
                                                                        #'4ijs', # NOTE They use a polyA sequence for simplicity. even though there are interaction with some of the bases.
                                                                        '2xbm', # NOTE specificity is in a dinucleotide labeled as G3A
                                                                        
                                                                        '5eeu', '5eev', '5eew', '5eex', '5eey', '5eez', 
                                                                        '5ef0', '5ef1', '5ef2', '5ef3', '1utd', '4v4f', 
                                                                        '1gtf', '1gtn', # NOTE While the protein is the same, RNA does not show up in a pseudo symmetry mate. Half and Half. also note a lot of unmodeled nt https://www.rcsb.org/3d-view/5EEV/1
                                                                        #'6dtd', #  NOTE Cas 13b
                                                                        '2zi0', '4erd', # NOTE single helix contact
                                                                        '6cf2', # NOTE single helix contact

                                                                        #'6mdz', # TODO Ttesting
                                                                        #'5js2', '5ki6', # NOTE Modified base argonaut
                                                                        #'6oon','5vm9','5w6v','4kre','4kxt','4olb','4ola', # NOTE Poly-A sequence bound to argonaut
                                                                        #'5t7b' # NOTE unpublished argonaut
                                                                        #'4z4c', '4z4d', '4z4e', '4z4f', '4z4g', '4z4h', '4z4i', # NOTE This series of pdbid concerns a water mediated recognition site for adenosine on argonaute `Water-mediated recognition of t1-adenosine anchors Argonaute2 to microRNA targets`
                                                                        #'5js1', '4w5o', '4w5q', # NOTE Argonaut structure. 4w5o,q has more missing residue than siblings 4w5t,r,n.
                                                                        #'5wqe', # NOTE multiplebase specific interactions were outlined but most interacts with peptide backbone.
                                                                        #'5wtk', # NOTE 4 base specific interactions were outlined but the structure is ds and some sidechains e.g. 415-416 were stubbed. we will not include it in training

                                                                        # NOTE No specific H bond contact found/does not fulfill Hbond criterion in pymol
                                                                        #'5ztm', # NOTE The claimed interaction at E172, N175, Q195 does not fulfill H-bond criterion in pymol. Find>Polar Contacts
                                                                        #'6h5s','6h5q', # NOTE no specific H bond  contact
                                                                        #'4al7','4al5','4al6', # NOTE base binding site at an unmodeled loop
                                                                        #'4n2s','4n2q','4me2', # NOTE close but no defined H bond 
                                                                        #'6hyu', '6hyt', # NOTE polyA used and no Hbond specific contact
                                                                        


                                                                        #'5t8y', 

                                                                        #'4z92', # NOTE minimal contact in vriys 
                                                                        #'3hsb', # NOTE a AGAGAG aptamer used but the G does not form specific hbond interactions 
                                                                        #'7bg6','7bg7','7nuq','7nun','7nuo','7nul','7num', # NOTE only stack touched
                                                                        #'5f9f', '5f98','5f9h','5e3h','3eqt', # NOTE RIG-I recognise modified base m7G `https://www.pnas.org/doi/full/10.1073/pnas.1515152113`
                                                                        #'5z98','4lg2', # NOTE duplex
                                                                        '2ihx', # NOTE Disordered
                                                                        #'4gv3','4gv6','4gv9','4gve','4g9z', #NOTE backbone only
                                                                        #'7c06', # NOTE it shares same sequence ith 7c08 but poor?
                                                                        
                                                                        #'3ciy' # NOTE dsRNA
                                                                        #'5jbg', # NOTE MDA5
                                                                        '4ill', '4ilm', '4ilr', # NOTE The RNA strand appears broken??? (bonds too long)
                                                                        #'6s8b','6s8e','6shb','6sic','6s91','6s6b', # NOTE Backbone only. marginal interaction

                                                                        '4peh','4peg','4pei','4pef', # NOTE modified base

                                                                        #'5jaj','5jb2','5jbg', # NOTE LGP2 duplex
                                                                        #'4lg2', # NOTE duplex
                                                                        #'4gha','5m73', # NOTE dsrna

                                                                        '3ciy', # NOTE 3.41 angstrom resolution, some sidechain can be highly flexible
                                                                        
                                                                        

                                                                        #'3zd6','3zd7', # NOTE Rig I

                                                                        '3zc0', # NOTE almost no contact
                                                                        '2jlw', # NOTE no contact
                                                                        #'6ozp', '6ozn', '6ozf','6oze', '6ozg', '6ozh','6ozi', '6ozj', '6ozk', '6ozl', '6ozm',  '6ozo',  '6ozq', 
                                                                        #'6ozr','6ozs', # NOTE through backbone
                                                                        #'2gje', # NOTE backbone only
                                                                        #'1f8v', '2bbv', # NOTE backbone only duplex cage in virus capsid

                                                                        
                                                                        '2mxy', # NOTE solution structure with extra nucleotide compare to 2mz1
                                                                       

                                                                        '3pkm', # NOTE missing loop

                                                                        
                                                                        
                                                                        '2bx2', # NOTE Marginal

                                                                        #'6d06', # NOTE modified base dsrna
                                                                        '3dh3', # NOTE Modified base
                                                                        '7kfn', # NOTE Modified base
                                                                        '4i67', # NOTE Modified nt
                                                                        '1jbt','1jbs','1jbr',
                                                                        '6gc5', #NOTE short strand
                                                                        

                                                                        
                                                                        '5uj2', # NOTE marginal; same family as 4e78

                                                                        '7ndh', '7ndi', '7ndj', '7ndk','3d2s','3trz', # NOTE require zinc cage
                                                                        '6l1w', '1rgo', # NOTE zinc finger    
                                                                        '4lj0', '5elk',# NOTE Zn finger short peptide
                                                                        '2hgh','2li8','6wlh', # NOTE Zn finger short peptide

                                                                        '2mqv','2mqt','2ms0','2ms1','2mkn','5u9b','1wwe','1wwf','1wwd','1wwg','2n82','5u9b','1fje','1t4l',
                                                                        '2l3c','2lup','1a1t','2mf1','2mf0','1f6u','1ekz','6gbm','2mfe', 
                                                                        '2mfg', '2mfh','2mff', '4cio','2jpp',#NOTE Disordered NMR solution structures
                                                                        
                                                                        '5c0y','5v7c', # NOTE no contact
                                                                        #'5wea', # NOTE poly A sequence

                                                                        #'6vff', # NOTE dsrna
                                                                        #'7krn', '7kro','7krp', # NOTE Helicase dsrna
                                                                        '4pmi', # NOTE single helix
                                                                        '6yrb', # NOTE The nucleotide is detached?
                                                                        '2vpl', # NOTE require potassium coordination

                                                                        # NOTE Water-mediated or simply in an envelope of water
                                                                        '1wpu','2qux','1wmq','1wrq','4csf',
                                                                        '4qoz', '4tuw','4tux','4tv0','4l8r', # NOTE water duplex
                                                                        '4mdx', # NOTE water
                                                                        '5l2l', # NOTE water
                                                                        '5elh', # NOTE water; 5elk has much tighter contact 
                                                                        '2pjp', '6lt7','6db8','6db9','1c9s','6c6k', '3ts2','5tf6',
                                                                        '4n0t','4kzd','6b3k','5e08','5h1l', '1m5o', '6fq3',
                                                                        '5gxh','4q9q', '6mwn','5det','6u8d','6u8k', '5gxi','6hau','6d12',
                                                                        '2y8y','2y9h','2y8w','4qvc','4f02','6fql','6fq3', '4ht9', # NOTE water
                                                                        ]))
                                          # NOTE Recently indexed shape-dependent machinery (tRNA/exosome/ribosome), but pdb has not updated its derived data
                                          & ~(Df_grand['Pdbid'].isin(['5hr7','5omw','5jea',
                                                                      #'4o26', # NOTE telomerase
                                                                      #'5fmz','5epi', # NOTE polymerase
                                                                      '6zoj', '6zok', '6zol', # NOTE Ribosome
                                                                      '6yan','6yam','6yal', # NOTE ribosome
                                                                      '5iwa', # NOTE ribosome
                                                                      '5e6m', # NOTE trna
                                                                      '5on2','5onh','5on3','5omw','3al0', '3akz', '5e6m', # NOTE tRNA 
                                                                      '1zl3', # NOTE trna specificity at modified base FLO
                                                                      '5ud5','5v6x','4qei','4kqe' # NOTE trna
                                                                      '3jam','3jap','3jaq', # NOTE This is a ribosome
                                                                      #'5ng6', # NOTE Crispr machinery recognise DNA motif TTN but no mention of RNA
                                                                      #'6sh8','6s6b', '6s8b', '6s8e', '6s91', '6shb', '6sic', # NOTE Crispr machinery no mention of base interaction
                                                                    ]))


                                          # NOTE 
                                              ]
#print(pd.unique(Df_grand['NucleicAcid']))
print(Df_grand.shape)
# NOTE Further Remarks on some interesting cases
# 3PTO, 3PTX, 3PU0, 3PU1, 3PU4. uses the same nucleocapsid to bind with poly(A,U,C,G), which they use to test how interaction with each kind of base will look like and they propose UAG as an interesting motif to look for https://journals.asm.org/doi/10.1128/JVI.01927-10
#                               polyG shows largest amount of interaction polyU shows none However at 3.0 Angstrom, the assignment of N161 can be flipped to make interaction with U27 (seem to support by K164)
# 6O1K, 6O1L, 6O1M              `Hfq thus has a structural preference for (ARN)n RNA stretches on its distal side, where N is any nucleotide. `


NmrStates = [ '1aud00000004','1aud00000010','1aud00000002',
              '2l4100000005','2l4100000011','2l4100000013',
              '2xc700000000','2xc700000002','2xc700000006',
              '1dz500000007','1dz500000008','1dz500000002',
              '1k1g00000001','1k1g00000005','1k1g00000007',
              '2ad900000017','2ad900000012','2ad900000019',
              '2adb00000004','2adb00000005','2adb00000014',
              '2adc00000007','2adc00000001','2adc00000000',
              '2c0600000002','2c0600000004','2c0600000009',
              '2cjk00000007','2cjk00000008','2cjk00000012',
              '2err00000003','2err00000016','2err00000006',
              '2fy100000008','2fy100000002','2fy100000000',
              '2kfy00000006','2kfy00000003','2kfy00000001',
              '2kg000000019','2kg000000012','2kg000000000',
              '2kg100000006','2kg100000005','2kg100000003',
              '2kh900000007','2kh900000001','2kh900000005',
              '2km800000004','2km800000007','2km800000006',
              '2kxn00000007','2kxn00000008','2kxn00000001',
              '2l2k00000006','2l2k00000002','2l2k00000007',
              '2l3j00000008','2l3j00000001','2l3j00000002',
              '2l5d00000004','2l5d00000016','2l5d00000008',
              '2lbs00000013','2lbs00000009','2lbs00000005',
              '2leb00000018','2leb00000000','2leb00000016',
              '2lec00000018','2lec00000002','2lec00000007',
              '2m8d00000013','2m8d00000003','2m8d00000010',
              '2mb000000004','2mb000000018','2mb000000001',
              '2mfc00000005','2mfc00000001','2mfc00000015',
              '2mfe00000001','2mfe00000002','2mfe00000013',
              '2mgz00000017','2mgz00000004','2mgz00000009',
              '2mjh00000019','2mjh00000006','2mjh00000009',
              '2mki00000005','2mki00000014','2mki00000002',
              '2mkk00000006','2mkk00000008','2mkk00000004',
              '2mz100000018','2mz100000004','2mz100000003',
              '2n7c00000002','2n7c00000010','2n7c00000007',
              '2n8l00000003','2n8l00000006','2n8l00000004',
              '2rra00000005','2rra00000008','2rra00000009',
              '2rs200000018','2rs200000004','2rs200000017',
              '2ru300000015','2ru300000011','2ru300000018',
              '4cio00000000','4cio00000006','4cio00000008',
              '5m8i00000008','5m8i00000014','5m8i00000006',
              '5mpg00000011','5mpg00000007','5mpg00000003',
              '5mpl00000004','5mpl00000012','5mpl00000002',
              '5n8l00000014','5n8l00000018','5n8l00000013',
              '5n8m00000015','5n8m00000004','5n8m00000002',
              '5x3z00000016','5x3z00000010','5x3z00000001',
              '6gbm00000002','6gbm00000000','6gbm00000011',
              '6hpj00000013','6hpj00000006','6hpj00000012',
              '6snj00000009','6snj00000000','6snj00000002',
              '6tph00000004','6tph00000009','6tph00000001',
              '7act00000009','7act00000008','7act00000000',
              '1t2r00000001','1t2r00000004','1t2r00000009', 
              '4bs200000006', '4bs200000007', '4bs200000009', 
              '2i2y00000001', '2i2y00000004', '2i2y00000014', 
              '2li800000004', '2li800000009', '2li800000015', 
              '2yh100000003', '2yh100000006', '2yh100000007', 
              '2hgh00000006', '2hgh00000010', '2hgh00000013', 
              '2ese00000001', '2ese00000004', '2ese00000011', 
              '6sdw00000002', '6sdw00000015', '6sdw00000016', 
              '2rqc00000002', '2rqc00000007', '2rqc00000018', 
              '1rkj00000001', '1rkj00000006', '1rkj00000009', 
              '6sdy00000006', '6sdy00000009', '6sdy00000017', 
              '6wlh00000007', '6wlh00000016', '6wlh00000017', 
              '4b8t00000005', '4b8t00000013', '4b8t00000015', 
              '7acs00000006', '7acs00000011', '7acs00000013'
 ]



(731, 34)


## Training Options

The cell below will define 9 subfolds with around the same datasize for each task. A 3-fold cross validation will be done with each cross fold containing 3 sub fold. In each training cycle 2 subfolds are resserved for validation 1 for testing; the remaining 6 for training. Some options are

* Task. `User_Task = "SXPR"`.
* Number of cross folds to be done. We recommend `n_CrossFold = 9`.
* Extent of external symmetry (BC percent) to be considered when we separate folds. We recommend `ClanGraphBcPercent = 90`, but 70 seems also affordable. (TODO Check)
* Hierarchy of class labels. We recommend a two level hierarchy `TaskNameLabelLogicDict = {"SXPR":LabelLogic_level0, "AUCG": LabelLogic_level1,}`, but a finer hierarchy `commandDataFetcher.OBSOLETE_TaskNameLabelLogicDict` is also provided if needed.
* Filter using Derived Data from PDB FTP. We recommend filtering as suggested in `Df_grand`.

Some options are machine learning specific hyperparameters and can be tuned in combination if desired. See comments for detail. Some worth mentioning hyperparameters:
* Noise in input/hidden layer.
* Ghost Batch Normalisation. As the size of dataset grow we can no longer afford small-batch-size (typically 128 or less datapoint) training. A remedy popularised in recent year is GBN.
* Multi-step cosine scheduler. `SimpleMultistepCosineLRS` This helps to propose multiple models ready for random forest settings.
* Label smoothing by neighborhood.
* Label smoothing by class.
* Implementation of Bottleneck. This also allow width tuning as in wideresnet. 

Some further remarks 

* When we pack clans of different sizes into the cross folds, we are not aiming at a [bin-packing solution](https://en.wikipedia.org/wiki/Bin_packing_problem), but rather we aim at distributing clans of different sizes evenly among folds. The process will produce a dataframe `TaskClanFoldDf_BC{bc percent}.pkl`, that indicates which pdbids to be included in the fold. 
* While we cannot load all data into RAM, we will make 6 pass from Storage to RAM, where each pass is restricted to hold `User_DesiredBatchDatasize = 3500000` datapoint. 
* Resampling will be done in minibatch.

In [3]:
n_CrossFold = 9
ClanGraphBcPercent = 90
User_featuretype = 'altman'

User_Task = "SXPR"
n_row_maxhold = 10000


# ================ Collapse. Click Please 

User_DesiredBatchDatasize    = 70000000 # NOTE This controls the number of new batch of dataset-dataloader being reloaded into memory
User_SampleSizePerEpoch_Factor = 1.0 # NOTE This controls how much sample enters into an epoch. if < 1.0, the sampler will make less than User_DesiredBatchDatasize sample to be fed in one epoch

User_SampleSizePerEpoch = int(User_DesiredBatchDatasize * User_SampleSizePerEpoch_Factor)
n_datasetworker = 16
User_ExperiementName = 'SXPR-9CV'

DIR_TrainingRoot = "/home/homingla/Project-NucleicNet/Models/"
DIR_TrainLog = "/home/homingla/Project-NucleicNet/Models/" 
#DIR_Checkpoint = "/home/homingla/Project-NucleicNet/Models/AUCG_Resnet50Pretrained/lightning_logs/version_4/checkpoints/epoch=4-step=4689.ckpt"
pl.seed_everything(42)
Combination_SizeMinibatch = [3072]                  # NOTE We have used Ghost Batch Norm with virtual batch size 128
Combination_LabelSmoothing  = [0.12]                # NOTE Default 0.12 when User_NeighborLabelSmoothAngstrom > 0.0. else 0.36
Combination_PerformReduction = [False]              # NOTE Default False. True worsen the performance.
Combination_Activation = ['gelu']                   # NOTE Default gelu 
Combination_n_ResnetBlock = [16]                    # NOTE Default 16 96 ok but lr tune needed
Combination_lr = [1e-3  * 0.3 ]                          # NOTE Default 1e-3 /2 in SXPR
Combination_min_lr = [1e-6]                        # NOTE Default 1e-6
Combination_CooldownInterval = [5000]               # NOTE Default 2000
Combination_AdamW_weight_decay = [0.01 * 3]        # NOTE Default model can tolerate 0.05 but not 0.1. In general 0.01-0.05 are satisfactory. Check Max Performance
Combination_Dropoutp = [0.7]                    # NOTE Default 0.7 model can tolerate 0.7
Combination_AddL1 = [0.000001]                      # NOTE Default 0 0.0001 poorer than 0.000001 
Combination_n_channelbottleneck = [40]          # NOTE Default 40, but 160 leads to simpler model as indicated by L1 of weights? Check
Combination_ShiftLrRatio = [0.01]                   # NOTE Unused
Combination_User_LrScheduler = ["SimpleMultistepCosineLRS_SXPR"]           # NOTE Default SimpleMultistepCosineLRS CosineAnnealingLR DescendingCosineAnnealingLR_HalfEpoch
Combination_User_BiasInSuffixFc = [True]            # NOTE Default True
Combination_User_NoiseX = [0.125 *8]                # NOTE Default 1.0 model can tolerate 1.0-1.5
Combination_User_NoiseY = [0.0]                     # NOTE Unused
Combination_User_Mixup = [False]                    # NOTE Unused. 
Combination_User_NumReductionComponent = [20]       # NOTE Default. Unused unless PerformReduction = True
Combination_User_NoiseZ = [0.125 *8]                # NOTE Default 1.5
Combination_User_NeighborLabelSmoothAngstrom = [0.0] # NOTE Default 0.0. 
Combination_User_InputDropoutp = [0.01]             # NOTE Default 0.1 finalise after tuning all hyperparameters
Combination_User_Loss = ["CrossEntropyLoss"]        # NOTE CrossEntropyLoss 

Combination_User_FocalLossAlpha = [0.25]            # NOTE Default 0.25 No effect if focal loss not used.
Combination_User_FocalLossGamma = [2.0]             # NOTE Default 2. Note gamma == 0 returns CE
Combination_User_GradientClippingValue = [1e5] # clip gradients' global norm to <= this number larger network may need larger clip? default 10000 TODO Test
combinations = [
                Combination_SizeMinibatch,
                Combination_LabelSmoothing,
                Combination_PerformReduction,
                Combination_Activation,

                Combination_n_ResnetBlock,
                Combination_lr,
                Combination_CooldownInterval,
                Combination_AdamW_weight_decay,
                Combination_min_lr,
                Combination_Dropoutp,
                Combination_AddL1,
                Combination_n_channelbottleneck,
                Combination_ShiftLrRatio,
                Combination_User_LrScheduler,
                Combination_User_BiasInSuffixFc,
                Combination_User_NoiseX,
                Combination_User_NoiseY,
                Combination_User_Mixup,
                Combination_User_NumReductionComponent,
                Combination_User_NoiseZ,
                Combination_User_NeighborLabelSmoothAngstrom,
                Combination_User_InputDropoutp,
                Combination_User_Loss,
                Combination_User_FocalLossAlpha,
                Combination_User_FocalLossGamma,
                Combination_User_GradientClippingValue,
                ]

# result contains all possible combinations.
CombinationList = list(itertools.product(*combinations))
print(CombinationList)


Global seed set to 42


[(3072, 0.12, False, 'gelu', 16, 0.0003, 5000, 0.03, 1e-06, 0.7, 1e-06, 40, 0.01, 'SimpleMultistepCosineLRS_SXPR', True, 1.0, 0.0, False, 20, 1.0, 0.0, 0.01, 'CrossEntropyLoss', 0.25, 2.0, 100000.0)]


In [4]:


# ========================= Auto 


FetchTaskC = FetchTask(DIR_DerivedData = DIR_DerivedData,
                              DIR_Typi = DIR_Typi,
                              DIR_FuelInput = DIR_FuelInput,
                              Df_grand = Df_grand,
                              TaskNameLabelLogicDict = None,
                              n_row_maxhold = n_row_maxhold)

# =========================
# Get Definition of Tasks
# =========================
# NOTE This collects task name and how to get corresponding data in typi 
TaskNameLabelLogicDict = FetchTaskC.Return_TaskNameLabelLogicDict()
#print(TaskNameLabelLogicDict)


print(FetchTaskC.TaskNameLabelLogicDict)

# =======================
# Task Clan Fold Dataframe
# ======================= 
# NOTE each element contains 3 tuple train val test
CrossFoldDfList = FetchTaskC.Return_CrossFoldDfList(n_CrossFold = n_CrossFold, 
                                                      ClanGraphBcPercent = ClanGraphBcPercent, 
                                                      User_Task = User_Task,
                                                      Factor_ClampOnMaxSize = 450000,  # NOTE Constraint on datasize of a clan. For SXPR, this is raised as the number of datapoint per entry is much larger.
                                                      Factor_ClampOnMultistate = 20,   # NOTE Constriant on number of multistate file read
                                                      NmrStates = NmrStates
                                                      )


{'SXPR': {'Base': {'union': ['A', 'U', 'C', 'G'], 'exclu': [], 'intersect': []}, 'Nonsite': {'union': ['nonsite_'], 'exclu': ['F'], 'intersect': []}, 'P': {'union': ['P'], 'exclu': [], 'intersect': []}, 'R': {'union': ['R'], 'exclu': [], 'intersect': []}}, 'AUCG': {'A': {'union': ['A'], 'exclu': [], 'intersect': ['nucsite_']}, 'U': {'union': ['U'], 'exclu': [], 'intersect': ['nucsite_']}, 'C': {'union': ['C'], 'exclu': [], 'intersect': ['nucsite_']}, 'G': {'union': ['G'], 'exclu': [], 'intersect': ['nucsite_']}}}


In [5]:
for cccc in CombinationList:
  for User_SelectedCrossFoldIndex in [0,3,6]:

    print(cccc)
    # ==========================
    # Hyperparam 
    # ===============================
    PART0_InitialiseHyperparameters = True
    if PART0_InitialiseHyperparameters:
    # ==========================
    # Hyperparam 
    # ===============================

        User_SizeMinibatch = cccc[0] #256 
        User_LabelSmoothing = cccc[1] #0.16 
        User_PerformReduction = cccc[2] #True 
        User_Activation = cccc[3] #'gelu'
        User_n_ResnetBlock = cccc[4]#16 
        User_lr = cccc[5] #1e-3      
        n_Restart = 1  
        User_CooldownInterval = cccc[6] #951
        User_AdamW_weight_decay = cccc[7] #1e-2
        User_min_lr = cccc[8] #1e-6


        User_Dropoutp = cccc[9]
        User_AddL1 = cccc[10]
        User_n_channelbottleneck = cccc[11]
        User_ShiftLrRatio = cccc[12]


        # NOTE Currently fixed for benchmarking
        User_LrScheduler = cccc[13]   
        User_BiasInSuffixFc = cccc[14]
        User_NoiseX = cccc[15]
        User_NoiseY = cccc[16]
        User_Mixup = cccc[17] # NOTE Not used.
        User_NumReductionComponent = cccc[18]
        User_NoiseZ = cccc[19]
        User_NeighborLabelSmoothAngstrom = cccc[20]
        User_InputDropoutp = cccc[21]
        User_Loss = cccc[22]
        User_FocalLossAlpha = cccc[23]
        User_FocalLossGamma = cccc[24]
        User_GradientClippingValue = cccc[25]
        #print(User_GradientClippingValue)
        #sys.exit()

        FetchDatasetC = FetchDataset(
            DIR_DerivedData = DIR_DerivedData,
            DIR_Typi = DIR_Typi,
            DIR_FuelInput = DIR_FuelInput,
            User_DesiredDatasize    = User_DesiredBatchDatasize, # NOTE This controls the number of new batch of dataset-dataloader being reloaded into memory
            User_SampleSizePerEpoch_Factor = User_SampleSizePerEpoch_Factor, # NOTE This controls how much sample enters into an epoch
            User_featuretype = User_featuretype,
            n_datasetworker = n_datasetworker,
            ClanGraphBcPercent = ClanGraphBcPercent)

        classindex_str = sorted(TaskNameLabelLogicDict[User_Task].keys()) 
        ClassName_ClassIndex_Dict = dict(zip(classindex_str, range(len(classindex_str))))

    # ============================
    # Get Cross-Folds and Batches
    # ============================
    print("Getting TrainValTest batches")
    PART1A_GetCrossFolds = True
    if PART1A_GetCrossFolds:

        # NOTE Pdbids, Datasize weight
        Train_PdbidBatches, TrainFold_PdbidSamplingWeight = CrossFoldDfList[User_SelectedCrossFoldIndex][0]
        Val_PdbidBatches, ValFold_PdbidSamplingWeight = CrossFoldDfList[User_SelectedCrossFoldIndex][1]
        Testing_PdbidBatches,TestingFold_PdbidSamplingWeight  = CrossFoldDfList[User_SelectedCrossFoldIndex][2]

        print(len(Train_PdbidBatches), len(Val_PdbidBatches), len(Testing_PdbidBatches), len(set(Testing_PdbidBatches+Val_PdbidBatches+Train_PdbidBatches)))
        Train_PdbidWeight = dict(
                TrainFold_PdbidSamplingWeight[["Pdbid", "PdbidSamplingWeight"]].values.tolist()
                )
        Val_PdbidWeight = dict(
                ValFold_PdbidSamplingWeight[["Pdbid", "PdbidSamplingWeight"]].values.tolist()
                )
        Testing_PdbidWeight = dict(
                TestingFold_PdbidSamplingWeight[["Pdbid", "PdbidSamplingWeight"]].values.tolist()
                )

    if User_Task == "AUCG":
        User_datastride = 1
    else:
        User_datastride = 30 # NOTE I cannot take in all the data in RAM >40GB. This only applies on nonsite as it's much larger than any other classes. Still 22GB of RAM.


    PART1B_DatasetDataloader = True
    if PART1B_DatasetDataloader:
        # NOTE Train
        ds_train, ds_train_samplingweight = FetchDatasetC.GetDataset(
                        Assigned_PdbidBatch = Train_PdbidBatches,
                        Assigned_PdbidWeight = Train_PdbidWeight,
                        User_NumReductionComponent = User_NumReductionComponent,
                        ClassName_ClassIndex_Dict = ClassName_ClassIndex_Dict,
                        User_datastride = User_datastride,
                        User_Task = User_Task,
                        PerformZscoring = True, 
                        PerformReduction = User_PerformReduction,
                        User_NeighborLabelSmoothAngstrom = User_NeighborLabelSmoothAngstrom 
                        )
                        
        train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
                        ds_train_samplingweight, User_SampleSizePerEpoch, replacement=True)
        train_loader  = torch.utils.data.DataLoader(ds_train, batch_size=User_SizeMinibatch, drop_last=True, num_workers=4, 
                                                            pin_memory=True,worker_init_fn=None, prefetch_factor=3, persistent_workers=False,
                                                            sampler = train_sampler)

        # NOTE Val
        ds_val, ds_val_samplingweight = FetchDatasetC.GetDataset(
                        Assigned_PdbidBatch = Val_PdbidBatches,
                        Assigned_PdbidWeight = Val_PdbidWeight,
                        User_NumReductionComponent = User_NumReductionComponent,
                        ClassName_ClassIndex_Dict = ClassName_ClassIndex_Dict,
                        User_Task = User_Task,
                        User_datastride = 10,                                    # NOTE Memory problem forces us to do it!,
                        PerformZscoring = True, 
                        PerformReduction = User_PerformReduction,
                        User_NeighborLabelSmoothAngstrom = User_NeighborLabelSmoothAngstrom 
                        )
        val_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            ds_val_samplingweight, int(User_SampleSizePerEpoch/100), replacement=True)
        val_loader          = torch.utils.data.DataLoader(ds_val, batch_size=int(ds_val.__len__()/100), drop_last=False, num_workers=4, 
                                                            pin_memory=True,worker_init_fn=None, prefetch_factor=3, persistent_workers=False,
                                                            shuffle=False, sampler = val_sampler)  

        #NOTE Test
        """
        ds_testing, ds_testing_samplingweight = FetchDatasetC.GetDataset(
                        Assigned_PdbidBatch = Testing_PdbidBatches,
                        Assigned_PdbidWeight = Testing_PdbidWeight,
                        User_NumReductionComponent = User_NumReductionComponent,
                        ClassName_ClassIndex_Dict = ClassName_ClassIndex_Dict,
                        User_Task = User_Task,
                        PerformZscoring = True, 
                        PerformReduction = User_PerformReduction,
                        )
        
        testing_sampler = torch.utils.data.sampler.WeightedRandomSampler(
            ds_testing_samplingweight, int(User_SampleSizePerEpoch/100), replacement=True)
        testing_loader          = torch.utils.data.DataLoader(ds_testing, batch_size=int(ds_testing.__len__()/100), drop_last=False, 
                                                            num_workers=4, 
                                                            pin_memory=True,worker_init_fn=None, prefetch_factor=3, persistent_workers=False,
                                                            shuffle=False, sampler = testing_sampler) 
        """





    #sys.exit()

        
    # =====================
    # Define Model
    # ======================
    print("Training model constr")
    PART2_DefineModel = True
    if PART2_DefineModel:
        if User_PerformReduction:
            n_FeatPerShell = User_NumReductionComponent
            hw_product = n_FeatPerShell*6
        else:
            n_FeatPerShell = 80
            hw_product = 80*6

        model = NucleicNet.Burn.M1.B1hw_FcLogits(
                        model   = NucleicNet.Burn.M1.B1hw_LayerResnetBottleneck(n_FeatPerShell = n_FeatPerShell, 
                                                    n_Shell = 6,
                                                    n_ShellMix = 2,
                                                    User_Activation = User_Activation,
                                                    User_Block = "B1hw_BlockPreActResnet",
                                                    n_Blocks = User_n_ResnetBlock,
                                                    ManualInitiation = False,
                                                    User_n_channelbottleneck = User_n_channelbottleneck,
                                                    User_NoiseZ = User_NoiseZ,
                                                    ),

                        #loss    = customloss, 
                        User_Loss = User_Loss, 
                        n_class = 4,
                        hw_product = hw_product,
                        AddMultiLabelSoftMarginLoss = False, # TODO Worsen stuff? One-vs-all likely of no use.
                        User_lr = User_lr,
                        User_min_lr = User_min_lr,
                        User_LrScheduler = User_LrScheduler,
                        User_CooldownInterval = User_CooldownInterval,
                        BiasInSuffixFc = User_BiasInSuffixFc, 
                        # NOTE some kwargs for hparam record
                        User_SizeMinibatch = User_SizeMinibatch,
                        User_LabelSmoothing = User_LabelSmoothing,
                        User_PerformReduction = User_PerformReduction,
                        User_n_ResnetBlock = User_n_ResnetBlock,
                        User_AdamW_weight_decay = User_AdamW_weight_decay,
                        User_Activation = User_Activation,
                        User_SelectedCrossFoldIndex = User_SelectedCrossFoldIndex,
                        User_Dropoutp = User_Dropoutp,
                        User_AddL1 = User_AddL1,
                        User_n_channelbottleneck = User_n_channelbottleneck,
                        User_ShiftLrRatio = User_ShiftLrRatio,
                        User_NoiseX = User_NoiseX,
                        User_NoiseY = User_NoiseY,
                        #User_Mixup = User_Mixup,
                        User_NumReductionComponent = User_NumReductionComponent,
                        User_NoiseZ = User_NoiseZ,
                        User_PdbidTraining = Train_PdbidBatches,
                        User_PdbidValidation = Val_PdbidBatches,
                        User_PdbidTesting = Testing_PdbidBatches,
                        User_InputDropoutp = User_InputDropoutp,
                        User_FocalLossAlpha = User_FocalLossAlpha,
                        User_FocalLossGamma = User_FocalLossGamma,
                        User_n_CrossFold = n_CrossFold,
                        User_ClanGraphBcPercent = ClanGraphBcPercent,
                        User_Task = User_Task,
                        User_NeighborLabelSmoothAngstrom = User_NeighborLabelSmoothAngstrom,
                        User_GradientClippingValue = User_GradientClippingValue,
                        User_datastride = User_datastride,
                    )



        NucleicNet.Burn.util.ResetAllParameters(model)



    # ====================
    # Stage 0 training
    # ====================
    print("Training fit")
    trainer00 = NucleicNet.Burn.util.DefaultTrainer00(DIR_TrainLog = DIR_TrainLog, 
                                                        DIR_TrainingRoot = DIR_TrainingRoot, 
                                                        User_ExperiementName = User_ExperiementName,
                                                        User_SizeMinibatch = User_SizeMinibatch ,
                                                        User_ShiftLrRatio = User_ShiftLrRatio,
                                                        User_Mixup = User_Mixup,
                                                        User_GradientClippingValue = User_GradientClippingValue)
    trainer00.logger._log_graph = True 
    trainer00.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)



    del model, trainer00
    gc.collect()



(3072, 0.12, False, 'gelu', 16, 0.0003, 5000, 0.03, 1e-06, 0.7, 1e-06, 40, 0.01, 'SimpleMultistepCosineLRS_SXPR', True, 1.0, 0.0, False, 20, 1.0, 0.0, 0.01, 'CrossEntropyLoss', 0.25, 2.0, 100000.0)
Getting TrainValTest batches
604 196 93 893
Concating Dataset


100%|██████████| 604/604 [00:00<00:00, 2450.10it/s]


Finished Concat data. Cooling down
2316286 2316286
{0: 5.708879031009014, 1: 6.634975313565939, 2: 5.777437135364963, 3: 5.589715166041031}
Concating Dataset


100%|██████████| 196/196 [00:00<00:00, 1471.09it/s]


Finished Concat data. Cooling down
1504209 1504209
{0: 1.6749151586803217, 1: 6.626822002292714, 2: 1.4950520025591696, 3: 1.5732519099659006}
Training model constr


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Training fit



  | Name          | Type                       | Params
-------------------------------------------------------------
0 | nested_module | B1hw_LayerResnetBottleneck | 311 K 
1 | prefix_layerD | Sequential                 | 0     
2 | suffix_layerA | Sequential                 | 0     
3 | suffix_layerD | Sequential                 | 230 K 
4 | suffix_layerZ | Sequential                 | 1.9 K 
5 | loss          | CrossEntropyLoss           | 0     
-------------------------------------------------------------
543 K     Trainable params
0         Non-trainable params
543 K     Total params
2.175     Total estimated model params size (MB)


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]

  rank_zero_warn(


                                                                      

Global seed set to 42


Epoch 0:  44%|████▍     | 12111/27486 [2:51:44<3:38:01,  1.18it/s, loss=0.844, v_num=_137, train_loss_s=0.827, val_loss_s=1.360]10000 20000 0.00029899999999999995
Epoch 2: 100%|██████████| 27486/27486 [6:28:32<00:00,  1.18it/s, loss=0.722, v_num=_137, train_loss_s=0.727, val_loss_s=1.460]


FIT Profiler Report

Action                             	|  Mean duration (s)	|Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------
Total                              	|  -              	|_              	|  6.9793e+04     	|  100 %          	|
--------------------------------------------------------------------------------------------------------------------------------------
run_training_epoch                 	|  2.3261e+04     	|3              	|  6.9783e+04     	|  99.985         	|
run_training_batch                 	|  0.87166        	|68358          	|  5.9585e+04     	|  85.373         	|
optimizer_step_with_closure_0      	|  0.86931        	|68358          	|  5.9424e+04     	|  85.143         	|
training_step_and_backward         	|  0.68047        	|68358          	|  4.6515e+04     	|  66.647         	|
backward                           

## Epilogue

Remember to train AUCG before moving on!