# EigenFold

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import nglview as nv
from proteome.models.design.eigenfold.modeling import OmegaFoldForGraphEmbedding



In [3]:
sequence = 'MAAHKGAEHHHKAAEHHEQAAKHHHAAAEHHEKGEHEQAAHHADTAYAHHKHAEEHAAQAAKHDAEHHAPKPH'

In [4]:
embedder = OmegaFoldForGraphEmbedding()

In [5]:
node_repr, edge_repr = embedder.embed(sequence)

In [6]:
from dataclasses import dataclass

@dataclass
class Args:
    split_key: str = None
    inf_mols: int = 1000
    wandb: str = None
    num_workers: int = None
    ode: bool = False
    elbo: bool = True
    alpha: float = 1
    beta: float = 3
    num_samples: int = 1
    inf_step: float = 0.5
    elbo_step: float = 0.2
    inf_type: str = 'rate'
    max_len: int = 1024
    inf_Hf: float = None
    inf_kmin: int = None
    inf_tmin: int = None
    inf_cutoff: int = None
    embeddings_dir: str = None
    pdb_dir: str = None
    embeddings_key: str = None
    train_Hf: float = 2
    train_tmin: float = 0.01
    inf_step: float = 0.5
    inf_type: str = "rate"  # choices=['entropy', 'rate']
    train_cutoff: float = 5
    train_kmin: int = 5

In [7]:
@dataclass
class ModelConfig:
    sde_a: float = 3/(3.8**2)
    sde_b: float = 0
    resi_conv_layers: int = 6
    resi_ns: int = 32
    resi_nv: int = 4
    resi_ntps: int = 16
    resi_ntpv: int = 4
    resi_fc_dim: int = 128
    resi_pos_emb_dim: int = 16
    lin_nf: int = 1
    lin_self: bool = False
    attention: bool = False
    sh_lmax: int = 2
    order: int = 1
    t_emb_dim: int = 32
    t_emb_type: str = 'sinusoidal'
    radius_emb_type: str = 'gaussian'
    radius_emb_dim: int = 50
    radius_emb_max: float = 50
    tmin: float = 0.001
    tmax: float = 1e6
    no_radius_sqrt: bool = False
    parity: bool = True
    lm_edge_dim: int = 128
    lm_node_dim: int = 256
    no_edge_embs: bool = False

In [8]:
ckpt = torch.hub.load_state_dict_from_url(
    "https://github.com/bjing2016/EigenFold/raw/master/pretrained_model/epoch_7.pt", map_location="cpu", file_name="eigenfold_epoch_7.pt"
)

In [9]:
from proteome.models.design.eigenfold.score_model import get_model

In [10]:
ef_model = get_model(ModelConfig())



In [11]:
ef_model.load_state_dict(ckpt['model'], strict=True)

<All keys matched successfully>

In [12]:
ef_model = ef_model.eval()
ef_model = ef_model.cuda()

In [13]:
from torch_geometric.data import HeteroData
from proteome.models.design.eigenfold.dataset import get_dense_edges
from proteome.models.design.eigenfold.polymer import PolymerSDE
from proteome.models.design.eigenfold import schedule
from proteome.models.design.eigenfold.inference import get_score_fn
from proteome.models.design.eigenfold.sampling import logp, reverse_sample
from proteome.models.design.eigenfold.pdb import PDBFile



In [14]:
args = Args()
conf = ModelConfig()
data = HeteroData()

data.skip = False
seqlen = node_repr.shape[0]
data["resi"].num_nodes = seqlen
data["resi"].edge_index = get_dense_edges(seqlen)

sde = PolymerSDE(N=seqlen, a=conf.sde_a, b=conf.sde_b)
sde.make_schedule(Hf=args.train_Hf, step=args.inf_step, tmin=args.train_tmin)

data.resi_sde = data.sde = sde
if conf.no_edge_embs:
    edge_repr = np.zeros_like(edge_repr)

data["resi"].node_attr = node_repr
src, dst = data["resi"].edge_index[0], data["resi"].edge_index[1]
data["resi"].edge_attr_ = torch.cat(
    [edge_repr[src, dst], edge_repr[dst, src]], -1
)

In [15]:
sde = data.sde
sched = {"entropy": schedule.EntropySchedule, "rate": schedule.RateSchedule}[
    args.inf_type
](
    sde,
    Hf=args.train_Hf,
    rmsd_max=0,
    step=args.inf_step,
    cutoff=args.train_cutoff,
    kmin=args.train_kmin,
    tmin=args.train_tmin,
    alpha=args.alpha,
    beta=args.beta,
)
sched_full = gsched = {"entropy": schedule.EntropySchedule, "rate": schedule.RateSchedule}[
    args.inf_type
](
    sde,
    Hf=args.train_Hf,
    rmsd_max=0,
    step=args.elbo_step,
    cutoff=args.train_cutoff,
    kmin=args.train_kmin,
    tmin=args.train_tmin,
    alpha=0,
    beta=1,
)

In [16]:
score_fn = get_score_fn(args, ef_model, data, key="resi", device="cuda:0")

In [17]:
pdb = PDBFile(sequence)

In [18]:
data.Y = reverse_sample(
    args,
    score_fn,
    sde,
    sched,
    device="cuda:0",
    Y=None,
    pdb=pdb,
    tqdm_=True,
    ode=args.ode,
)

100%|███████████████████████████████████████████████████████████| 144/144 [00:08<00:00, 16.34it/s]


In [19]:
data.elbo_Y = (
    logp(data.Y, score_fn, sde, sched_full, device="cuda:0", tqdm_=True)
    if args.elbo
    else np.nan
)

100%|███████████████████████████████████████████████████████████| 144/144 [00:02<00:00, 56.17it/s]


In [20]:
data.pdb = pdb

In [25]:
with open("added.pdb", mode="r") as f:
    pdb_str = f.read()

In [21]:
data.pdb.write("./base.pdb", reverse=True)
data.pdb.clear().add(data.Y).write("./added.pdb")

In [26]:
view = nv.show_text(pdb_str)
view

NGLWidget()