In [70]:
import csv
from dateutil import parser
from omegaconf import DictConfig, OmegaConf
import hydra
import os
from collections import namedtuple, defaultdict
from tqdm import tqdm
import pandas as pd
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils import data

from data import chemical
from data import residue_constants
from data import utils as du

# Old code

In [None]:
train_set = DistilledDataset(pdb_IDs, loader_pdb, loader_pdb_fixbb, pdb_dict,
                             compl_IDs, loader_complex, loader_complex_fixbb, compl_dict,
                             neg_IDs, loader_complex, neg_dict,
                             fb_IDs, loader_fb, loader_fb_fixbb, fb_dict,
                             homo, self.loader_param)

In [None]:
train_sampler = DistributedWeightedSampler(train_set, pdb_weights, compl_weights, neg_weights, fb_weights, p_seq2str,
                                           num_example_per_epoch=N_EXAMPLE_PER_EPOCH,
                                           num_replicas=world_size, rank=rank, fraction_fb=0.5, fraction_compl=0.25)

In [None]:
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM)

In [None]:
def set_data_loader_params(args):
    PARAMS = {
        "COMPL_LIST" : "%s/list.hetero.csv"%compl_dir,
        "HOMO_LIST" : "%s/list.homo.csv"%compl_dir,
        "NEGATIVE_LIST" : "%s/list.negative.csv"%compl_dir,
        "PDB_LIST"   : "%s/list_v02.csv"%base_dir,
        #"PDB_LIST"    : "/gscratch2/list_2021AUG02.csv",
        "FB_LIST"    : "%s/list_b1-3.csv"%fb_dir,
        "VAL_PDB"    : "%s/val/xaa"%base_dir,
        #"VAL_PDB"   : "/gscratch2/PDB_val/xaa",
        "VAL_COMPL"  : "%s/val_lists/xaa"%compl_dir,
        "VAL_NEG"    : "%s/val_lists/xaa.neg"%compl_dir,
        "DATAPKL"    : "./dataset.pkl", # cache for faster loading
        "PDB_DIR"    : base_dir,
        "FB_DIR"     : fb_dir,
        "COMPL_DIR"  : compl_dir,
        "MINTPLT" : 0,
        "MAXTPLT" : 5,
        "MINSEQ"  : 1,
        "MAXSEQ"  : 1024,
        "MAXLAT"  : 128, 
        "CROP"    : 256,
        "DATCUT"  : "2020-Apr-30",
        "RESCUT"  : 5.0,
        "BLOCKCUT": 5,
        "PLDDTCUT": 70.0,
        "SCCUT"   : 90.0,
        "ROWS"    : 1,
        "SEQID"   : 95.0,
        "MAXCYCLE": 4,
        "HAL_MASK_HIGH": 35, #added by JW for hal masking
        "HAL_MASK_LOW": 10,
        "HAL_MASK_HIGH_AR": 50,
        "HAL_MASK_LOW_AR": 20,
        "COMPLEX_HAL_MASK_HIGH": 35,
        "COMPLEX_HAL_MASK_LOW": 10,
        "COMPLEX_HAL_MASK_HIGH_AR": 50,
        "COMPLEX_HAL_MASK_LOW_AR": 20,
        "FLANK_HIGH": 6,
        "FLANK_LOW" : 3,
        "STR2SEQ_FULL_LOW" : 0.9,
        "STR2SEQ_FULL_HIGH" : 1.0,
        "MAX_LENGTH" : 260,
        "MAX_COMPLEX_CHAIN" : 200,
        "TASK_NAMES" : ['seq2str'],
        "TASK_P" : [1.0],
        "DIFF_MASK_LOW":args.diff_mask_low,
        "DIFF_MASK_HIGH":args.diff_mask_high,

    }
    for param in PARAMS:
        if hasattr(args, param.lower()):
            PARAMS[param] = getattr(args, param.lower())

    print('This is params from get train valid')
    for key,val in PARAMS.items():
        print(key, val)
    return PARAMS

