# POC - over fit the EMSCcomplex model on the toy data
Lets demonstrate training and inference using `ESMCcomplex` model. 

We will perform over-fitting of the model on the toy data and than inference.

**As a prior step, you must "vectorize" the data**, meaning, to encode the assay descriptions of each instance using BART. These encoded vectors will be the labels for the ESMC training.

The data is already splitted into `train`, `test`, `val` and `toy` in the csv files in `data/`. 

To vectorize it into pickle files use (from the main repo folder):
`$ python vectorize.py -d`

In [1]:
REPO_FOLDER = '..'

import sys
sys.path.append(REPO_FOLDER) 

import argparse
from os.path import join
from pathlib import Path
import re 
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import EsmTokenizer, BartTokenizer, BartModel, BartForConditionalGeneration
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

from crystoper import config
from crystoper.processor import filter_by_pdbx_details_length, filter_for_single_entities
from crystoper.utils.general import vprint, make_parent_dirs
from crystoper.esmc_models import ESMCcomplex
from crystoper.trainer import ESMCTrainer, seq2sent, train_model
from crystoper.dataset import  Sequence2BartDataset



device = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
esm_model = ESMCcomplex()
esm_tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")    
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
bart_model = BartForConditionalGeneration.from_pretrained('facebook/bart-base') 
data_path = join(REPO_FOLDER, config.toy_path)
vectors_path = join(REPO_FOLDER, config.details_vectors_path, 'toy', 'bart_vectors_0.pkl')



In [None]:
# We will first perform inference using the untrained model 
data = pd.read_csv(data_path)

X = data['sequence']
Y_true = data['pdbx_details'] 

bart_model.to(device)
esm_model.to(device)

print(f'\n\nStarting inference of {len(data)} instances!')

for x, y_true in tqdm(zip(X,Y_true)):
    pred = seq2sent(x, esm_model, esm_tokenizer, bart_model, bart_tokenizer, ac=True)
    y_true = y_true.replace("\n", " ")
    print(f'True sentence: {y_true}')
    print(f'Pred sentence: {pred}')
    print('\n\n')

  with autocast():


