In [1]:
from ABDB import database as db
import matplotlib.pyplot as plt
import numpy as np
from retrain_ablooper import *
import torch
import pandas as pd
from rich import print as pprint
from ABlooper import CDR_Predictor



Own code for relaxing sturcutres

In [2]:
# torch settings
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_dtype(torch.float)

model = MaskDecoyGen(decoys=5).to(device = device).float()
model.load_state_dict(torch.load('best_models/best_model-2804-Radam-5-2optim', map_location=torch.device(device)))

batch_size = 1

train = torch.load('train_data/train.pt')
validation = torch.load('train_data/val.pt')
test = torch.load('train_data/test.pt')

val_dataloader = torch.utils.data.DataLoader(validation, 
                                             batch_size=batch_size,
                                             num_workers=1,
                                             shuffle=False,
                                             pin_memory=True,
                                             )

test_dataloader = torch.utils.data.DataLoader(test[:2], 
                                              batch_size=batch_size,
                                              num_workers=1,
                                              shuffle=False,
                                              pin_memory=True,
                                              )

In [3]:
for i in range(5):
    tmp = MaskDecoyGen(decoys=1).to(device = device).float()
    dict=torch.load("best_models/best_model-0305-Radam-1-2optim-"+str(i+1), map_location=torch.device(device))
    tmp.load_state_dict(dict)
    weights = tmp.blocks[0].state_dict()

    model.blocks[i].load_state_dict(weights)

In [14]:
def produce_full_structures_of_val_set(val_dataloader, model, outdir='', relax=True, to_be_rewritten=["H1", "H2", "H3", "L1", "L2", "L3"]):
    '''
    Produces full FAB structure for a dataset
    '''
    CDR_rmsds_not_relaxed = list()
    CDR_rmsds_relaxed = list()
    decoy_diversities = list()
    order_of_pdbs = list()

    with torch.no_grad():
        model.eval()

        for data in track(val_dataloader, description='predict val set'):

            # predict sturcture using the model
            coordinates, geomout, node_feature, mask, id = data['geomins'].float().to(device), data['geomouts'].float().to(device), data['encodings'].float().to(device), data['mask'].float().to(device), data['ids']
            pred = model(node_feature, coordinates, mask)
            CDR_rmsds_not_relaxed.append(rmsd_per_cdr(pred, node_feature, geomout).tolist())
            pred = pred.squeeze() # remove batch dimension
            
            # get framework info from pdb file
            pdb_id, heavy_c, light_c, pdb_file = get_info_from_id(id)
            chains = [heavy_c, light_c]
            order_of_pdbs.append(pdb_id)

            with open(pdb_file) as file:
                pdb_text = [line for line in file.readlines()]
                
            pdb_text = pdb_select_hc_lc(pdb_text, chains)

            CDR_with_anchor_slices, atoms, CDR_text, CDR_sequences, CDR_numberings, CDR_start_atom_id = get_framework_info(pdb_text, chains)

            predicted_CDRs = {}
            all_decoys = {}
            decoy_diversity = {}

            for i, CDR in enumerate(CDR_with_anchor_slices):
                output_CDR = pred[:, node_feature[0, :, 30 + i] == 1.0]
                all_decoys[CDR] = rearrange(output_CDR, "b (i a) d -> b i a d", a=4).cpu().numpy()
                predicted_CDRs[CDR] = rearrange(output_CDR.mean(0), "(i a) d -> i a d", a=4).cpu().numpy()
                decoy_diversity[CDR] = (output_CDR[None] - output_CDR[:, None]).pow(2).sum(-1).mean(-1).pow(
                    1 / 2).sum().item() / 20
            
            decoy_diversities.append(list(decoy_diversity.values()))
            
            text_prediction_per_CDR = convert_predictions_into_text_for_each_CDR(CDR_start_atom_id, predicted_CDRs, CDR_sequences, CDR_numberings, CDR_with_anchor_slices)
            old_text = pdb_text

            for CDR in to_be_rewritten:
                new = True
                new_text = []
                chain, CDR_slice = CDR_with_anchor_slices[CDR]
                CDR_slice = (CDR_slice[0] + 2, CDR_slice[1] - 2)

                for line in old_text:
                    if not filt(line, chain, CDR_slice):
                        new_text.append(line)
                    elif new:
                        new_text += text_prediction_per_CDR[CDR]
                        new = False
                    else:
                        continue
                old_text = new_text

            header = [
                "REMARK    CDR LOOPS REMODELLED USING ABLOOPER                                   \n"]
            new_text = header + old_text

            with open('pdbs/'+outdir+'/'+pdb_id+'-'+heavy_c+light_c+'.pdb', "w+") as file:
                file.write("".join(new_text))

            with open('pdbs/'+outdir+'/'+pdb_id+'-'+heavy_c+light_c+'-true.pdb', "w+") as file:
                file.write("".join(pdb_text))

            if relax:
                relaxed_text = openmm_refine(old_text, CDR_with_anchor_slices)
                header.append("REMARK    REFINEMENT DONE USING OPENMM" + 42 * " " + "\n")
                relaxed_text = header + relaxed_text

                with open('pdbs/'+outdir+'/'+pdb_id+'-'+heavy_c+light_c+'-relaxed.pdb', "w+") as file:
                    file.write(''.join(relaxed_text))


                # calculate rmsds of relaxed structures
                CDR_with_anchor_slices, atoms, CDR_text, CDR_sequences, CDR_numberings, CDR_start_atom_id = get_framework_info(relaxed_text, chains)
                CDR_BB_coords = extract_BB_coords(CDR_text, CDR_with_anchor_slices, CDR_sequences, atoms)

                relaxed_coords = prepare_model_output([CDR_BB_coords])[0]
            
                relaxed_coords = pad_tensor(relaxed_coords)
                relaxed_coords = rearrange(relaxed_coords, 'i x -> () () i x')
                CDR_rmsds_relaxed.append(rmsd_per_cdr(relaxed_coords, node_feature, geomout).tolist())


    return CDR_rmsds_not_relaxed, CDR_rmsds_relaxed, decoy_diversities, order_of_pdbs