In [None]:
def get_train_valid_set(params, OFFSET=1000000):

    if (not os.path.exists(params['DATAPKL'])):
        # read validation IDs for PDB set
        val_pdb_ids = set([int(l) for l in open(params['VAL_PDB']).readlines()])
        val_compl_ids = set([int(l) for l in open(params['VAL_COMPL']).readlines()])
        val_neg_ids = set([int(l)+OFFSET for l in open(params['VAL_NEG']).readlines()])


        # read validation IDs for PDB set
        val_pdb_ids = set([int(l) for l in open(params['VAL_PDB']).readlines()])
        val_compl_ids = set([int(l) for l in open(params['VAL_COMPL']).readlines()])
        val_neg_ids = set([int(l)+OFFSET for l in open(params['VAL_NEG']).readlines()])
    
        # read homo-oligomer list
        homo = {}
        # with open(params['HOMO_LIST'], 'r') as f:
        #     reader = csv.reader(f)
        #     next(reader)
        #     # read pdbA, pdbB, bioA, opA, bioB, opB
        #     rows = [[r[0], r[1], int(r[2]), int(r[3]), int(r[4]), int(r[5])] for r in reader]
        # for r in rows:
        #     if r[0] in homo.keys():
        #         homo[r[0]].append(r[1:])
        #     else:
        #         homo[r[0]] = [r[1:]]

        # read & clean list.csv
        with open(params['PDB_LIST'], 'r') as f:
            reader = csv.reader(f)
            next(reader)
            rows = [[r[0],r[3],int(r[4]), int(r[-1].strip())] for r in reader
                    if float(r[2])<=params['RESCUT'] and
                    parser.parse(r[1])<=parser.parse(params['DATCUT']) and len(r[-2]) <= params['MAX_LENGTH'] and len(r[-2]) >= 60] #added length max so only have full chains, and minimum length of 60aa

        # compile training and validation sets
        val_hash = list()
        train_pdb = {}
        valid_pdb = {}
        valid_homo = {}
        for r in rows:
            if r[2] in val_pdb_ids:
                val_hash.append(r[1])
                if r[2] in valid_pdb.keys():
                    valid_pdb[r[2]].append((r[:2], r[-1]))
                else:
                    valid_pdb[r[2]] = [(r[:2], r[-1])]
                #
                if r[0] in homo:
                    if r[2] in valid_homo.keys():
                        valid_homo[r[2]].append((r[:2], r[-1]))
                    else:
                        valid_homo[r[2]] = [(r[:2], r[-1])]
            else:
                if r[2] in train_pdb.keys():
                    train_pdb[r[2]].append((r[:2], r[-1]))
                else:
                    train_pdb[r[2]] = [(r[:2], r[-1])]
        val_hash = set(val_hash)
        
        # compile facebook model sets
        with open(params['FB_LIST'], 'r') as f:
            reader = csv.reader(f)
            next(reader)
            rows = [[r[0],r[2],int(r[3]),len(r[-1].strip())] for r in reader
                     if float(r[1]) > 80.0 and
                     len(r[-1].strip()) > 100 and len(r[-1].strip()) <= params['MAX_LENGTH']] #added max length to allow only full chains. Also reduced minimum length to 100aa
        
        fb = {}
        
        for r in rows:
            if r[2] in fb.keys():
                fb[r[2]].append((r[:2], r[-1]))
            else:
                fb[r[2]] = [(r[:2], r[-1])]
        
        #compile complex sets
        with open(params['COMPL_LIST'], 'r') as f:
            reader = csv.reader(f)
            next(reader)
            # read complex_pdb, pMSA_hash, complex_cluster, length, taxID, assembly (bioA,opA,bioB,opB)
            rows = [[r[0], r[3], int(r[4]), [int(plen) for plen in r[5].split(':')], r[6] , [int(r[7]), int(r[8]), int(r[9]), int(r[10])]] for r in reader
                     if float(r[2]) <= params['RESCUT'] and
                     parser.parse(r[1]) <= parser.parse(params['DATCUT']) and min([int(i) for i in r[5].split(":")]) < params['MAX_COMPLEX_CHAIN'] and min([int(i) for i in r[5].split(":")]) > 50] #require one chain of the hetero complexes to be smaller than a certain value so it can be kept complete. This chain must also be > 50aa long.

        train_compl = {}
        valid_compl = {}
        for r in rows:
            if r[2] in val_compl_ids:
                if r[2] in valid_compl.keys():
                    valid_compl[r[2]].append((r[:2], r[-3], r[-2], r[-1])) # ((pdb, hash), length, taxID, assembly, negative?)
                else:
                    valid_compl[r[2]] = [(r[:2], r[-3], r[-2], r[-1])]
            else:
                # if subunits are included in PDB validation set, exclude them from training
                hashA, hashB = r[1].split('_')
                if hashA in val_hash:
                    continue
                if hashB in val_hash:
                    continue
                if r[2] in train_compl.keys():
                    train_compl[r[2]].append((r[:2], r[-3], r[-2], r[-1]))
                else:
                    train_compl[r[2]] = [(r[:2], r[-3], r[-2], r[-1])]

        # compile negative examples
        # remove pairs if any of the subunits are included in validation set
        # with open(params['NEGATIVE_LIST'], 'r') as f:
        #     reader = csv.reader(f)
        #     next(reader)
        #     # read complex_pdb, pMSA_hash, complex_cluster, length, taxonomy
        #     rows = [[r[0],r[3],OFFSET+int(r[4]),[int(plen) for plen in r[5].split(':')],r[6]] for r in reader
        #             if float(r[2])<=params['RESCUT'] and
        #             parser.parse(r[1])<=parser.parse(params['DATCUT'])]

        train_neg = {}
        valid_neg = {}
        # for r in rows:
        #     if r[2] in val_neg_ids:
        #         if r[2] in valid_neg.keys():
        #             valid_neg[r[2]].append((r[:2], r[-2], r[-1], []))
        #         else:
        #             valid_neg[r[2]] = [(r[:2], r[-2], r[-1], [])]
        #     else:
        #         hashA, hashB = r[1].split('_')
        #         if hashA in val_hash:
        #             continue
        #         if hashB in val_hash:
        #             continue
        #         if r[2] in train_neg.keys():
        #             train_neg[r[2]].append((r[:2], r[-2], r[-1], []))
        #         else:
        #             train_neg[r[2]] = [(r[:2], r[-2], r[-1], [])]
    
        # Get average chain length in each cluster and calculate weights
        pdb_IDs = list(train_pdb.keys())
        fb_IDs = list(fb.keys())
        compl_IDs = list(train_compl.keys())
        neg_IDs = list(train_neg.keys())

        pdb_weights = list()
        fb_weights = list()
        compl_weights = list()
        neg_weights = list()
        for key in pdb_IDs:
            plen = sum([plen for _, plen in train_pdb[key]]) // len(train_pdb[key])
            w = (1/512.)*max(min(float(plen),512.),256.)
            pdb_weights.append(w)
    
        for key in fb_IDs:
            plen = sum([plen for _, plen in fb[key]]) // len(fb[key])
            w = (1/512.)*max(min(float(plen),512.),256.)
            fb_weights.append(w)
    
        for key in compl_IDs:
            plen = sum([sum(plen) for _, plen, _, _ in train_compl[key]]) // len(train_compl[key])
            w = (1/512.)*max(min(float(plen),512.),256.)
            compl_weights.append(w)
    
        for key in neg_IDs:
            plen = sum([sum(plen) for _, plen, _, _ in train_neg[key]]) // len(train_neg[key])
            w = (1/512.)*max(min(float(plen),512.),256.)
            neg_weights.append(w)
        # save
        obj = (
           pdb_IDs, pdb_weights, train_pdb,
           fb_IDs, fb_weights, fb,
           compl_IDs, compl_weights, train_compl,
           neg_IDs, neg_weights, train_neg,
           valid_pdb, valid_homo, valid_compl, valid_neg, homo
        )
        with open(params["DATAPKL"], "wb") as f:
            print ('Writing',params["DATAPKL"])
            pickle.dump(obj, f)
            print ('Done')

    else:
        with open(params["DATAPKL"], "rb") as f:
            print ('Loading',params["DATAPKL"])
            (
               pdb_IDs, pdb_weights, train_pdb,
               fb_IDs, fb_weights, fb,
               compl_IDs, compl_weights, train_compl,
               neg_IDs, neg_weights, train_neg,
               valid_pdb, valid_homo, valid_compl, valid_neg, homo
            ) = pickle.load(f)
            print ('Done')

    return (pdb_IDs, torch.tensor(pdb_weights).float(), train_pdb), \
           (fb_IDs, torch.tensor(fb_weights).float(), fb), \
           (compl_IDs, torch.tensor(compl_weights).float(), train_compl), \
           (neg_IDs, torch.tensor(neg_weights).float(), train_neg),\
           valid_pdb, valid_homo, valid_compl, valid_neg, homo

