## Openfold Colab adaptation

This notebook is a modification of the OpenFold Colab notebook: https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb

In [2]:
#import sys
#sys.path.append('./openfold_model')

In [1]:
import os

import torch

# A filthy hack to avoid slow Linear layer initialization
import openfold.model.primitives

def __default_linear_init__(self, *args, **kwargs):
    return torch.nn.Linear.__init__(
      self, 
      *args[:2], 
      **{k:v for k,v in kwargs.items() if k == "bias"}
    )

openfold.model.primitives.Linear.__init__ = __default_linear_init__

from openfold import config
from openfold.data import feature_pipeline
from openfold.data import data_pipeline
from openfold.model import model
from openfold.utils.import_weights import import_jax_weights_
from openfold.utils.tensor_utils import tensor_tree_map

In [3]:
%load_ext autoreload

In [4]:
%autoreload 2

In [5]:
def _placeholder_template_feats(num_templates_, num_res_):
  return {
      'template_aatype': torch.zeros(num_templates_, num_res_, 22).long(),
      'template_all_atom_positions': torch.zeros(num_templates_, num_res_, 37, 3),
      'template_all_atom_mask': torch.zeros(num_templates_, num_res_, 37),
      'template_domain_names': torch.zeros(num_templates_),
      'template_sum_probs': torch.zeros(num_templates_, 1),
  }

In [6]:
import pickle

with open("../data_pickle/deletion_matrices.pickle", 'rb') as f:
    deletion_matrices = pickle.load(f)

with open("../data_pickle/msas.pickle", 'rb') as f:
    msas = pickle.load(f)

In [7]:
sequence = "MSIQHFRVALIPFFAAFCLPVFAHPETLVKVKDAEDQLGARVGYIELDLNSGKILESFRPEERFPMMSTFKVLLCGAVLSRVDAGQEQLGRRIHYSQNDLVEYSPVTEKHLTDGMTVRELCSAAITMSDNTAANLLLTTIGGPKELTAFLHNMGDHVTRLDRWEPELNEAIPNDERDTTMPAAMATTLRKLLTGELLTLASRQQLIDWMEADKVAGPLLRSALPAGWFIADKSGAGERGSRGIIAALGPDGKPSRIVVIYTTGSQATMDERNRQIAEIGASLIKHW"

In [8]:
model_name = "model_1"

In [9]:
num_templates = 1 # dummy number --- is ignored
num_res = len(sequence)

In [10]:
feature_dict = {}
feature_dict.update(data_pipeline.make_sequence_features(sequence, 'test', num_res))
feature_dict.update(data_pipeline.make_msa_features(msas, deletion_matrices=deletion_matrices))
feature_dict.update(_placeholder_template_feats(num_templates, num_res))

In [11]:
import pickle

filepath = "msa_arr.pickle"
output_dir = "./../data_pickle"
msa_output_path = os.path.join(output_dir, filepath)
with open(msa_output_path, 'wb') as f:
    pickle.dump(feature_dict['msa'], f, protocol=pickle.HIGHEST_PROTOCOL)

In [12]:
weight_set = "OpenFold"

In [13]:
cfg = config.model_config(model_name)
openfold_model = model.AlphaFold(cfg)
openfold_model = openfold_model.eval()

In [14]:
OPENFOLD_PARAMS_DIR = './openfold/resources/openfold_params'
ALPHAFOLD_PARAMS_DIR = './openfold/openfold/resources/params'

In [15]:
if(weight_set == "AlphaFold"):
  params_name = os.path.join(ALPHAFOLD_PARAMS_DIR, f"params_{model_name}.npz")
  import_jax_weights_(openfold_model, params_name, version=model_name)

elif(weight_set == "OpenFold"):
  model_name_spl = model_name.split("_")

  if(model_name_spl[-1] == "ptm"):
    of_model_name = "finetuning_ptm_2.pt"
    
  else:
    of_model_name = f"finetuning_{model_name_spl[-1]}.pt"
    
  params_name = os.path.join(
    OPENFOLD_PARAMS_DIR,
    of_model_name
  )

  d = torch.load(params_name)
  openfold_model.load_state_dict(d)
else:
  raise ValueError(f"Invalid weight set: {weight_set}")

In [16]:
openfold_model = openfold_model.cuda()

In [17]:
pipeline = feature_pipeline.FeaturePipeline(cfg.data)
 
processed_feature_dict = pipeline.process_features(
  feature_dict, mode='predict'
)

  k: torch.tensor(v) for k, v in np_example.items() if k in features


		Recycling from configs
	CALLING MSA SAMPLING
Dumping pickle at ./../data_pickle/sel_seq.pickle
Dumping pickle at ./../data_pickle/not_sel_seq.pickle


In [18]:
processed_feature_dict.keys()

dict_keys(['aatype', 'residue_index', 'seq_length', 'template_aatype', 'template_all_atom_positions', 'template_all_atom_mask', 'template_sum_probs', 'seq_mask', 'msa_mask', 'msa_row_mask', 'template_mask', 'template_pseudo_beta', 'template_pseudo_beta_mask', 'template_torsion_angles_sin_cos', 'template_alt_torsion_angles_sin_cos', 'template_torsion_angles_mask', 'atom14_atom_exists', 'residx_atom14_to_atom37', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'extra_msa', 'extra_msa_mask', 'extra_msa_row_mask', 'bert_mask', 'true_msa', 'extra_has_deletion', 'extra_deletion_value', 'msa_feat', 'target_feat'])

In [19]:
processed_feature_dict['true_msa'].shape

torch.Size([512, 286, 1])

In [20]:
processed_feature_dict = tensor_tree_map(
    lambda t: t.cuda(), processed_feature_dict
)

In [21]:
processed_feature_dict['true_msa'].shape

torch.Size([512, 286, 1])

In [22]:
torch.cuda.empty_cache()

with torch.no_grad():
  prediction_result = openfold_model(processed_feature_dict)

In [23]:
prediction_result.keys()

dict_keys(['msa', 'pair', 'single', 'sm', 'final_atom_positions', 'final_atom_mask', 'final_affine_tensor', 'lddt_logits', 'plddt', 'distogram_logits', 'masked_msa_logits', 'experimentally_resolved_logits'])

In [None]:
prediction_result['msa']

In [24]:
prediction_result['msa'].shape

torch.Size([512, 286, 256])

In [25]:
import pickle

filepath = "prediction_result.pickle"
output_dir = "./../data_pickle"
msa_output_path = os.path.join(output_dir, filepath)
with open(msa_output_path, 'wb') as f:
    pickle.dump(prediction_result, f, protocol=pickle.HIGHEST_PROTOCOL)