In [None]:
import copy
import json
import os
import uuid
import sys

import warnings
import gzip
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from Bio.PDB import PDBParser
from Bio.SeqUtils import seq1
warnings.simplefilter('ignore', PDBConstructionWarning)

import torch
import torch as ch
import torch.nn as nn
from fastargs import Param, Section
from fastargs.validation import And, OneOf
import numpy as np
import src.config_parse_utils as config_parse_utils
from src.eval_utils import evaluate_model
from src.trainer import LightWeightTrainer
from src.models_and_optimizers import create_clip_model, load_model
import src.dist_utils as dist_utils
import src.data_utils as data_utils
from transformers import EsmTokenizer
import src.loader as loaders_utils
import webdataset as wds
import tqdm
import tensorflow as tf
import os
import logging
from functools import partial
import src.loader as loader_utils
import src.zipdataset as zipdataset_utils
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # FATAL
logging.getLogger('tensorflow').setLevel(logging.FATAL)

import src.models_and_optimizers as model_utils
from types import SimpleNamespace
from clip_main import get_wds_loaders
from transformers import EsmTokenizer
import src.data_utils as data_utils
import os
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast
from transformers import EsmTokenizer, EsmModel
import esm as esmlib

In [None]:
## WDS helpers

def process_residue(residue):
    atoms = ['N', 'CA', 'C', 'O']
    coordinates = []
    for r in atoms:
        coord = residue.child_dict.get(r, None)
        if coord is None:
            if r == 'O':
                coord = residue.child_dict.get('OXT', None)
            if coord is None:
                return None, None
        coordinates.append(np.array(coord.get_coord()))
    return np.stack(coordinates), seq1(residue.resname)

def process_chain(chain):
    coordinates = []
    seq = []
    for r in chain:
        output, residue_name = process_residue(r)
        if output is not None:
            coordinates.append(output)
            seq.append(residue_name)
    if len(coordinates) == 0:
        return None
    coordinates = np.stack(coordinates)
    seq = ''.join(seq)
    return coordinates, seq

def process_chains(chains, pep=False, prot=False):
    if pep or prot:
        chain_lens = []
        chain_ids = []
        for chain in chains:
            for i, res in enumerate(chain):
                continue
            chain_lens.append(i)
            chain_ids.append(chain.id)
        if chain_lens[0] < chain_lens[1]:
            pep_id = chain_ids[0]
            prot_id = chain_ids[1]
        else:
            pep_id = chain_ids[1]
            prot_id = chain_ids[0]
        if pep and isinstance(pep, str): pep_id == pep
        if prot and isinstance(prot, str): prot_id == prot
    output = []
    chain_ids = []
    for chain in chains:
        if (pep and chain.id != pep_id) or (prot and chain.id != prot_id):
            continue
        out = process_chain(chain)
        if out is not None:
            output.append(out)
            chain_ids.append(chain.id)
    coords = [u[0] for u in output]
    seqs = [u[1] for u in output]
    return coords, seqs, chain_ids

def process_structure(structure, pep=False, prot=False):
    for s in structure: # only one structure
        return process_chains(s, pep, prot)
    return None

# +
def process_pdb(parser, pdb_filename):
    # print(pdb_filename)
    with gzip.open(pdb_filename, "rt") as file_handle:
        structure = parser.get_structure("?", file_handle)
        date = structure.header['deposition_date']
        return process_structure(structure), date
    
def process_pdb_raw(parser, pdb_filename, pep=False, prot=False):
    s = parser.get_structure("?", pdb_filename)
    return process_structure(s, pep, prot)

def read_input_ids(index_file):
    input_ids = []
    with open(os.path.join(index_file), 'r') as f:
        for line in f:
            input_ids += [line.strip()]
    return np.array(input_ids)

def write_dataset(dataset, tar_name, use_shards=False, max_shard_count=10000):
    if use_shards:
        os.makedirs(tar_name, exist_ok=True)
        sink = wds.ShardWriter(f'{tar_name}/shard-%06d.tar',maxcount=max_shard_count)
    else:
        sink = wds.TarWriter(tar_name)
    for index, (batch, pdb_id) in enumerate(dataset):
        if index%1000==0:
            print(f"{index:6d}", end="\r", flush=True, file=sys.stderr)
        if len(batch[0]) == 0:
            continue
        sink.write({
            "__key__": "sample%06d" % index,
            "inp.pyd": dict(coords=batch[0], seqs=batch[1], chain_ids=batch[2], pdb_id=pdb_id),
        })
    sink.close()
    
def make_wds(dir_, tar_):
    """
    Args:
        dir_ (str): Directory containing PDB files.
        tar_ (str): Output file path to write WDS to.
    """
    parser = PDBParser()
    root_pdb = dir_
    outputs = []
    for i, pdb_file in tqdm.tqdm(enumerate(os.listdir(dir_)), total=len(os.listdir(dir_))):
        pdb_file = pdb_file.strip()
        pdb_file = os.path.join(dir_, pdb_file)
        out = process_pdb_raw(parser, pdb_file)
        pdb_id = pdb_file.split('.')[0]
        for sequence in 
        outputs.append((out, pdb_id))

    dataset = []
    for o, pdb_id in tqdm.tqdm(outputs):
        if o is None:
            continue
        dataset.append((o, pdb_id))

    write_dataset(dataset, tar_)

In [None]:
## GENERAL SETUP (CHANGE PATHS AS NEEDED)
ROOT = "/home/gridsan/lguan/keating/rla/model_weights"
model_dir = "version_0/" 
dev = 'cuda:0'
CLIP_MODE = False
root_path = os.path.join(ROOT, model_dir)
path = os.path.join(root_path, "checkpoints/checkpoint_best.pt")
data_root = "/home/gridsan/lguan/keating/pacap/rfdiffusion/20240925/pacap_bind/outputs/wds"
args_path = os.path.join(ROOT, model_dir, [u for u in os.listdir(os.path.join(ROOT, model_dir)) if u.endswith('.pt')][0])

backwards_compat = {
    'masked_rate': -1,
    'masked_mode': 'MASK',
    'lm_only_text': 1,
    'lm_weight': 1,
    'resid_weight': 1,
    'language_head': False,
    'language_head_type': 'MLP',
    'zip_enabled': False,
    'num_mutations': False,
}
hparams = torch.load(args_path)
args_dict = hparams['args']
args_dict['data_root'] = data_root
args_dict['batch_size'] = 1
args_dict['blacklist_file'] = ''
args_dict['num_workers'] = 1
for k in backwards_compat.keys():
    if k not in args_dict:
        args_dict[k] = backwards_compat[k]
args = SimpleNamespace(**args_dict)

print(vars(args))

coordinator_params = data_utils.get_coordinator_params(args.coordinator_hparams)
coordinator_params['num_positional_embeddings'] = args.gnn_num_pos_embs
coordinator_params['zero_out_pos_embs']= args.gnn_zero_out_pos_embs
coordinator_params['clip_mode'] = True