In [None]:
base_dir = "/projects/ml/TrRosetta/PDB-2021AUG02"
compl_dir = "/projects/ml/RoseTTAComplex"
fb_dir = "/projects/ml/TrRosetta/fb_af"
if not os.path.exists(base_dir):
    # training on AWS
    base_dir = "/data/databases/PDB-2021AUG02"
    fb_dir = "/data/databases/fb_af"
    compl_dir = "/data/databases/RoseTTAComplex"

In [None]:
train_set = DistilledDataset(pdb_IDs, loader_pdb, loader_pdb_fixbb, pdb_dict,
                             compl_IDs, loader_complex, loader_complex_fixbb, compl_dict,
                             neg_IDs, loader_complex, neg_dict,
                             fb_IDs, loader_fb, loader_fb_fixbb, fb_dict,
                             homo, self.loader_param)

valid_pdb_set = Dataset(list(valid_pdb.keys())[:self.n_valid_pdb],
                        loader_pdb, valid_pdb,
                        self.loader_param, homo, p_homo_cut=-1.0)

In [None]:
class DistilledDataset(data.Dataset):
    def __init__(self,
                 pdb_IDs,
                 pdb_loader,
                 pdb_loader_fixbb,
                 pdb_dict,
                 compl_IDs,
                 compl_loader,
                 compl_loader_fixbb,
                 compl_dict,
                 neg_IDs,
                 neg_loader,
                 neg_dict,
                 fb_IDs,
                 fb_loader,
                 fb_loader_fixbb,
                 fb_dict,
                 homo,
                 params,
                 p_homo_cut=0.5):
        #
        self.pdb_IDs     = pdb_IDs
        self.pdb_dict    = pdb_dict
        self.pdb_loaders = {'seq2str':      pdb_loader,
                            'str2seq':      pdb_loader_fixbb, 
                            'str2seq_full': pdb_loader_fixbb, 
                            'hal':          pdb_loader_fixbb, 
                            'hal_ar':       pdb_loader_fixbb,
                            'diff':         pdb_loader_fixbb}


        self.compl_IDs     = compl_IDs
        self.compl_dict    = compl_dict
        self.compl_loaders = {'seq2str':     compl_loader,
                              'str2seq':     compl_loader_fixbb, 
                              'str2seq_full':compl_loader_fixbb, 
                              'hal':         compl_loader_fixbb, 
                              'hal_ar':      compl_loader_fixbb,
                              'diff':        compl_loader_fixbb}


        self.neg_IDs    = neg_IDs
        self.neg_loader = neg_loader
        self.neg_dict   = neg_dict


        self.fb_IDs     = fb_IDs
        self.fb_dict    = fb_dict
        self.fb_loaders = { 'seq2str':      fb_loader,
                            'str2seq':      fb_loader_fixbb, 
                            'str2seq_full': fb_loader_fixbb, 
                            'hal':          fb_loader_fixbb, 
                            'hal_ar':       fb_loader_fixbb,
                            'diff':         fb_loader_fixbb}

        self.homo = homo
        self.params = params
        self.p_task = params['TASK_P']
        self.task_names = params['TASK_NAMES']
        self.unclamp_cut = 0.9
        self.p_homo_cut = p_homo_cut
        
        self.compl_inds = np.arange(len(self.compl_IDs))
        self.neg_inds = np.arange(len(self.neg_IDs))
        self.fb_inds = np.arange(len(self.fb_IDs))
        self.pdb_inds = np.arange(len(self.pdb_IDs))
    
    def __len__(self):
        return len(self.fb_inds) + len(self.pdb_inds) + len(self.compl_inds) + len(self.neg_inds)

    def __getitem__(self, index):
        p_unclamp = np.random.rand()

        #choose task if not from negative set (which is always seq2str)
        if index < len(self.fb_inds) + len(self.pdb_inds) + len(self.compl_inds):
            task_idx = np.random.choice(np.arange(len(self.task_names)), 1, p=self.p_task)[0]
            task = self.task_names[task_idx]
        else:
            task = 'seq2str'

        if index >= len(self.fb_inds) + len(self.pdb_inds) + len(self.compl_inds): # from negative set
            # print('Chose negative')
            chosen_dataset='negative'
            ID = self.neg_IDs[index-len(self.fb_inds)-len(self.pdb_inds)-len(self.compl_inds)]
            sel_idx = np.random.randint(0, len(self.neg_dict[ID]))
            out = self.neg_loader(self.neg_dict[ID][sel_idx][0], self.neg_dict[ID][sel_idx][1], self.neg_dict[ID][sel_idx][2], self.neg_dict[ID][sel_idx][3], self.params, negative=True)

        elif index >= len(self.fb_inds) + len(self.pdb_inds): # from complex set
            chosen_dataset='complex'
            # print('Chose complex')
            ID = self.compl_IDs[index-len(self.fb_inds)-len(self.pdb_inds)]
            sel_idx = np.random.randint(0, len(self.compl_dict[ID]))
            if self.compl_dict[ID][sel_idx][0][0] in self.homo: #homooligomer so do seq2str
                task='seq2str'
            chosen_loader = self.compl_loaders[task]
            out = chosen_loader(self.compl_dict[ID][sel_idx][0], self.compl_dict[ID][sel_idx][1],self.compl_dict[ID][sel_idx][2], self.compl_dict[ID][sel_idx][3], self.params, negative=False)
        
        elif index >= len(self.fb_inds): # from PDB set
            chosen_dataset='pdb'
            # print('Chose pdb')
            ID = self.pdb_IDs[index-len(self.fb_inds)]
            sel_idx = np.random.randint(0, len(self.pdb_dict[ID]))
            chosen_loader = self.pdb_loaders[task]
            if p_unclamp > self.unclamp_cut:
                out = chosen_loader(self.pdb_dict[ID][sel_idx][0], self.params, self.homo, unclamp=True, p_homo_cut=self.p_homo_cut)
            else:
                out = chosen_loader(self.pdb_dict[ID][sel_idx][0], self.params, self.homo, unclamp=False, p_homo_cut=self.p_homo_cut)
        else: # from FB set
            chosen_dataset='fb'
            # print('Chose fb')
            ID = self.fb_IDs[index]
            sel_idx = np.random.randint(0, len(self.fb_dict[ID]))
            chosen_loader = self.fb_loaders[task]
            if p_unclamp > self.unclamp_cut:
                out = chosen_loader(self.fb_dict[ID][sel_idx][0], self.params, unclamp=True)
            else:
                out = chosen_loader(self.fb_dict[ID][sel_idx][0], self.params, unclamp=False)

        (seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, complete_chain, atom_mask) = out

        # get masks for example
        # Commenting this out. Because of popping Nans, need a new way of dealing with complexes, which will break the inpainting-style task mask generators.
        """
        if complete_chain[0] == 0:
            if complete_chain[1] is not None:
                complete_chain = [0,complete_chain[1]-1] #first and last index of the first chain (complete_chain[1] is length of first chain)
            else:
                complete_chain = [0, same_chain.shape[0] -1] #first and last index of full chain
        else:
            complete_chain=[complete_chain[1],same_chain.shape[0]-1]
        """
        if complete_chain[1] is not None:
            chain_tensor, contacts = get_contacts(complete_chain, xyz_t)
        else:
            # make tensor of zeros to stable onto t1d for monomers (i.e. no contacts)
            contacts = torch.zeros(xyz_t.shape[1])
            
        #### DJ/JW alteration: Pop any NaN residues from tensors for diffuion training 
        # print('Printing shapes')
        # print("seq ",seq.shape)
        # print("msa ",msa.shape)
        # print("msa_masked ",msa_masked.shape)
        # print("msa_full ",msa_full.shape)
        # print("mask_msa ",mask_msa.shape)
        # print("true_crds ",true_crds.shape)
        # print("atom_mask ",atom_mask.shape )
        # print("idx_pdb ",idx_pdb.shape)
        # print("xyz_t ",xyz_t.shape)
        # print("t1d ",t1d.shape)
        # print("xyz_prev ",xyz_prev.shape)
        # print("unclamp ",unclamp)
        # print("atom_mask ",atom_mask.shape)
        # print('same chain ',same_chain.shape)
        pop   = (atom_mask[:,:3]).squeeze().any(dim=-1) # will be true if any of the backbone atoms were False in atom mask 
        N     = pop.sum()
        pop2d = pop[None,:] * pop[:,None]

        seq         = seq[:,pop]
        msa         = msa[:,:,pop]
        msa_masked  = msa_masked[:,:,pop]
        msa_full    = msa_full[:,:,pop]
        mask_msa    = mask_msa[:,:,pop]
        true_crds   = true_crds[pop]
        atom_mask   = atom_mask[pop]
        idx_pdb     = idx_pdb[pop]
        xyz_t       = xyz_t[:,pop]
        t1d         = t1d[:,pop]
        xyz_prev    = xyz_prev[pop]
        same_chain  = same_chain[pop2d].reshape(N,N)
        contacts = contacts[pop]

        if complete_chain[1] is not None:
            complete_chain = chain_tensor[pop]
        
        #Concatenate on the contacts tensor onto t1d
        t1d = torch.cat((t1d, contacts[None,...,None]), dim=-1)
        if chosen_dataset != 'complex':
            assert torch.sum(t1d[:,:,-1]) == 0 
        # print('Printing shapes after popping')
        # print("seq ",seq.shape)
        # print("msa ",msa.shape)
        # print("msa_masked ",msa_masked.shape)
        # print("msa_full ",msa_full.shape)
        # print("mask_msa ",mask_msa.shape)
        # print("true_crds ",true_crds.shape)
        # print("atom_mask ",atom_mask.shape )
        # print("idx_pdb ",idx_pdb.shape)
        # print("xyz_t ",xyz_t.shape)
        # print("t1d ",t1d.shape)
        # print("xyz_prev ",xyz_prev.shape)
        # print("unclamp ",unclamp)
        # print("atom_mask ",atom_mask.shape)
        # print('same chain ',same_chain.shape)
        mask_dict = generate_masks(msa, task, self.params, chosen_dataset, complete_chain)

        return seq, msa, msa_masked, msa_full, mask_msa, true_crds, atom_mask, idx_pdb, xyz_t, t1d, xyz_prev, same_chain, unclamp, negative, mask_dict, task, chosen_dataset, atom_mask


