In [1]:
from ABDB import database as db
import matplotlib.pyplot as plt
import numpy as np
from retrain_ablooper import *
import torch
import pandas as pd
from rich import print as pprint
from ABlooper import CDR_Predictor



Own code for relaxing sturcutres

In [5]:
# torch settings
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_dtype(torch.float)

model = MaskDecoyGen(decoys=5).to(device = device).float()
model.load_state_dict(torch.load('best_models/best_model-2804-Radam-5-2optim', map_location=torch.device(device)))

batch_size = 1

train = torch.load('train_data/train.pt')
validation = torch.load('train_data/val.pt')
test = torch.load('train_data/test.pt')

val_dataloader = torch.utils.data.DataLoader(validation, 
                                             batch_size=batch_size,
                                             num_workers=1,
                                             shuffle=False,
                                             pin_memory=True,
                                             )

test_dataloader = torch.utils.data.DataLoader(test, 
                                              batch_size=batch_size,
                                              num_workers=1,
                                              shuffle=False,
                                              pin_memory=True,
                                              )

In [None]:
model = MaskDecoyGen(decoys=5).to(device = device).float()
for i in range(5):
    model.blocks[i].load_state_dict(torch.load("best_models/best_model-0305-Radam-1-2optim-1", map_location=torch.device(device)))

In [11]:
def pdb_select_hc_lc(pdb_text, chains):
    '''
    Returns only lines of a pdb file which correspond to the heavy or light chain of the antibody
    '''
    atoms = [line for line in pdb_text if line.split()[0] == 'ATOM']
    pdb_text_hc_lc = [line for line in atoms if line.split()[4] in chains]
    
    return pdb_text_hc_lc
    

def produce_full_structures_of_val_set(val_dataloader, model, outdir='', relax=True, to_be_rewritten=["H1", "H2", "H3", "L1", "L2", "L3"]):
    '''
    Produces full FAB structure for a dataset
    '''
    CDR_rmsds_not_relaxed = list()
    CDR_rmsds_relaxes = list()
    decoy_diversities = list()
    order_of_pdbs = list()

    with torch.no_grad():
        model.eval()

        for data in track(val_dataloader, description='predict val set'):

            # predict sturcture using the model
            coordinates, geomout, node_feature, mask, id = data['geomins'].float().to(device), data['geomouts'].float().to(device), data['encodings'].float().to(device), data['mask'].float().to(device), data['ids']
            pred = model(node_feature, coordinates, mask)
            CDR_rmsds_not_relaxed.append(rmsd_per_cdr(pred, node_feature, geomout).tolist())
            pred = pred.squeeze() # remove batch dimension
            
            # get framework info from pdb file
            pdb_id, heavy_c, light_c, pdb_file = get_info_from_id(id)
            chains = [heavy_c, light_c]
            order_of_pdbs.append(pdb_id)

            with open(pdb_file) as file:
                pdb_text = [line for line in file.readlines()]
                
            pdb_text = pdb_select_hc_lc(pdb_text, chains)

            CDR_with_anchor_slices, atoms, CDR_text, CDR_sequences, CDR_numberings, CDR_start_atom_id = get_framework_info(pdb_text, chains)
            
            predicted_CDRs = {}
            all_decoys = {}
            decoy_diversity = {}

            for i, CDR in enumerate(CDR_with_anchor_slices):
                output_CDR = pred[:, node_feature[0, :, 30 + i] == 1.0]
                all_decoys[CDR] = rearrange(output_CDR, "b (i a) d -> b i a d", a=4).cpu().numpy()
                predicted_CDRs[CDR] = rearrange(output_CDR.mean(0), "(i a) d -> i a d", a=4).cpu().numpy()
                decoy_diversity[CDR] = (output_CDR[None] - output_CDR[:, None]).pow(2).sum(-1).mean(-1).pow(
                    1 / 2).sum().item() / 20
            
            decoy_diversities.append(list(decoy_diversity.values()))
            
            text_prediction_per_CDR = convert_predictions_into_text_for_each_CDR(CDR_start_atom_id, predicted_CDRs, CDR_sequences, CDR_numberings, CDR_with_anchor_slices)
            old_text = pdb_text

            for CDR in to_be_rewritten:
                new = True
                new_text = []
                chain, CDR_slice = CDR_with_anchor_slices[CDR]
                CDR_slice = (CDR_slice[0] + 2, CDR_slice[1] - 2)

                for line in old_text:
                    if not filt(line, chain, CDR_slice):
                        new_text.append(line)
                    elif new:
                        new_text += text_prediction_per_CDR[CDR]
                        new = False
                    else:
                        continue
                old_text = new_text

            header = [
                "REMARK    CDR LOOPS REMODELLED USING ABLOOPER                                   \n"]
            new_text = "".join(header + old_text)

            with open('pdbs/'+outdir+'/'+pdb_id+'-'+heavy_c+light_c+'.pdb', "w+") as file:
                file.write(new_text)

            with open('pdbs/'+outdir+'/'+pdb_id+'-'+heavy_c+light_c+'-true.pdb', "w+") as file:
                file.write("".join(pdb_text))

            if relax:
                relaxed_text = openmm_refine(old_text, CDR_with_anchor_slices)
                header.append("REMARK    REFINEMENT DONE USING OPENMM" + 42 * " " + "\n")
                relaxed_text = "".join(header + relaxed_text)

                with open('pdbs/'+outdir+'/'+pdb_id+'-'+heavy_c+light_c+'-relaxed.pdb', "w+") as file:
                    file.write(relaxed_text)

    return CDR_rmsds_not_relaxed, decoy_diversities, order_of_pdbs

