In [None]:
import logging
import os
import sys

import numpy as np
import torch

from rhofold.data.balstn import BLASTN
from rhofold.rhofold import RhoFold
from rhofold.config import rhofold_config
from rhofold.utils import get_device, save_ss2ct, timing
from rhofold.relax.relax import AmberRelaxation
from rhofold.utils.alphabet import get_features

@torch.no_grad()
def main(ckpt='./pretrained/RhoFold_pretrained.pt'):
    model = RhoFold(rhofold_config)
    model.load_state_dict(torch.load(ckpt, map_location=torch.device('cpu'))['model'], strict=False)
    model.eval()

    return model

def inference():
    device = get_device('cpu')
    model = main().to(device)
    input_fas = 'test-rhofold.ipynb'
    input_a3m = 'test-rhofold.ipynb'
    data_dict = get_features(input_fas, input_a3m)

    outputs = model(tokens=data_dict['tokens'].to(device),
                    rna_fm_tokens=data_dict['rna_fm_tokens'].to(device),
                    seq=data_dict['seq'],
                    )

    output = outputs[-1]

    unrelaxed_model = f'tmp/unrelaxed_model.pdb'

    node_cords_pred = output['cord_tns_pred'][-1].squeeze(0)
    model.structure_module.converter.export_pdb_file(data_dict['seq'],
                                                        node_cords_pred.data.cpu().numpy(),
                                                        path=unrelaxed_model, chain_id=None,
                                                        confidence=output['plddt'][0].data.cpu().numpy(),
                                                        logger=logger)



In [2]:
model = main()

In [4]:
sum(p.numel() for p in model.parameters())

126913743

In [3]:
dir(model)

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_pre_hooks',
 '_get_backward_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook',
 '_register_state_dict_hook',
 '_replicate_for_data_parallel',
 '_save_to_state_dict',
 '_slow_forward',
 '_state_dict_hooks',
 '_version',
 'add_module',
 'apply',