## Sandbox

In [None]:
with open(data_conf.pdb_csv, 'r') as f:
    reader = csv.reader(f)
    next(reader)
    rows = [[r[0],r[3],int(r[4]), int(r[-1].strip())] for r in reader
            if float(r[2])<=data_conf.resolution_cutoff and
            parser.parse(r[1])<=parser.parse(data_conf.date_cutoff) and len(r[-2]) <= data_conf.max_len and len(r[-2]) >= 60]


In [None]:
val_hash = list()
train_pdb = {}
valid_pdb = {}
valid_homo = {}
val_clusters = set([int(l) for l in open(data_conf.valid_pdb_dir).readlines()])
for r in rows:
    if r[2] in val_clusters:
        val_hash.append(r[1])
        if r[2] in valid_pdb.keys():
            valid_pdb[r[2]].append((r[:2], r[-1]))
        else:
            valid_pdb[r[2]] = [(r[:2], r[-1])]
    else:
        if r[2] in train_pdb.keys():
            train_pdb[r[2]].append((r[:2], r[-1]))
        else:
            train_pdb[r[2]] = [(r[:2], r[-1])]

# Data config

In [None]:
dir(OmegaConf)

In [73]:
os_dir = '/projects/ml/TrRosetta'
base_dir = os.path.join(os_dir, 'PDB-2021AUG02')
base_fb_dir = os.path.join(os_dir, 'fb_af')

