## Log and register Boltz-1 model on Unity Catalog

- note that with dbboltz version we suport JackHMMer option for MSA when running in notebook (but for serving the MSA with JackHMMer will take too long)

In [0]:
%pip install ../dbboltz[gpu]
%pip install py3Dmol
dbutils.library.restartPython()

In [0]:
import mlflow
mlflow.autolog(disable=True)
from dbboltz.boltz import run_boltz, Boltz
import yaml

In [0]:
def get_model_config():
    model_config = {}
    model_config['jackhmmer_binary_path'] = "/miniconda3/envs/jackhmmer_env/bin/jackhmmer"
    model_config['compute_type'] = 'gpu'
    return model_config

model_config = get_model_config()

### Initialize the model

In [0]:
model = Boltz()
context = mlflow.pyfunc.PythonModelContext(
    artifacts = {
        'CACHE_DIR': '/Volumes/protein_folding/boltz/weights'
    },
    model_config = model_config
)
model.load_context(context)

### Helper function to map between dictionary type input and string input for serving

In [0]:
def convert_input_to_serving_input(inputs):
    out_dict = dict()
    for k, v in inputs.items():
        for in_seqs in v:
            chain_ids = ','.join(in_seqs[0])
            sequence = in_seqs[1]
            out_dict[k+'_'+chain_ids] = sequence
    out_str = ""
    for k,v in out_dict.items():
        out_str += k+':'+v+';'
    out_str = out_str.rstrip(';')
    return out_str

### set an example input use case

In [0]:
inputs = {
    'protein':[
        ( ('A'),"GTGAMWLTKLVLNPASRAARRDLANPYEMHRTLSKAVSRALEEGRERLLWRLEPARGLEPPVVLVQTLTEPDWSVLDEGYAQVFPPKPFHPALKPGQRLRFRLRANPAKRLAATGKRVALKTPAEKVAWLERRLEEGGFRLLEGERGPWVQILQDTFLEVRRKKDGEEAGKLLQVQAVLFEGRLEVVDPERALATLRRGVGPGKALGLGLLSVAP"),
    ],
    'rna': [
        ( ('B'), "UCCCCACGCGUGUGGGGAU")
    ]
}
params = {
    'msa': 'no_msa',
    'msa_depth': 20,
    'diffusion_samples': 1,
    'recycling_steps': 3,
    'sampling_steps': 200,
}

#### Test the model out

In [0]:
# serving enpoint expects this format...maybe just standardize to this? like everywhere?
model_input = {
    'input': convert_input_to_serving_input(inputs),
    'msa': 'no_msa',
    'use_msa_server': 'True'
}
print(model_input)

In [0]:
result = model.predict(context, [model_input])

#### See what the output looks like

In [0]:
import py3Dmol

view = py3Dmol.view(width=800, height=300)

view.addModel(
    result[0]['pdb'],
    'pdb'
)
view.setStyle({'chain': 'A'}, {'cartoon': {'color': 'blue'}})
view.setStyle({'chain': 'B'}, {'cartoon': {'color': 'red'}})

view.zoomTo()
html = view._make_html()
displayHTML(html)

### Let's log our model too

In [0]:
%sh
# move a copy of our code base to "local" machine and then register it with the model
# this will make a copy of our codebase that we can then install on the server for model serving
mkdir -p /local_disk0/dbboltz
cp -r ../dbboltz/src /local_disk0/dbboltz
cp ../dbboltz/pyproject.toml /local_disk0/dbboltz

In [0]:
%sh
ls /local_disk0/dbboltz
echo " -- "
ls /local_disk0/dbboltz/src/dbboltz

In [0]:
result[0]

In [0]:
from mlflow.types.schema import ColSpec, Schema
mlflow.set_registry_uri("databricks-uc")
from mlflow.models.signature import infer_signature

signature = infer_signature([model_input], result)
print(signature)

with mlflow.start_run(run_name='boltz'):
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=Boltz(),
        artifacts={
            'CACHE_DIR': '/Volumes/protein_folding/boltz/weights',
            'repo_path': '/local_disk0/dbboltz'
        },
        model_config=model_config,
        input_example=[model_input],
        signature=signature,
        conda_env='../envs/conda_env.yaml',
        registered_model_name="protein_folding.boltz.boltz"
    )