True sentence: 10-12% PEG3350, 0.1M BIS-TRIS
0.2-0.3M MG ACETATE, 0.1M GdCl3 
10% glycerol, 5 mM TCEP
Pred sentence:  what song.�'�. in – and, and — ....� (. and the' ands. and s.. and with the.S...".' and. -- N.  and all.. - and suddenly and or.'.Â and.ss�� – and. and. - even, h (s0s and and and....suk.s. not and] … -�sÂs.�" (ss —) not� and') not and) -�.s W.  �) and, - -.s...Âs0Âs -s.s'Â0Â (s."0s.' *ss...." "ss0" blocks0s . *2's. (' PÂss .4�ÂS'] and thes andÂÂÂs] and both s all–s of ors /ÂÂ5s and� ands andS. and/--Âs and/Â1' of the-0.s both "s both today - - andÂ conf®ÂS--sÂÂ



True sentence: 0.2 M imidazole malate, 25%(w/v) PEG4000
Pred sentence: h ("s.". each1.em for"., that in.anns"ened-8, and (The my all- andard. and, and- - (s – (s ls1 (K2." first of a now nows – M "sard- not as now now in?7 now and so suchs- now possibly and now now with now or (ome now, so now now of " from and and now before " (sone. out. (" ( Y.... * "A l ())]"-T ([ "I./) only1 (1.Sask as and out-n/ (ska+1 

In [3]:
# Now we will train the model. Because this is just a POC example, we will over fit it on the toy data.
# you must run `$ python vectorize.py -d` from the main repo folder first

def train_model(model, 
                train_loader,
                loss_fn,
                optimizer,
                batch_size,
                device,
                verbose=True):
    

    model.train()  # Set model to training mode
    running_train_loss = 0.0

    # Training loop
    for batch_idx, batch in enumerate(train_loader):
        optimizer.zero_grad()  # Clear gradients

        # Forward pass
        output_matrices = model(batch['input_ids'], attention_mask=batch['attention_mask'])
        loss = loss_fn(output_matrices, batch['target_matrices'])

        # Backward pass
        loss.backward()
        optimizer.step()


    return loss.item()

#params
n_epochs = 30
batch_size = 2
loss_fn = nn.MSELoss()
lr = 1e-3
optimizer=optim.Adam(esm_model.parameters(), lr=lr)
loses = []

#create dataset and loader for this piece of data
print(f"Loading train data from {vectors_path}....")
data = torch.load(vectors_path)
train_dataset = Sequence2BartDataset(data['sequences'], data['det_vecs'], esm_tokenizer, device=device)
train_loader = DataLoader(train_dataset, batch_size=batch_size, collate_fn=train_dataset.collate)



for epoch in range(n_epochs):
    print(f'Starting epoch {epoch+1}')


    loss = train_model(esm_model, train_loader, loss_fn,
                                            optimizer,  batch_size,
                                            device)
    loses.append(loss)


    print(f'Finished epoch {epoch + 1}. Train loss: {loss}')
    

  data = torch.load(vectors_path)
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Loading train data from ../vectors/details/toy/bart_vectors_0.pkl....
Starting epoch 1
Finished epoch 1. Train loss: 0.18874993920326233
Starting epoch 2
Finished epoch 2. Train loss: 0.12248475849628448
Starting epoch 3
Finished epoch 3. Train loss: 0.11135702580213547
Starting epoch 4
Finished epoch 4. Train loss: 0.075321726500988
Starting epoch 5
Finished epoch 5. Train loss: 0.045750416815280914
Starting epoch 6
Finished epoch 6. Train loss: 0.02309129387140274
Starting epoch 7
Finished epoch 7. Train loss: 0.019204216077923775
Starting epoch 8
Finished epoch 8. Train loss: 0.015923021361231804
Starting epoch 9
Finished epoch 9. Train loss: 0.013394526205956936
Starting epoch 10
Finished epoch 10. Train loss: 0.011154514737427235
Starting epoch 11
Finished epoch 11. Train loss: 0.009472673758864403
Starting epoch 12
Finished epoch 12. Train loss: 0.00902880821377039
Starting epoch 13
Finished epoch 13. Train loss: 0.008758622221648693
Starting epoch 14
Finished epoch 14. Train los

In [4]:
#Now we will perform inference on same data with the over-fitted model
data = pd.read_csv(data_path)

X = data['sequence']
Y_true = data['pdbx_details'] 

bart_model.to(device)
esm_model.to(device)

print(f'\n\nStarting inference of {len(data)} instances!')

for x, y_true in tqdm(zip(X,Y_true)):
    pred = seq2sent(x, esm_model, esm_tokenizer, bart_model, bart_tokenizer, ac=True)
    y_true = y_true.replace("\n", " ")
    print(f'True sentence: {y_true}')
    print(f'Pred sentence: {pred}')
    print('\n\n')



Starting inference of 10 instances!


  with autocast():
1it [00:11, 11.74s/it]

True sentence: 10-12% PEG3350, 0.1M BIS-TRIS 0.2-0.3M MG ACETATE, 0.1M GdCl3  10% glycerol, 5 mM TCEP
Pred sentence: 10-12% PEG3350, 0.1M BIS-TRIS 0.2-0.3M MG ACACACACATACATATATACAC1-1.0-0-1,0.1-0,0-2.0.5 mM GIS-1:0.0:0-4.0





2it [00:26, 13.50s/it]

True sentence: 0.2 M imidazole malate, 25%(w/v) PEG4000
Pred sentence: 0.2 M imidazole malformate.1.2.3.5.6.4.1 (1.3)(1.6)





3it [00:36, 12.03s/it]

True sentence: 1 uL + 1uL drop with 500 uL reservoir solution: 100 mM Tris-MES pH 6.5, 75 mM K-Citrate, 24-28% w/w PEG550 MME and 10% (v/v) glycerol; protein buffer: 10 mM Tris-MES pH 6.0, 100 mM KCl, 10% (v/v) glycerol, 40 mM n-octyl-beta-D-glucoside
Pred sentence: 1 uL + 1uL drop with 500 uL reservoir solution: 100 mM Tris-MES pH 6.5, 75 mM K-Citrate, 24-28% w/w PEG, 50 mM MME and 10% (v/v) glycerol; protein buffer: 10 mg Tris+MES 0.0, 100 mM KCl, 10% dioxamine, 10 mM n-octyl-beta-D-glucoside





4it [00:39,  8.51s/it]

True sentence: 2 M ammonium sulfate
Pred sentence: 2 M ammonium sulfate





5it [00:47,  8.21s/it]

True sentence: 0.9 M sodium citrate, 0.1 M imidazole, 25 mM 2-mercaptoethanol, and 2 mM succinyl-CoA, pH 8.2, VAPOR DIFFUSION, HANGING DROP, temperature 298K
Pred sentence: 0.9 M sodium citrate, 0.1 M imidazole, 25 mM 2-mercaptoethanol, and 2 mM succinyl-CoA, pH 8.2, VAPOR DIFFUSION, VANGING DROP, temperature 298K





6it [01:04, 11.27s/it]

True sentence: 25-30% PEG 4000, 0.1M sodium citrate, pH5.6, 0.2M ammonium acetate , VAPOR DIFFUSION, HANGING DROP, temperature 20K
Pred sentence: 25-30% PEGI, 0.1M sodium citrate, pH5.1, pH6.0, pH 5.6, VAPOR DIFFUSION, VEMPORATORATOR, RAPORATOR, VIPORATOR , VAPATOR, HAPORABLE, VOPORATOR: VIPATOR.VIPATOR: RAPATOR.AVATOR.RAPATOR:VAMPATOR.vipATOR.wavATOR.infectATOR.pathpath.pathparser.pathochondrator.pathentity.pathinterface.pathinfectATOR:vipregor.pathadaptator.pathacetoxicity.pathplugin.pathoptanimATOR:pathoptoxicitypathpathpath:pathadaptimmunepathpathinfectinfectedpathpathwithpathoptimmunepathadaptanimpathpathSTDOUTpathpath\":pathoptibalpathpath177pathpathadaptpathpath7601pathpathruntime.pathshiftpathpath":""},{"pathpathtimeoutpathpathusercpathpath181pathpathdestroypathpath\/\/pathpathescriptionpathpath70710pathpath"/>pathpathparenpathpathCLASSIFIEDpathpathREDACTEDpathpathFINESTpathpath@@@@@@@@pathpathminecraftpathpath ├pathpathNAMEpathpathicterpathpathPATH





7it [01:10,  9.63s/it]

True sentence: 0.1 M SPG (succinic acid, sodium phosphate monobasic monohydrate, and glycine) buffer pH 7.0, 25% w/v PEG 1500
Pred sentence: 0.1 M SPG (succinic acid, sodium phosphate monobasic monohydrate, and glycine) buffer pH 7.0, 25% w/v PEG.





8it [01:24, 10.72s/it]

True sentence: 50 MM ACETATE AT PH 4.7 CONTAINING 0.9 M NACL
Pred sentence: 50.1.0.1:1.2.1,5.0,0.7.0-0.9.0;0.6.5.5,0:1,0256.0:0.5-1.6,0257.0(0.0),0.4.0,"0.8.0"pathpathpath.0."0.3.5"path.path.infectinfectedwithinfectinfectinfectionsinfectedinfectinfectioninfectinfectionainfectinfectivinfectinfectiveinfectinfectimmuneinfectinfectivainfectinfectivesinfectinfecticutinfectinfectivoinfectinfectiousinfectinfectivaninfectinfectelineinfectinfectoralinfectinfectaliainfectinfectarrayinfectinfectvinfectinfectendifinfectinfectsinfectinfectsylvaniainfectinfectneainfectinfectdatainfectinfectaurainfectinfectochondriainfectinfectavinfectinfectitiainfectinfectuniainfectinfectaddressinfectinfectISONinfectinfectococinfectinfectivelyinfectinfectvidinfectinfect130infectinfectspeciesinfectinfectdisableinfectinfectnessinfectinfectantinfectinfectrelinfectinfect123infectinfectampooinfectinfectraginfectinfect177infectinfectableinfectinfectaginfectinfectalginfectinfectawareinfectinfectagascarinfectinfecthyperinfe

9it [01:32, 10.00s/it]

True sentence: 0.1 M HEPES pH 7.0, 12-15% PEG 4000
Pred sentence: 0.1.2.0.3.5.1,0.0,0:0.6.0;0.4.0





10it [01:48, 10.88s/it]

True sentence: LISO4 2.5M, HEPES 100 MM PH 8.5
Pred sentence: Libraries:1.1.2.4.3.5.6.0.1-1.5-1,5-6.9-7.6-5.8-5-5,5.0-6,5,6-6-4.6,6.5,7.0,7-7-6."6.6"Libraries:"Libraries":Libraries.1,"Libraries."Libraries,"Languages."Languages":Languages:Languages:"Languages":"Languages":["Languages)."Languages.character.character."character.path.path."character."path."pathpathpath).pathpath.characterpathpaths.pathpathwithpathpath*.pathpathoptoptoptpathoptpathpath.>>pathpathhealthpathpathadaptpathpathruntime.pathoptimmunepathpathdisablepathoptoxicitypathpathinfectpathpath@pathoptocomputer.pathtmlpathpathfindpathpath/.pathoptopathpathpath\":pathpathjavascriptpathpathdescriptionpathpathparenpathpathPATHpathpathtimeoutpathpathchildrenpathpathacterspathpathPathpathpathmesspathpath ├pathpathphysicalpathpath\/\/pathpath/*pathpath