data_conf = OmegaConf.create({
    'resolution_cutoff': 5.0,
    'base_dir': base_dir,
    'base_fb_dir': base_fb_dir,
    'max_len': 260,
    'min_len': 60,
    'date_cutoff': "2020-Apr-30",
    'pdb_csv': os.path.join(base_dir, 'list_v02.csv'),
    'pdb_dir': os.path.join(base_dir, 'torch/pdb'),
    'fb_dir': os.path.join(base_fb_dir, 'pdb'),
    'valid_clusters': os.path.join(base_dir, 'val/xaa'),
    'fb_csv': os.path.join(base_fb_dir, 'list_b1-3.csv'),
    'min_plddt': 80.0,
    'min_fb_len': 100,
    'cache_path': './pkl_jar/dataset.pkl'
})

diffusion_conf = OmegaConf.create({
    
})


In [80]:
print(OmegaConf.to_yaml(data_conf))

resolution_cutoff: 5.0
base_dir: /projects/ml/TrRosetta/PDB-2021AUG02
base_fb_dir: /projects/ml/TrRosetta/fb_af
max_len: 260
min_len: 60
date_cutoff: 2020-Apr-30
pdb_csv: /projects/ml/TrRosetta/PDB-2021AUG02/list_v02.csv
pdb_dir: /projects/ml/TrRosetta/PDB-2021AUG02/torch/pdb
fb_dir: /projects/ml/TrRosetta/fb_af/pdb
valid_clusters: /projects/ml/TrRosetta/PDB-2021AUG02/val/xaa
fb_csv: /projects/ml/TrRosetta/fb_af/list_b1-3.csv
min_plddt: 80.0
min_fb_len: 100
cache_path: ./pkl_jar/dataset.pkl



