In [None]:
from unifold.refine.dataset import load_and_process
from unifold.data.lmdb_dataset import LMDBDataset
from unifold.data.data_ops import atom37_to_torsion_angles
from unifold.config import model_config
from unifold.modules import AlphaFold
from unifold.refine.modules import RefineFold
from unifold.inference import automatic_chunk_size
import torch
import numpy as np
import json

config = model_config("model_2", train=False)
config.data.common.structure_refine = True
from typing import *

def recur_print(x):
    if isinstance(x, torch.Tensor) or isinstance(x, np.ndarray):
        return f"{x.shape}_{x.dtype}"
    elif isinstance(x, dict):
        return {k: recur_print(v) for k, v in x.items()}
    elif isinstance(x, Iterable):
        return [recur_print(v) for v in x]
    else:
        raise RuntimeError(x)


In [None]:
#@markdown ## load data
data_path = "/mnt/data/projects/unifold/data_0916/traineval/"

feat_lmdb = LMDBDataset(data_path + "features.lmdb")
lab_lmdb = LMDBDataset(data_path + "labels.lmdb")

feat_id_map = json.load(open(data_path + "train_label_to_seq.json"))

lid = "101m_A"
sid = feat_id_map[lid]

feat, lab = load_and_process(
    config.data,
    "predict",
    batch_idx=0,
    data_idx=0,
    sequence_ids=[sid],
    feature_dir=feat_lmdb,
    msa_feature_dir=data_path + "msa_features",
    template_feature_dir=data_path + "template_features",
    uniprot_msa_feature_dir=data_path + "uniprot_features",
    label_ids=[lid],
    label_dir=lab_lmdb,
)
lab = lab[0]

recur_print(feat)

In [None]:
def get_model(param_path='/mnt/data/projects/unifold/release_params/monomer.unifold.pt'):
    config = model_config("model_2_ft", train=False)
    config.data.common.max_recycling_iters = 3
    config.globals.max_recycling_iters = 3
    config.data.predict.num_ensembles = 1
    config.data.predict.subsample_templates = False
    model = RefineFold(config)

    print("start to load params {}".format(param_path))
    state_dict = torch.load(param_path)["ema"]["params"]
    state_dict = {".".join(k.split(".")[1:]): v for k, v in state_dict.items()}
    model.load_state_dict(state_dict)
    model = model.to("cuda:0")
    model.eval()
    model.inference_mode()
    return config, model
config, model = get_model()

In [None]:
from unifold.dataset import UnifoldDataset
batch = UnifoldDataset.collater([feat])
chunk_size, block_size = automatic_chunk_size(256, "cuda:0", False)
model.globals.chunk_size = chunk_size
model.globals.block_size = block_size

In [None]:
with torch.no_grad():
    batch = {
        k: torch.as_tensor(v, device="cuda:0")
        for k, v in batch.items()
    }
    raw_out = model(batch)
recur_print(batch)

In [None]:
recur_print(raw_out)

In [None]:
bb_frames = raw_out['sm']['frames']
sc_frames = raw_out['sm']['sidechain_frames']

In [None]:
bb_frames[-1, ...] == sc_frames[:, :, 0, :, :]

In [None]:
torsions = lab["torsion_angles_sin_cos"]    # [omg,phi,psi,x1-4] * [sin,cos]
masks = torch.prod(lab["torsion_angles_mask"][..., 1:3], axis=-1)

def sin_cos_to_angle(sin_cos):
    cos_ = sin_cos[..., 1]
    sin_ = sin_cos[..., 0]
    return torch.acos(cos_) * torch.sign(sin_)

phi_psi = sin_cos_to_angle((torsions * masks[..., None, None])[..., 1:3, :])

In [None]:
from matplotlib import pyplot as plt
plt.plot(phi_psi[:, 0], phi_psi[:, 1], '.')