In [1]:
import logging
import os
import sys

import numpy as np
import torch

from rhofold.data.balstn import BLASTN
from rhofold.rhofold import RhoFold
from rhofold.config import rhofold_config
from rhofold.utils import get_device, save_ss2ct, timing
from rhofold.relax.relax import AmberRelaxation
from rhofold.utils.alphabet import get_features

@torch.no_grad()
def main(ckpt='./pretrained/RhoFold_pretrained.pt'):
    model = RhoFold(rhofold_config)
    model.load_state_dict(torch.load(ckpt, map_location=torch.device('cpu'))['model'], strict=False)
    model.eval()

    return model

def inference():
    device = get_device('cpu')
    model = main().to(device)
    print("Number of params:", sum(p.numel() for p in model.parameters()))
    input_fas = '../data/RNA3D_DATA/seq/1dqf_A.seq'
    input_a3m = '../data/RNA3D_DATA/rMSA/1dqf_A.a3m'
    data_dict = get_features(input_fas, input_a3m)

    outputs = model(tokens=data_dict['tokens'].to(device),
                    rna_fm_tokens=data_dict['rna_fm_tokens'].to(device),
                    seq=data_dict['seq'],
                    )

    output = outputs[-1]

    unrelaxed_model = f'tmp/unrelaxed_model.pdb'

    node_cords_pred = output['cord_tns_pred'][-1].squeeze(0)
    model.structure_module.converter.export_pdb_file(data_dict['seq'],
                                                        node_cords_pred.data.cpu().numpy(),
                                                        path=unrelaxed_model, chain_id=None,
                                                        confidence=output['plddt'][0].data.cpu().numpy())



In [2]:
inference()

Number of params: 130059855


In [3]:
from get_evo2_embeddings import seq_id_to_embedding

seq_id_to_embedding("1dqf_A")

ModuleNotFoundError: No module named 'get_evo2_embeddings'