In [76]:
OmegaConf.to_yaml(data_conf) == OmegaConf.to_yaml(data_conf)

True

# Get files

In [4]:
def _parse_pdb_csv(*, csv_path, max_len, min_len, res_cutoff, date_cutoff):
    """Parse and filter PDB metadata csv."""
    raw_pdb_csv = pd.read_csv(data_conf.pdb_csv)
    raw_pdb_csv['seq_len'] = raw_pdb_csv.SEQUENCE.apply(lambda x: len(x))
    raw_pdb_csv['resolution'] = raw_pdb_csv.RESOLUTION.apply(lambda x: float(x))
    raw_pdb_csv['date'] = raw_pdb_csv.DEPOSITION.apply(lambda x: parser.parse(x))
    raw_pdb_csv['source'] = 'pdb'
    date_cutoff = parser.parse(date_cutoff)
    return raw_pdb_csv[
        (raw_pdb_csv.seq_len <= max_len) &
        (min_len <= raw_pdb_csv.seq_len) &
        (raw_pdb_csv.resolution <= res_cutoff) &
        (raw_pdb_csv.date <= date_cutoff)
    ]

def _parse_fb_csv(*, csv_path, max_len, min_len, min_plddt):
    """Parse and filter FB metadata csv."""
    raw_fb_csv = pd.read_csv(csv_path).rename(
        {'#CHAINID': 'CHAINID'}, axis='columns')
    raw_fb_csv['seq_len'] = raw_fb_csv.SEQUENCE.apply(lambda x: len(x))
    raw_fb_csv['source'] = 'fb'
    return raw_fb_csv[
        (raw_fb_csv.plDDT >= min_plddt) &
        (raw_fb_csv.seq_len <= max_len) &
        (min_len <= raw_fb_csv.seq_len)
    ]