In [8]:
dir = '0305-Radam-1-2optim-test'

In [9]:
cdr_rmsds, CDR_rmsds_relaxed, decoy_diversities, pdb_ids = produce_full_structures_of_val_set(test_dataloader, 
                                                                                              model, 
                                                                                              outdir=dir, 
                                                                                              relax=True)

In [None]:
with open('pdbs/'+dir+'/metrics.json', 'w') as f:
    json.dump({'pdb_ids': pdb_ids, 'cdr_rmsds': cdr_rmsds, 'cdr_rmsds_relaxed': cdr_rmsds_relaxed, 'decoy_divsersity': decoy_diversities}, f)
    

In [13]:
cdr_rmsds

[[0.39621397852897644,
  0.5520344972610474,
  2.0543265342712402,
  0.5045629143714905,
  0.1996065378189087,
  1.3235737085342407],
 [1.4963270425796509,
  1.0473147630691528,
  2.906479835510254,
  0.6791009902954102,
  0.2234359234571457,
  1.0305887460708618]]

In [14]:
 CDR_rmsds_relaxed

[[0.39689359068870544,
  0.5317484736442566,
  2.0721490383148193,
  0.5045245885848999,
  0.19938620924949646,
  1.3232002258300781],
 [1.4967271089553833,
  1.0466426610946655,
  2.9059810638427734,
  0.6716121435165405,
  0.2233588546514511,
  1.0306532382965088]]

Calculate rmsds of relaxed structures

In [1]:
import os
from ABDB import database as db
import matplotlib.pyplot as plt
import numpy as np
from retrain_ablooper import *
import torch
import pandas as pd
from rich import print as pprint
from ABlooper import CDR_Predictor
from einops import rearrange

Database path /Volumes/LaCie/sabdab-sabpred/data/ABDB was not found.


Cannot do refinement


In [2]:
files = os.listdir('/Users/fabian/Desktop/2804-RAdam-5-2optim-test')
files

