In [1]:
import os
import argparse
import copy
import json
import pickle
import Bio
from Bio.Align import substitution_matrices
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

from abnumber import Chain
from xlm.utils import AttrDict
from xlm.utils import bool_flag, initialize_exp
from xlm.data.dictionary import Dictionary
from xlm.model.transformer import TransformerModel
from xlm.utils import to_cuda
from xlm.model.transformer import get_masks
from xlm.evaluation.evaluator import convert_to_text, calculate_identity
from evaluation.datasets import SAbDabDataset
from evaluation.datasets import get_dataset
from evaluation.utils.protein.writers import save_pdb
from evaluation.utils.data import *
from evaluation.utils.misc import *
from evaluation.utils.transforms import *

from transformers import BertModel


import pyrosetta
pyrosetta.init(silent=True)

from pyrosetta import pose_from_pdb, init
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
  warn(f"Failed to load image Python extension: {e}")


core.init: Checking for fconfig files in pwd and ./rosetta/flags
core.init: Rosetta version: PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python39.Release r337 2022.49+release.201d763 201d7639f91f369d58b1adf514f3febaf6154c58 http://www.pyrosetta.org 2022-12-07T16:15:33
core.init: command: PyRosetta -ex1 -ex2aro -database /h/benjami/.conda/envs/AntiXLM/lib/python3.9/site-packages/pyrosetta/database
basic.random.init_random_generator: 'RNG device' seed mode, using '/dev/urandom', seed=1495705403 seed_offset=0 real_seed=1495705403 thread_index=0
basic.random.init_random_generator: RandomGenerator:init: Normal mode, seed=1495705403 RG_type=mt19937


In [2]:
#Core Includes
from rosetta.core.select import residue_selector as selections

from rosetta.protocols import antibody
init('-use_input_sc -ignore_unrecognized_res -check_cdr_chainbreaks false \
     -ignore_zero_occupancy false -load_PDB_components false -no_fconfig', silent=True)


def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Translate sentences")

    # main parameters
    parser.add_argument("--dump_path", type=str, default="dumped/", help="Experiment dump path")
    parser.add_argument("--exp_name", type=str, default="kir", help="Experiment name")
    parser.add_argument("--exp_id", type=str, default="123", help="Experiment ID")
    parser.add_argument("--beam_size", type=int, default=100)
    parser.add_argument("--excess_res", type=int, default=50)
    parser.add_argument("--reporter", type=bool, default=False)
    # model / output paths
    parser.add_argument("--model_path", type=str, default="/scratch/ssd004/scratch/benjami/AntiXLM/evaluation/checkpoint_rand.pkl", help="Model path")
    parser.add_argument("--output_path", type=str, default="evaluation/", help="Output path")

    # source language / target language
    parser.add_argument("--src_lang", type=str, default="ag", help="Source language")
    parser.add_argument("--tgt_lang", type=str, default="ab", help="Target language")

    parser.add_argument('-i', '--index', type=int, default=0)
    parser.add_argument('-t', '--tag', type=str, default='')
    parser.add_argument('-c', '--config', type=str, default='evaluation/configs/test/codesign_single.yml')
    parser.add_argument('-o', '--out_root', type=str, default='evaluation/results')

    return parser