def _parse_clusters(df):
    processed_clusters = {}
    for cluster_id, cluster_df in tqdm(df.groupby('CLUSTER')):
        num_seqs = len(cluster_df)
        avg_len = cluster_df.seq_len.sum() // num_seqs
        weight = (1 / 512.) * max(min(float(avg_len), 512.), 256.)
        membership = cluster_df.CHAINID.tolist()
        processed_clusters[cluster_id] = (
            cluster_id,
            num_seqs,
            avg_len,
            weight,
            membership
        )
    return processed_clusters

In [None]:
# Parse PDB data into train and validation.
pdb_csv = _parse_pdb_csv(
    csv_path=data_conf.pdb_csv,
    max_len=data_conf.max_len,
    min_len=data_conf.min_len,
    res_cutoff=data_conf.resolution_cutoff,
    date_cutoff=data_conf.date_cutoff,
)

fb_csv = _parse_fb_csv(
    csv_path=data_conf.fb_csv,
    max_len=data_conf.max_len,
    min_len=data_conf.min_fb_len,
    min_plddt=data_conf.min_plddt,
)

# Assign splits
val_clusters = set([int(l) for l in open(data_conf.valid_clusters).readlines()])
pdb_csv['split'] = pdb_csv.CLUSTER.apply(lambda x: 'valid' if x in val_clusters else 'train')
train_pdb_csv = pdb_csv[pdb_csv.split == 'train']
valid_pdb_csv = pdb_csv[pdb_csv.split == 'valid']
fb_csv['split'] = fb_csv.CLUSTER.apply(lambda x: 'valid' if x in val_clusters else 'train')
train_fb_csv = fb_csv[fb_csv.split == 'train']

In [6]:
# Parse train clusters
train_csv = pd.concat([pdb_csv, fb_csv])
train_clusters = _parse_clusters(train_csv)

100%|██████████| 1550590/1550590 [05:04<00:00, 5084.91it/s]


In [85]:
%%time
train_csv[train_csv.CHAINID == '5naj_A']

CPU times: user 251 ms, sys: 0 ns, total: 251 ms
Wall time: 254 ms


Unnamed: 0,CHAINID,DEPOSITION,RESOLUTION,HASH,CLUSTER,SEQUENCE,LEN_EXIST,seq_len,resolution,date,source,split,plDDT
9,5naj_A,2017-02-28,1.46,30830,22021,GSMSEQSICQARAAVMVYDDANKKWVPAGGSTGFSRVHIYHHTGNN...,110.0,113,1.46,2017-02-28,pdb,train,


In [87]:
%%time


CPU times: user 6 µs, sys: 1e+03 ns, total: 7 µs
Wall time: 9.78 µs


Int64Index([      9,      10,      11,      12,      15,      16,      17,
                 20,      21,      29,
            ...
            7597978, 7597979, 7597980, 7597981, 7597982, 7597985, 7597986,
            7597988, 7597989, 7597994],
           dtype='int64', length=4334387)

In [9]:
precomputed_data = (
    data_conf,
    train_csv,
    train_clusters,
    valid_pdb_csv,
)
with open(data_conf.cache_path, "wb") as f:
    pickle.dump(precomputed_data, f)
    print ('Done')

Done


# Featurize

In [7]:
with open(data_conf.cache_path, "rb") as f:
    (
        data_conf,
        train_csv,
        train_clusters,
        valid_pdb_csv,
    ) = pickle.load(f)