['2ypv-HL-true.pdb',
 '2v17-HL-relaxed.pdb',
 '2vxv-HL-true.pdb',
 '1fns-HL.pdb',
 '3oz9-HL.pdb',
 '2v17-HL-true.pdb',
 '2e27-HL-true.pdb',
 '3gnm-HL-true.pdb',
 '4hpy-HL.pdb',
 '3oz9-HL-relaxed.pdb',
 '3g5y-BA-relaxed.pdb',
 '1gig-HL-true.pdb',
 '4f57-HL-true.pdb',
 '3g5y-BA.pdb',
 '2fb4-HL-relaxed.pdb',
 '3t65-BA-relaxed.pdb',
 '3eo9-HL.pdb',
 '2xwt-AB-relaxed.pdb',
 '3hc4-HL.pdb',
 '4h20-HL.pdb',
 '1fns-HL-true.pdb',
 '3t65-BA-true.pdb',
 '3ifl-HL-relaxed.pdb',
 '4nzu-HL.pdb',
 '2r8s-HL-relaxed.pdb',
 '1oaq-HL.pdb',
 '1oaq-HL-true.pdb',
 '2w60-AB-relaxed.pdb',
 '1mfa-HL-relaxed.pdb',
 '3hc4-HL-relaxed.pdb',
 '1mlb-BA.pdb',
 '1nlb-HL.pdb',
 '1oaq-HL-relaxed.pdb',
 '3liz-HL-true.pdb',
 '2fb4-HL-true.pdb',
 '1jpt-HL-true.pdb',
 '1dlf-HL.pdb',
 '1dlf-HL-relaxed.pdb',
 '3t65-BA.pdb',
 '2e27-HL-relaxed.pdb',
 '2r8s-HL.pdb',
 '2vxv-HL-relaxed.pdb',
 '1nlb-HL-relaxed.pdb',
 '2adf-HL-true.pdb',
 '4nzu-HL-relaxed.pdb',
 '3v0w-HL-relaxed.pdb',
 '2ypv-HL-relaxed.pdb',
 '4hpy-HL-relaxed.pdb',
 '

In [5]:
def cdr_rmsds_on_pdb_file(prediction_path, truth_path, chains):

    pred = CDR_Predictor(prediction_path, chains = chains)
    pred_coords = prepare_model_output([pred.CDR_BB_coords])[0]
    pred_coords = rearrange(pred_coords, 'i x -> () () i x')
    pred_nodefeatures = pred.prepare_model_input()[0]
    
    truth = CDR_Predictor(truth_path, chains)
    truth_coords = prepare_model_output([truth.CDR_BB_coords])[0]
    truth_coords = rearrange(truth_coords, 'i x -> () i x')

    cdr_rmsds = rmsd_per_cdr(pred_coords, pred_nodefeatures, truth_coords)

    return cdr_rmsds

def rmsds_pred_and_relaxed(dir):

    files = os.listdir(dir)
    files = [file[:7] for file in files if file[0] != '.']
    files = list(set(files))
    
    ids = list()
    rmsds_pred = np.zeros((len(files), 6))
    rmsds_relaxed = np.zeros((len(files), 6))

    for i in range(len(files)):
        chains = (files[i][-2], files[i][-1])
        
        truth_path = dir+files[i]+'-true.pdb'
        pred_path = dir+files[i]+'.pdb'
        relaxed_path = dir+files[i]+'-relaxed.pdb'

        ids.append(files[i])
        rmsds_pred[i,:] = np.array(cdr_rmsds_on_pdb_file(pred_path, truth_path, chains).tolist())
        rmsds_relaxed[i,:] = np.array(cdr_rmsds_on_pdb_file(relaxed_path, truth_path, chains).tolist())

    return ids, rmsds_pred, rmsds_relaxed

In [7]:
dir = '/Users/fabian/Desktop/2804-RAdam-5-2optim-test/'

ids, rmsd_pred, rmsd_relaxed = rmsds_pred_and_relaxed(dir)


In [8]:
rmsd_pred.mean(0)

array([1.90520121, 1.96798921, 2.92938134, 1.86615316, 1.77889952,
       1.97423427])

In [None]:
from ABDB import database as db
from ABDB.AbPDB import AntibodyParser
import numpy as np
import Bio.PDB
parser = AntibodyParser(PERMISSIVE=True, QUIET=True)
parser.set_numbering_scheme("imgt")
db.set_numbering_scheme("imgt")
backbone = ["CA","C","N", "CB"]

def CDR_rmsds(pdb, file, fab_n = 0, chains = ["H", "L"], decoy_chains = ["H", "L"]):
    fab = db.fetch(pdb).fabs[fab_n]
    rmsds = {}
    for h_or_l in chains:
        #Truth load
        chain = db.db_summary[pdb]["fabs"][fab_n][h_or_l+"chain"]
        truth = fab.get_structure()[chain]
        #Decoy load
        decoy = parser.get_antibody_structure(pdb+"_model", file)[0][{"H":decoy_chains[0], "L":decoy_chains[1]}[h_or_l]]
        #Numbering
        numb = [(" ", *x) for x in fab.get_numbering()[h_or_l]][2:-2]
        #Get residues to align
        truth_res = [truth[x] for x in numb if (x in decoy) and (x in truth)]
        decoy_res = [decoy[x] for x in numb if (x in decoy) and (x in truth)]
        #Get atoms to align
        fixed = []
        moved = []
        for i in range(len(decoy_res)):
            fixed += [truth_res[i][atom] for atom in backbone if (atom in decoy_res[i]) and (atom in truth_res[i])]
            moved += [decoy_res[i][atom] for atom in backbone if (atom in decoy_res[i]) and (atom in truth_res[i])]
        #Calculate superimposer and move decoy
        imposer = Bio.PDB.Superimposer()
        imposer.set_atoms(fixed, moved)
        imposer.apply(decoy.get_atoms())
        rmsds[h_or_l] = imposer.rms
        # Find CDR definitions
        loop_definitions = {x[3:]:[(" ", *y[0]) for y in fab.get_CDR_sequences(definition="imgt")[x]]  for x in fab.get_CDR_sequences(definition="imgt")}
        # Calculate RMSD for each CDR
        for CDR in loop_definitions:
            if CDR[0] == h_or_l:
                true_loop = []
                decoy_loop = []
                for res in loop_definitions[CDR]:
                    if (res in truth) and (res in decoy):
                        true_loop += [truth[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
                        decoy_loop+= [decoy[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
                #Calculate RMSD
                rmsds[CDR] = np.sqrt(np.mean(3*(np.array(true_loop) - np.array(decoy_loop))**2))
                
    # Calculate RMSD for framework
    ignore = sum([loop_definitions[x] for x in loop_definitions if x[0] == h_or_l], [])
    frame_def = [x for x in numb if x not in ignore]
    
    true_frame, decoy_frame = [], []
    for res in frame_def:
        if (res in truth) and (res in decoy):
            true_frame += [truth[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
            decoy_frame+= [decoy[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
    rmsds[h_or_l] = np.sqrt(np.mean(3*(np.array(true_frame) - np.array(decoy_frame))**2))
                
    return rmsds