class SabdabEntry:
    def __init__(self, dataset, index, params, renumber=None) -> None:
        self.structure = dataset[index]
        self.entry = self.find_entry(dataset, index)
        
        self.structure_id = self.structure['id']
        self.ag_name = self.entry['ag_name']
        self.ag_chain = self.entry['ag_chains'][0]
        self.ab_chain = self.entry['H_chain']
        self.pdb_code = self.entry['pdbcode']

        self.f1 = self.structure['heavy']['FW1_seq']
        self.f2 = self.structure['heavy']['FW2_seq']
        self.f3 = self.structure['heavy']['FW3_seq']
        self.f4 = self.structure['heavy']['FW4_seq']

        self.c1 = self.structure['heavy']['H1_seq']
        self.c2 = self.structure['heavy']['H2_seq']
        self.c3 = self.structure['heavy']['H3_seq']
        
        self.excess = params.excess_res

        self.ab_seq = self.structure['heavy'].seq
        self.ag_seq = self.structure['antigen'].seq
        self.weights = {}
        self.set_weights()
        data_native = MergeChains()(self.structure)
        self.log_dir = get_new_log_dir(os.path.join(params.log_dir), prefix='%02d_%s' % (index, self.structure_id))
        save_pdb(data_native, os.path.join(self.log_dir, 'reference.pdb'))
        save_pdb(data_native, os.path.join(self.log_dir, 'reference_renamed.pdb'), 
                 rename={self.ag_chain:'A', self.ab_chain:'H'})
        pose = pose_from_pdb(os.path.join(self.log_dir, 'reference_renamed.pdb'))
        ab_info = antibody.AntibodyInfo(pose, antibody.Chothia_Scheme, antibody.North)

        for s in range(5,25):
            self.epi_residues = np.array(antibody.select_epitope_residues(ab_info, pose, s))[len(self.ab_seq):]
            if self.epi_residues.sum() > 20:
                break
        self.epi_range = (np.argmax(self.epi_residues), self.epi_residues.shape[0] - np.argmax(self.epi_residues[::-1]) - 1)
        self.epi_resseq = self.structure['antigen']['resseq'][self.epi_residues]

        save_pdb(data_native, os.path.join(self.log_dir, 'antigen.pdb'), ignore_chain=self.ab_chain)
        save_pdb(data_native, os.path.join(self.log_dir, 'antibody.pdb'), ignore_chain=self.ag_chain)        
        save_pdb(data_native, os.path.join(self.log_dir, 'cutted_antigen.pdb'), ignore_chain=self.ab_chain,
                 write_range={self.ag_chain: (max(0, self.epi_range[0]-self.excess), self.epi_range[1]+self.excess)})        
        save_pdb(data_native, os.path.join(self.log_dir, 'cutted_refrence_renamed.pdb'),
                 write_range={'A': (max(0, self.epi_range[0]-self.excess), self.epi_range[1]+self.excess)},
                 rename={self.ag_chain:'A', self.ab_chain:'H'})

        self.identity = {}
        self.generated_sequences = {}

    @property
    def antigen(self):
        extra = int(min(self.excess, (240 - (self.epi_range[1] - self.epi_range[0]))/2))
        return self.ag_seq[max(0, self.epi_range[0]-extra): self.epi_range[1]+extra]

    @property
    def antibody(self):
        return self.ab_seq

    def set_weights(self):
        self.weights['CDR1'] = self._construct_weight(cdr1=True)
        self.weights['CDR2'] = self._construct_weight(cdr2=True)
        self.weights['CDR3'] = self._construct_weight(cdr3=True)
        self.weights['CDR123'] = self._construct_weight(cdr1=True, cdr2=True, cdr3=True)

    def _construct_weight(self, cdr1=False, cdr2=False, cdr3=False):
        return [0] * len(self.f1) + \
            ([1] * len(self.c1) if cdr1 else [0] * len(self.c1)) + \
            [0] * len(self.f2) + \
            ([1] * len(self.c2) if cdr2 else [0] * len(self.c2)) + \
            [0] * len(self.f3) + \
            ([1] * len(self.c3) if cdr3 else [0] * len(self.c3)) + \
            [0] * len(self.f4)


    def find_entry(self, dataset:SAbDabDataset, index):
        for entry in dataset.sabdab_entries:
            if entry['id'] == self.structure['id']:
                return entry

    def write_generated(self):
        for key in self.generated_sequences:
        # Create a file name for the fasta file
            with open(os.path.join(self.log_dir, key)+'.fasta', "w") as f:
                for i, seq in enumerate(self.generated_sequences[key]):
                    # Open the file for writing
                    f.write(">{0}_sequence".format(key) + str(i) + "\n")
                    # Write the sequence to the file in fasta format
                    f.write(seq.replace(' ', '') + "\n")

core.init: Rosetta version: PyRosetta4.conda.linux.cxx11thread.serialization.CentOS.python39.Release r337 2022.49+release.201d763 201d7639f91f369d58b1adf514f3febaf6154c58 http://www.pyrosetta.org 2022-12-07T16:15:33
core.init: command: PyRosetta -use_input_sc -ignore_unrecognized_res -check_cdr_chainbreaks false -ignore_zero_occupancy false -load_PDB_components false -no_fconfig -database /h/benjami/.conda/envs/AntiXLM/lib/python3.9/site-packages/pyrosetta/database
basic.random.init_random_generator: 'RNG device' seed mode, using '/dev/urandom', seed=1850212661 seed_offset=0 real_seed=1850212661 thread_index=0
basic.random.init_random_generator: RandomGenerator:init: Normal mode, seed=1850212661 RG_type=mt19937


  from rosetta.core.select import residue_selector as selections


In [3]:
    

