In [0]:
%pip install ../proteinmpnn
# some extras we specify with pip requirements file to handle the extra url paths needed for odd cuda specific versions of packages
%pip install -r ../envs/requirements.txt
dbutils.library.restartPython()

In [0]:
# as our pyproject.toml is not at the root of a git folder we additionally
# need to point to the repo root directly with sys
# note to self: why is it ok if toml is at git project root?
# note that when we do model serving dependencies and internal module installs can still be handled by the pyproject.toml
# import sys
# sys.path.append("../")
# import proteinmpnn
# proteinmpnn.__file__

In [0]:
from proteinmpnn.run import main, get_argparser
from proteinmpnn.parse_multiple_chains import main as pdb_main
from proteinmpnn.parse_multiple_chains import get_argparser as pdb_get_argparser
import tempfile

from typing import Optional,List

import mlflow
from mlflow.types.schema import ColSpec, Schema
mlflow.set_registry_uri("databricks-uc")
        

In [0]:
class ProteinMPNN(mlflow.pyfunc.PythonModel):

    def load_context(self, context):
        self.model_dir = context.artifacts['model_dir']

    def _prepare_pdb_input(self,pdb_str:str,outdir:str):
        
        from proteinmpnn.parse_multiple_chains import main as pdb_main
        from proteinmpnn.parse_multiple_chains import get_argparser as pdb_get_argparser
        import tempfile
        parser = pdb_get_argparser()

        with tempfile.TemporaryDirectory() as temp_dir:
            with open(temp_dir + "/my_pdb.pdb", "w") as f:
                f.write(pdb_str)
            
            arg_list = []
            # arg_list.extend(['--ca_only'])
            arg_list.extend(['--input_path', temp_dir])
            arg_list.extend(['--output_path', f'{outdir}/inputs.jsonl'])
            args = parser.parse_args(arg_list)
            pdb_main(args)
        return None

    def _run_proteinmpnn(self,input_path, output_dir):
        from proteinmpnn.run import main, get_argparser

        parser = get_argparser()
        arg_list = []
        arg_list.extend(['--suppress_print', "1"])
        # arg_list.extend(['--ca_only'])
        arg_list.extend(['--jsonl_path', input_path])
        arg_list.extend(['--out_folder', output_dir])
        arg_list.extend(['--num_seq_per_target', "3"])
        arg_list.extend(['--sampling_temp', "0.1"])
        arg_list.extend(['--batch_size', "1"])
        arg_list.extend(['--path_to_model_weights', self.model_dir])
        args = parser.parse_args(arg_list)

        main(args)
        return None


    def predict(self, context, inputs : List[str], params=None) -> List[str]:
        import tempfile
        if len(inputs)!= 1:
            raise ValueError("Expected exactly one input")
        pdb_str= inputs[0]
        with tempfile.TemporaryDirectory() as tmpdir:
            self._prepare_pdb_input(pdb_str,tmpdir)
            with tempfile.TemporaryDirectory() as outdir:
                self._run_proteinmpnn(tmpdir+'/inputs.jsonl', outdir)
                with open(outdir+'/seqs/my_pdb.fa', 'r') as f:
                    lines = f.readlines()
                seqs = lines[3::2]
        return [s.strip() for s in seqs]




In [0]:
model = ProteinMPNN()

artifacts={
    "model_dir" : "/Volumes/protein_folding/proteinmpnn/model_weights/vanilla_model_weights/",
}
context=mlflow.pyfunc.PythonModelContext(artifacts=artifacts, model_config=dict())
model.load_context(context)

with open('../example_data/inputs/5yd3.pdb', 'r') as f:
    in_pdb_str = f.read()

seqs = model.predict(
    context,
    [in_pdb_str]
)

In [0]:
seqs

In [0]:
%sh
# move a copy of our code base to "local" machine and then register it with the model
# this will make a copy of our codebase that we can then install on the server for model serving
mkdir -p /local_disk0/proteinmpnn
cp -r ../proteinmpnn/src /local_disk0/proteinmpnn
cp ../proteinmpnn/pyproject.toml /local_disk0/proteinmpnn

In [0]:
%sh
ls /local_disk0/proteinmpnn
ls /local_disk0/proteinmpnn/src/proteinmpnn

In [0]:
signature = mlflow.models.signature.ModelSignature(
    inputs = Schema([ColSpec(type="string")]),
    outputs = Schema([ColSpec(type="string")]),
    params = None
)

with mlflow.start_run(run_name='protein_mpnn'):
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=ProteinMPNN(),
        artifacts={
            "model_dir" : "/Volumes/protein_folding/proteinmpnn/model_weights/vanilla_model_weights/",
            'repo_path': '/local_disk0/proteinmpnn'
        },
        input_example=[in_pdb_str],
        signature=signature,
        conda_env='../envs/conda_env.yaml',
        registered_model_name="protein_folding.proteinmpnn.proteinmpnn"
    )