In [12]:
%%time
cdr_rmsds, decoy_diversities, pdb_ids = produce_full_structures_of_val_set(val_dataloader, model, outdir='test', relax=True)

CPU times: user 1h 42min 28s, sys: 1min 26s, total: 1h 43min 54s
Wall time: 3min 4s


In [14]:
cdr_rmsds, pdb_ids, decoy_diversities

([[0.9121501445770264,
   0.8906557559967041,
   2.2079901695251465,
   0.4420602023601532,
   0.4020904302597046,
   0.46330466866493225]],
 ['5hdq'],
 [[0.4799493789672852,
   0.6126280784606933,
   1.4202484130859374,
   0.8075618743896484,
   0.40671415328979493,
   0.526338529586792]])

In [None]:
for data in val_dataloader:

    coordinates, geomout, node_feature, mask, id = data['geomins'].float().to(device), data['geomouts'].float().to(device), data['encodings'].float().to(device), data['mask'].float().to(device), data['ids']
id 

In [None]:
from ABDB import database as db
from ABDB.AbPDB import AntibodyParser
import numpy as np
import Bio.PDB
parser = AntibodyParser(PERMISSIVE=True, QUIET=True)
parser.set_numbering_scheme("imgt")
db.set_numbering_scheme("imgt")
backbone = ["CA","C","N", "CB"]

def CDR_rmsds(pdb, file, fab_n = 0, chains = ["H", "L"], decoy_chains = ["H", "L"]):
    fab = db.fetch(pdb).fabs[fab_n]
    rmsds = {}
    for h_or_l in chains:
        #Truth load
        chain = db.db_summary[pdb]["fabs"][fab_n][h_or_l+"chain"]
        truth = fab.get_structure()[chain]
        #Decoy load
        decoy = parser.get_antibody_structure(pdb+"_model", file)[0][{"H":decoy_chains[0], "L":decoy_chains[1]}[h_or_l]]
        #Numbering
        numb = [(" ", *x) for x in fab.get_numbering()[h_or_l]][2:-2]
        #Get residues to align
        truth_res = [truth[x] for x in numb if (x in decoy) and (x in truth)]
        decoy_res = [decoy[x] for x in numb if (x in decoy) and (x in truth)]
        #Get atoms to align
        fixed = []
        moved = []
        for i in range(len(decoy_res)):
            fixed += [truth_res[i][atom] for atom in backbone if (atom in decoy_res[i]) and (atom in truth_res[i])]
            moved += [decoy_res[i][atom] for atom in backbone if (atom in decoy_res[i]) and (atom in truth_res[i])]
        #Calculate superimposer and move decoy
        imposer = Bio.PDB.Superimposer()
        imposer.set_atoms(fixed, moved)
        imposer.apply(decoy.get_atoms())
        rmsds[h_or_l] = imposer.rms
        # Find CDR definitions
        loop_definitions = {x[3:]:[(" ", *y[0]) for y in fab.get_CDR_sequences(definition="imgt")[x]]  for x in fab.get_CDR_sequences(definition="imgt")}
        # Calculate RMSD for each CDR
        for CDR in loop_definitions:
            if CDR[0] == h_or_l:
                true_loop = []
                decoy_loop = []
                for res in loop_definitions[CDR]:
                    if (res in truth) and (res in decoy):
                        true_loop += [truth[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
                        decoy_loop+= [decoy[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
                #Calculate RMSD
                rmsds[CDR] = np.sqrt(np.mean(3*(np.array(true_loop) - np.array(decoy_loop))**2))
                
    # Calculate RMSD for framework
    ignore = sum([loop_definitions[x] for x in loop_definitions if x[0] == h_or_l], [])
    frame_def = [x for x in numb if x not in ignore]
    
    true_frame, decoy_frame = [], []
    for res in frame_def:
        if (res in truth) and (res in decoy):
            true_frame += [truth[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
            decoy_frame+= [decoy[res][x].get_coord() for x in backbone if (x in decoy[res]) and (x in truth[res])]
    rmsds[h_or_l] = np.sqrt(np.mean(3*(np.array(true_frame) - np.array(decoy_frame))**2))
                
    return rmsds