def get_model(params):    
    # initialize the experiment
    logger = initialize_exp(params)

    # generate parser / parse parameters
    reloaded = torch.load(params.model_path)
    model_params = AttrDict(reloaded['params'])
    logger.info("Supported languages: %s" % ", ".join(model_params.lang2id.keys()))

    # update dictionary parameters
    for name in ['n_words', 'bos_index', 'eos_index', 'pad_index', 'unk_index', 'mask_index']:
        setattr(params, name, getattr(model_params, name))

    # build dictionary / build encoder / build decoder / reload weights
    dico = Dictionary(reloaded['dico_id2word'], reloaded['dico_word2id'], reloaded['dico_counts'])
    encoder = TransformerModel(model_params, dico, is_encoder=True, with_output=False).cuda().eval()
    decoder = TransformerModel(model_params, dico, is_encoder=False, with_output=True).cuda().eval()
    encoder.load_state_dict({k[7:]:reloaded['encoder'][k] for k in reloaded['encoder']})
    decoder.load_state_dict({k[7:]:reloaded['decoder'][k] for k in reloaded['decoder']})
    params.src_id = model_params.lang2id['ag']
    params.tgt_id = model_params.lang2id['ab']
    encoder.eval()
    decoder.eval()
    bert = BertModel.from_pretrained("Rostlab/prot_bert")
    bert.eval()
    bert = bert.cuda()
    return encoder, decoder, dico, bert


def get_sabdab(params):
    # Load configs
    config, config_name = load_config(params.config)
    # Testset
    dataset = get_dataset(config.dataset.test)
    # Logging

    return dataset


def build_batch(seq, lang, eos, bos, pad):
    lengths = torch.LongTensor([len(seq) + 2])
    batch = torch.LongTensor(lengths.max().item(), lengths.size(0)).fill_(pad)
    batch[0] = bos
    batch[1:lengths[0] - 1, 0].copy_(seq)
    batch[lengths[0] - 1, 0] = eos
    langs = batch.clone().fill_(lang)
    return batch, lengths, langs


def write_generated(log_dir, eval_step, sequences):
    log = get_new_log_dir(os.path.join(log_dir, eval_step))
    for i, seq in enumerate(sequences):
    # Create a file name for the fasta file
        filename = "sequence" + str(i) + ".fasta"

        # Open the file for writing
        with open(os.path.join(log, filename), "w") as f:
            # Write the sequence to the file in fasta format
            f.write(">sequence" + str(i) + "\n")
            f.write(seq.replace(' ', '') + "\n")

def evaluate(encoder, decoder, dico, bert, params, aligner, sample:SabdabEntry, eval_modes=['CDR1', 'CDR2', 'CDR3', 'CDR123', 'GEN']):
    
    ag_tensor, ag_length_tensor, ag_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antigen]), 
                                                               params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
    ab_tensor, ab_length_tensor, ab_langs_tensor = build_batch(torch.LongTensor([dico.index(w) for w in sample.antibody]), 
                                                               params.tgt_id, params.eos_index, params.bos_index, params.pad_index)
    token_type_ids = torch.zeros_like(ag_tensor)
            
    ab_tensor, ab_length_tensor, ab_langs_tensor, ag_tensor, ag_length_tensor, ag_langs_tensor, token_type_ids = \
        to_cuda(ab_tensor, ab_length_tensor, ab_langs_tensor, ag_tensor, ag_length_tensor, ag_langs_tensor, token_type_ids)

    with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):   
        bert_embed = bert(input_ids=ag_tensor.T, token_type_ids=token_type_ids.T, attention_mask=get_masks(ag_tensor.size()[0], ag_length_tensor, False)[0]).last_hidden_state

        enc1 = encoder('fwd', x=ag_tensor, lengths=ag_length_tensor, langs=ag_langs_tensor, causal=False, bert_embed=bert_embed)
        enc1 = enc1.transpose(0, 1)
    beam_size = params.beam_size
    for eval_step in eval_modes:
        if eval_step == 'GEN':
            beam_size = 20
        ab_weights_tensor, _, _ = build_batch(torch.LongTensor(sample.weights.get(eval_step, sample.weights['CDR123'])), params.tgt_id, 1, 0, 0)
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):   
            generated, lengths = decoder.generate_beam(
                                enc1, ag_length_tensor, params.tgt_id, beam_size=beam_size,
                                length_penalty=1,
                                early_stopping=False,
                                max_len=160,
                                bert_embed=bert_embed,
                                cdr_generation=eval_step != 'GEN',
                                tgt_frw=ab_tensor,
                                w=ab_weights_tensor,
                                open_end=False
                            )
            hypothesis_text = convert_to_text(generated, lengths, dico, params)
            batch_generate_identity, batch_generate_cdr_identity = calculate_identity(aligner, [sample.antibody], hypothesis_text, 'ab', ab_weights_tensor, beam_size=beam_size)
            write_generated(sample.log_dir, eval_step, hypothesis_text)
            sample.generated_sequences[eval_step] = hypothesis_text
            
            if eval_step == 'GEN':
                sample.identity[eval_step+'_CDR'] = batch_generate_cdr_identity
                sample.identity[eval_step+'_ALL'] = batch_generate_identity
            else:
                sample.identity[eval_step] = batch_generate_cdr_identity
    print(sample.identity)

