In [None]:
%load_ext autoreload
%autoreload 2

import os
import socket
import urllib.request
import sys


os.environ["NO_PROXY"] = "*"
os.environ["no_proxy"] = "*"


def no_network(*args, **kwargs):
    raise Exception("Network access disabled")


socket.socket = no_network
urllib.request.urlopen = no_network

sys.path.append("./bio2token_main")

from typing import *
import torch
import argparse
import os
import logging

from bio2token_main.bio2token.models.fsq_ae import FSQ_AE
from bio2token_main.bio2token.utils.utils import *
from bio2token_main.bio2token.utils.pdb import *

import warnings
import Bio

warnings.simplefilter("ignore", Bio.PDB.PDBExceptions.PDBConstructionWarning)


In [None]:
TOKENIZER_PATH = "./bio2token_main/"
chains = None
pdb = "3wbm"
seq_type = "AA"
tokenizer = "bio2token"

config_model = load_config(TOKENIZER_PATH + "configs/tokenizer.yaml")
model = FSQ_AE(config_model).cuda()

checkpoint = TOKENIZER_PATH + "checkpoints/bio2token.ckpt"
state_dict = torch.load(checkpoint, map_location="cuda:0")["state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
    new_state_dict[k.replace("model.", "")] = v

model.load_state_dict(new_state_dict)
model.eval()

## CASP Benchmark

In [None]:
IN_DIR = os.path.join("./data/casp15_backbone_only/")
OUT_DIR = os.path.join("./data/bio2token/bio2token_casp_backbone_recon")
os.makedirs(OUT_DIR, exist_ok=True)

file_names = os.listdir(IN_DIR)
file_names.sort()

for file_name in file_names:
    print(f"Processing {file_name} from {IN_DIR}")
    biomolecule = pdb_2_dict(
        os.path.join(IN_DIR, file_name),
        chains=chains,
    )
    batch = pdb_to_batch(config_model, biomolecule)
    batch["seq_type"] = seq_type

    with torch.no_grad():
        out = model.step(batch, mode="inference")

    gt = out["coords_gt"][:, : biomolecule["atom_length"], :].squeeze(0).cpu().numpy()
    gt = np.split(gt, gt.shape[0])
    recon = (
        out["coords_pred_kabsch_all"][:, : biomolecule["atom_length"], :]
        .squeeze(0)
        .cpu()
        .numpy()
    )
    recon = np.split(recon, recon.shape[0])

    # pdb_dict_gt = to_pdb_dict(
    #     coords=gt,
    #     atom_names=biomolecule["atom_names"],
    #     continuous_res_ids=biomolecule["continuous_res_ids"],
    #     residue_names=biomolecule["res_names"],
    #     res_types=biomolecule["res_types"],
    # )
    pdb_dict_recon = to_pdb_dict(
        coords=recon,
        atom_names=biomolecule["atom_names"],
        continuous_res_ids=biomolecule["continuous_res_ids"],
        residue_names=biomolecule["res_names"],
        res_types=biomolecule["res_types"],
    )

    count_atoms = 0
    for k, v in pdb_dict_recon.items():
        count_atoms += len(v["atom_names"])

    # pdb_dict_to_file(
    #     pdb_dict_gt,
    #     pdb_file_path=os.path.join(
    #         "./data/bio2token/examples/recon", f"{pdb}_{tokenizer}_gt.pdb"
    #     ),
    # )
    output_file = os.path.join(OUT_DIR, f"{os.path.splitext(file_name)[0]}.pdb")
    pdb_dict_to_file(
        pdb_dict_recon,
        pdb_file_path=output_file,
    )

# CATH 4.2.0 Tokens

In [None]:
IN_DIR = os.path.join("./data/cath-dataset-nonredundant-S40-v4_2_0.pdb/")
OUT_DIR = os.path.join("./data/cath_4_2_0_bio2token")
os.makedirs(OUT_DIR, exist_ok=True)

file_names = os.listdir(IN_DIR)
file_names.sort()

for file_name in file_names:
    print(f"Processing {file_name} from {IN_DIR}")
    biomolecule = pdb_2_dict(
        os.path.join(IN_DIR, file_name),
        chains=chains,
    )
    batch = pdb_to_batch(config_model, biomolecule)
    batch["seq_type"] = seq_type

    with torch.no_grad():
        out = model.step(batch, mode="inference")

    gt = out["coords_gt"][:, : biomolecule["atom_length"], :].squeeze(0).cpu().numpy()
    gt = np.split(gt, gt.shape[0])
    recon = (
        out["coords_pred_kabsch_all"][:, : biomolecule["atom_length"], :]
        .squeeze(0)
        .cpu()
        .numpy()
    )
    recon = np.split(recon, recon.shape[0])

    # pdb_dict_gt = to_pdb_dict(
    #     coords=gt,
    #     atom_names=biomolecule["atom_names"],
    #     continuous_res_ids=biomolecule["continuous_res_ids"],
    #     residue_names=biomolecule["res_names"],
    #     res_types=biomolecule["res_types"],
    # )
    pdb_dict_recon = to_pdb_dict(
        coords=recon,
        atom_names=biomolecule["atom_names"],
        continuous_res_ids=biomolecule["continuous_res_ids"],
        residue_names=biomolecule["res_names"],
        res_types=biomolecule["res_types"],
    )

    count_atoms = 0
    for k, v in pdb_dict_recon.items():
        count_atoms += len(v["atom_names"])

    # pdb_dict_to_file(
    #     pdb_dict_gt,
    #     pdb_file_path=os.path.join(
    #         "./data/bio2token/examples/recon", f"{pdb}_{tokenizer}_gt.pdb"
    #     ),
    # )
    output_file = os.path.join(OUT_DIR, f"{os.path.splitext(file_name)[0]}.pdb")
    pdb_dict_to_file(
        pdb_dict_recon,
        pdb_file_path=output_file,
    )