In [66]:
def load_pdb(chain_id):
    parsed_pdb = torch.load(
        os.path.join(data_conf.pdb_dir, chain_id[1:3], f'{chain_id}.pt'))
    xyz = parsed_pdb['xyz']
    num_res = len(xyz)
    mask = parsed_pdb['mask'].long()
    res_idx = torch.arange(num_res).long()
    xyz = torch.nan_to_num(xyz).float()
    seq = torch.tensor([residue_constants.restype_order_with_x[i] for i in parsed_pdb['seq']]).long()    
    # TODO: Remove beginning and ending tags
    res_mask = torch.any(mask, axis=-1)
    feats = {
        'xyz': xyz,
        'mask': mask,
        'res_mask': res_mask,
        'seq': seq,
        'res_idx': res_idx,
    }
    return chain_id, feats

def load_fb(chain_id, chain_hash):
    file_path = os.path.join(
        data_conf.fb_dir, chain_hash[:2], chain_hash[2:], chain_id)
    xyz, mask, res_idx, seq = parse_pdb(file_path+'.pdb')
    plddt = np.load(file_path+'.plddt.npy')
    # TODO: Add back sidechain pLDDT masking
    mask = np.logical_and(mask, (plddt > data_conf.min_plddt)[:,None])
    seq = np.array([residue_constants.restype_order_with_x[i] for i in seq])
    # TODO: Remove beginning and ending tags
    res_mask = np.any(mask, axis=-1)
    feats = {
        'xyz': torch.tensor(xyz).float(),
        'mask': torch.tensor(mask).long(),
        'res_mask': torch.tensor(res_mask).long(),
        'seq': torch.tensor(seq).long(),
        'res_idx': torch.tensor(res_idx).long(),
    }
    return os.path.basename(file_path), feats

def parse_pdb(filename):
    lines = open(filename,'r').readlines()
    return parse_pdb_lines(lines)

def parse_pdb_lines(lines):

    # indices of residues observed in the structure
    idx_s = [int(l[22:26]) for l in lines if l[:4]=="ATOM" and l[12:16].strip()=="CA"]

    # 4 BB + up to 10 SC atoms
    xyz = np.full((len(idx_s), 14, 3), np.nan, dtype=np.float32)
    seq = []
    for l in lines:
        if l[:4] != "ATOM":
            continue
        resNo, atom, aa = int(l[22:26]), l[12:16], l[17:20]
        seq.append(residue_constants.restype_3to1[aa])
        idx = idx_s.index(resNo)
        for i_atm, tgtatm in enumerate(chemical.aa2long[chemical.aa2num[aa]]):
            if tgtatm == atom:
                xyz[idx,i_atm,:] = [float(l[30:38]), float(l[38:46]), float(l[46:54])]
                break

    # save atom mask
    mask = np.logical_not(np.isnan(xyz[...,0]))
    xyz[np.isnan(xyz[...,0])] = 0.0

    return xyz, mask, np.array(idx_s), ''.join(seq)

In [67]:
pdb_id, pdb_data = load_pdb('5naj_A')
fb_id, fb_data = load_fb('UniRef50_A0A1N6PI58', 'c87b')

In [None]:
torch.mean(pdb_data['res_mask'].float())

In [69]:
torch.mean(fb_data['res_mask'].float())

tensor(0.7842)

In [97]:
train_csv.sample()

Unnamed: 0,CHAINID,DEPOSITION,RESOLUTION,HASH,CLUSTER,SEQUENCE,LEN_EXIST,seq_len,resolution,date,source,split,plDDT
6977295,UniRef50_A0A1H3MGD0,,,dda8,2660271,MKNIRFYEAEKYNSDDYEKVEDMIYMPHHDPSEQNIIYVTSIIYEP...,,128,,NaT,fb,train,94.29


In [14]:
pdb_weights, pdb_splits, pdb_sizes = _calc_cluster_weights(pdb_csv)
fb_weights, fb_splits, fb_sizes = _calc_cluster_weights(fb_csv)

In [None]:
train_set = DistilledDataset(pdb_IDs, loader_pdb, loader_pdb_fixbb, pdb_dict,
                             compl_IDs, loader_complex, loader_complex_fixbb, compl_dict,
                             neg_IDs, loader_complex, neg_dict,
                             fb_IDs, loader_fb, loader_fb_fixbb, fb_dict,
                             homo, self.loader_param)

valid_pdb_set = Dataset(list(valid_pdb.keys())[:self.n_valid_pdb],
                        loader_pdb, valid_pdb,
                        self.loader_param, homo, p_homo_cut=-1.0)

# Diffuser