In [4]:
parser = get_parser()
params = parser.parse_args(args=[])
log = get_new_log_dir(os.path.join(params.out_root), date=True)
params.log_dir = log
dataset = get_sabdab(params)
test_samples = []

encoder, decoder, dico, bert = get_model(params)
# aligner = Bio.Align.PairwiseAligner()
# aligner.substitution_matrix = substitution_matrices.load("BLOSUM62")
# aligner.open_gap_score = -10
# aligner.extend_gap_score = -0.5    



INFO - 04/12/23 14:51:31 - 0:00:00 - beam_size: 100
                                     command: python /h/benjami/.conda/envs/AntiXLM/lib/python3.9/site-packages/ipykernel_launcher.py '-f' '/ssd003/home/benjami/.local/share/jupyter/runtime/kernel-32dceed6-a15c-4c04-9b75-8cbd78f35e43.json' --exp_id "123"
                                     config: evaluation/configs/test/codesign_single.yml
                                     dump_path: dumped/kir/123
                                     excess_res: 50
                                     exp_id: 123
                                     exp_name: kir
                                     index: 0
                                     log_dir: evaluation/results/2023_04_12__14_51_30
                                     model_path: /scratch/ssd004/scratch/benjami/AntiXLM/evaluation/checkpoint_rand.pkl
                                     out_root: evaluation/results
                                     output_path: evaluation/
         

In [5]:

sample = SabdabEntry(dataset=dataset, index=0, params=params)

evaluate(encoder, decoder, dico, bert, params, None, sample, eval_modes=['CDR1', 'CDR2', 'CDR3'])


core.chemical.GlobalResidueTypeSet: Finished initializing fa_standard residue type set.  Created 985 residue types
core.chemical.GlobalResidueTypeSet: Total time to initialize 2.96259 seconds.
core.import_pose.import_pose: File 'evaluation/results/2023_04_12__14_51_30/00_4ffv_H_L_A_/reference_renamed.pdb' automatically determined to be of type PDB
core.conformation.Conformation: Found disulfide between residues 288 299
core.conformation.Conformation: current variant for 288 CYS
core.conformation.Conformation: current variant for 299 CYS
core.conformation.Conformation: current variant for 288 CYD
core.conformation.Conformation: current variant for 299 CYD
core.conformation.Conformation: Found disulfide between residues 345 357
core.conformation.Conformation: current variant for 345 CYS
core.conformation.Conformation: current variant for 357 CYS
core.conformation.Conformation: current variant for 345 CYD
core.conformation.Conformation: current variant for 357 CYD
core.conformation.Confor

In [None]:
sample = SabdabEntry(dataset=dataset, index=0, params=params)

evaluate(encoder, decoder, dico, bert, params, None, sample, eval_modes=['CDR1', 'CDR2', 'CDR3'])


core.import_pose.import_pose: File 'evaluation/results/2023_04_12__14_50_39/00_4ffv_H_L_A_/reference_renamed.pdb' automatically determined to be of type PDB
core.conformation.Conformation: Found disulfide between residues 22 96
core.conformation.Conformation: current variant for 22 CYS
core.conformation.Conformation: current variant for 96 CYS
core.conformation.Conformation: current variant for 22 CYD
core.conformation.Conformation: current variant for 96 CYD
core.conformation.Conformation: Found disulfide between residues 405 416
core.conformation.Conformation: current variant for 405 CYS
core.conformation.Conformation: current variant for 416 CYS
core.conformation.Conformation: current variant for 405 CYD
core.conformation.Conformation: current variant for 416 CYD
core.conformation.Conformation: Found disulfide between residues 462 474
core.conformation.Conformation: current variant for 462 CYS
core.conformation.Conformation: current variant for 474 CYS
core.conformation.Conformation

In [11]:
with open('evaluation/results/2023_04_12__14_49_18/00_4ffv_H_L_A_/sample.pkl', 'rb') as f:
    sample2 = pickle.load(f)

In [19]:
sample.generated_sequences['CDR3'][10] == sample.generated_sequences['CDR3'][0]

True

In [17]:
sample2.generated_sequences['CDR3'][0] == sample2.generated_sequences['CDR3'][1]

True