In [1]:
import os
import sys

import torch
import torch.nn.functional as F

import numpy as np

import matplotlib
import matplotlib.pyplot as plt    
import matplotlib.animation as anim
import matplotlib.colors as colors
import matplotlib.patches as patches

from translator import *
from olgdesign import *
from ofvalidation import *

torch.set_num_threads(2)
gpu_id = 0
os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
device = torch.device(f'cuda:0') if gpu_id >= 0 else torch.device('cpu')

nucleotides = ['A', 'T', 'G', 'C']
amino_acids = list("ARNDCQEGHILKMFPSTWYV*")

In [2]:
#Load rosettafold net
os.chdir("/home/ubuntu/projects/olgdesign")
sys.path.insert(0, "./util")
include_dir = './'
network_name = 'rf_Nov05_2021'
weights_dir = "./weights"
rosetta, rosetta_params = load_model(include_dir, network_name, weights_dir, device)

#Load translator net
translator_file = "./weights/translator/translator_cnn_512ch.pth"
translator = torch.load(translator_file)

#Make/load background for KLD loss
#bkg_L100 = mk_bkg(rosetta, 100, device, n_runs=1000)
#torch.save(bkg_L100, "./bkg/L100.pth")
bkg_100 = torch.load("./bkg/L100.pth")

In [3]:
#Run hallucination of a pair of proteins of length 100 that are fully overlapping
f1_length = 100
f2_length = 100
f1_frame = 0 #Frame 0
f2_frame = 5 #Frame -2
offset = 1
total_length = f1_length + offset + 2

#Mask for fixing first AA to Met
f1_force = torch.zeros(1, 21, f1_length)
f1_force[:, 12, 0] = 1.0
f2_force = torch.zeros(1, 21, f2_length)
f2_force[:, 12, 0] = 1.0

#Specify masks for KL divergence or cross entropy loss
wstart = 5 #Mask out first few AAs to give some space for fixed AAs and stop codons
wend = 95
mask_f1 = torch.zeros(f1_length, f1_length)
mask_f1[wstart:wend, wstart:wend].fill_(1.0)
mask_f1.fill_diagonal_(0.0)
mask_f1 = mask_f1.unsqueeze(0)
mask_f2 = torch.zeros(f2_length, f2_length)
mask_f2[wstart:wend, wstart:wend].fill_(1.0)
mask_f2.fill_diagonal_(0.0)
mask_f2 = mask_f2.unsqueeze(0)

In [None]:
#Main design loop
result = run_design(device, rosetta, translator, 
                    total_length, f1_frame, f2_frame, offset,
                    f1_force, f2_force, 
                    bkg_100, bkg_100, 
                    mask_f1, mask_f2, 
                    True, False,
                    lr=0.05, betas=(0.5, 0.9), 
                    weight_decay=0.0, accumulator=1e-6, eps=1e-3, 
                    lookahead_k=10, lookahead_alpha=0.5,
                    n_step_gd=500, n_step_gd_n=0.4, n_max_h=5,
                    early_gd_stop=0.5, grad_clip_p=0.1,
                    n_step_sa=500, 
                    alpha_gd=0.1, alpha_sa=0.3,
                    weight_kl=3.0, 
                    weight_ce=0.0,
                    weight_lddt=0.0,
                    weight_stop=1.0, 
                    weight_force=2.0, 
                    weight_last=2.0,
                    weight_rog=0.1, 
                    rog_thres=16.0, 
                    max_mut=1, tau0=5e-3, anneal_rate=1e-5, min_temp=1e-5, 
                    print_loss=True)

In [None]:
#Save trajectory animation
plot_res = anim_res(result, 1, 20)
plot_res.save("example_run.gif")

In [39]:
#Get best scoring sequence
min_prot_1 = ''.join([amino_acids[i] for i in torch.argmax(result[1]['seq'][result[1]['min_step'][-1]]['prot'][0], 1)[0]])
min_prot_2 = ''.join([amino_acids[i] for i in torch.argmax(result[1]['seq'][result[1]['min_step'][-1]]['prot'][1], 1)[0]])
min_nuc = ''.join([nucleotides[i] for i in torch.argmax(result[1]['seq'][result[1]['min_step'][-1]]['nuc'], 1)[0]])

In [None]:
#Load OpenFold weights
openfold_weight_path = "./weights/openfold/"
of_model, of_cfg = load_openfold(openfold_weight_path, 12) #Max 48 recycles

#Run OpenFold for validation
f1_of_input, f1_fd = prepare_openfold_input(min_prot_1, of_cfg)
f2_of_input, f2_fd = prepare_openfold_input(min_prot_2, of_cfg)
with torch.no_grad():
    f1_of_out = of_model(f1_of_input)
    f2_of_out = of_model(f2_of_input)
of_fig = plot_2d(f1_of_out, f2_of_out)
of_fig.savefig(outdir+"/of_summary.svg")