In [20]:
%load_ext autoreload
%autoreload 2

import os
import socket
import urllib.request
import sys


os.environ["NO_PROXY"] = "*"
os.environ["no_proxy"] = "*"


def no_network(*args, **kwargs):
    raise Exception("Network access disabled")


socket.socket = no_network
urllib.request.urlopen = no_network

sys.path.append("./bio2token_main")

from typing import *
import torch
import argparse
import os
import logging

from bio2token_main.bio2token.models.fsq_ae import FSQ_AE
from bio2token_main.bio2token.utils.utils import load_config
from bio2token_main.bio2token.utils.pdb import (
    pdb_2_dict,
    pdb_to_batch,
    pdb_dict_to_file,
    to_pdb_dict,
)

import warnings
import Bio

warnings.simplefilter("ignore", Bio.PDB.PDBExceptions.PDBConstructionWarning)


from bio2token_extension import Bio2TokenExtension

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [21]:
TOKENIZER_PATH = "./bio2token_main/"
chains = None
pdb = "3wbm"
seq_type = "AA"
tokenizer = "bio2token"

config_model = load_config(TOKENIZER_PATH + "configs/tokenizer.yaml")
model = Bio2TokenExtension(config_model).cuda()

checkpoint = TOKENIZER_PATH + "checkpoints/bio2token.ckpt"
state_dict = torch.load(checkpoint, map_location="cuda:0")["state_dict"]
new_state_dict = {}
for k, v in state_dict.items():
    new_state_dict[k.replace("model.", "")] = v

model.load_state_dict(new_state_dict)
model.eval()

Bio2TokenExtension(
  (encoder): MambaLMHeadModel(
    (backbone): MixerModel(
      (embedding): Linear(in_features=3, out_features=128, bias=True)
      (layers): ModuleList(
        (0-1): 2 x Block(
          (mixer): Mamba(
            (in_proj): Linear(in_features=128, out_features=512, bias=False)
            (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
            (act): SiLU()
            (x_proj): Linear(in_features=256, out_features=40, bias=False)
            (dt_proj): Linear(in_features=8, out_features=256, bias=True)
            (out_proj): Linear(in_features=256, out_features=128, bias=False)
          )
          (norm): RMSNorm()
        )
      )
      (norm_f): RMSNorm()
    )
    (lm_head): Identity()
  )
  (decoder): MambaLMHeadModel(
    (backbone): MixerModel(
      (embedding): Identity()
      (layers): ModuleList(
        (0-3): 4 x Block(
          (mixer): Mamba(
            (in_proj): Linear(in_features=128, out_featu

In [22]:
IN_DIR = os.path.join("./data/cath-dataset-nonredundant-S40-v4_2_0")
OUT_DIR = os.path.join("./data/bio2token_tokens/cath_4_2_0_bio2token")
os.makedirs(OUT_DIR, exist_ok=True)

file_names = os.listdir(IN_DIR)[:1]
file_names.sort()

print("Number of files:", len(file_names))

Number of files: 1


In [23]:
type(model)

bio2token_extension.Bio2TokenExtension

In [24]:
for file_name in file_names:
    biomolecule = pdb_2_dict(
        pdb_path=os.path.join(IN_DIR, file_name),
        chains=chains,
    )
    batch = pdb_to_batch(config_model, biomolecule)
    batch["seq_type"] = seq_type
    with torch.no_grad():
        out = model.step(batch, mode="inference")

In [39]:
biomolecule.keys()

dict_keys(['pdb_id', 'seq', 'res_names', 'coords_groundtruth', 'atom_names', 'atom_types', 'seq_length', 'atom_length', 'chains', 'res_ids', 'continuous_res_ids', 'res_types'])

In [None]:
biomolecule["coords_groundtruth"]

array([[ -1.339, -14.972,   8.619],
       [ -2.357, -15.875,   8.061],
       [ -3.404, -15.142,   7.224],
       ...,
       [ -0.381,   4.085,   1.121],
       [  1.014,   3.971,   0.789],
       [ -1.1  ,   3.13 ,   0.148]], dtype=float32)

In [34]:
out["coords_pred_kabsch_all"].shape

torch.Size([1, 135301, 3])

In [None]:
for file_name in file_names:
    # print(f"Processing {file_name} from {IN_DIR}")
    biomolecule = pdb_2_dict(
        os.path.join(IN_DIR, file_name),
        chains=chains,
    )
    batch = pdb_to_batch(config_model, biomolecule)
    batch["seq_type"] = seq_type

    with torch.no_grad():
        out = model.step(batch, mode="inference")

    print(out)

    gt = out["coords_gt"][:, : biomolecule["atom_length"], :].squeeze(0).cpu().numpy()
    gt = np.split(gt, gt.shape[0])
    recon = (
        out["coords_pred_kabsch_all"][:, : biomolecule["atom_length"], :]
        .squeeze(0)
        .cpu()
        .numpy()
    )
    recon = np.split(recon, recon.shape[0])

    # pdb_dict_gt = to_pdb_dict(
    #     coords=gt,
    #     atom_names=biomolecule["atom_names"],
    #     continuous_res_ids=biomolecule["continuous_res_ids"],
    #     residue_names=biomolecule["res_names"],
    #     res_types=biomolecule["res_types"],
    # )
    pdb_dict_recon = to_pdb_dict(
        coords=recon,
        atom_names=biomolecule["atom_names"],
        continuous_res_ids=biomolecule["continuous_res_ids"],
        residue_names=biomolecule["res_names"],
        res_types=biomolecule["res_types"],
    )

    count_atoms = 0
    for k, v in pdb_dict_recon.items():
        count_atoms += len(v["atom_names"])

    # pdb_dict_to_file(
    #     pdb_dict_gt,
    #     pdb_file_path=os.path.join(
    #         "./data/bio2token/examples/recon", f"{pdb}_{tokenizer}_gt.pdb"
    #     ),
    # )
    output_file = os.path.join(OUT_DIR, f"{os.path.splitext(file_name)[0]}.pdb")
    pdb_dict_to_file(
        pdb_dict_recon,
        pdb_file_path=output_file,
    )