In [0]:
%pip install ../dbboltz[gpu]
%pip install py3Dmol
dbutils.library.restartPython()

In [0]:
import mlflow
mlflow.autolog(disable=True)
from dbboltz.boltz import run_boltz, Boltz
import yaml

In [0]:
def get_model_config():
    model_config = {}
    model_config['jackhmmer_binary_path'] = "/miniconda3/envs/jackhmmer_env/bin/jackhmmer"
    model_config['compute_type'] = 'gpu'
    return model_config

model_config = get_model_config()

In [0]:
model = Boltz()
context = mlflow.pyfunc.PythonModelContext(
    artifacts = {
        'CACHE_DIR': '/Volumes/peter_hawkins/testing/boltz'
    },
    model_config = model_config
)
model.load_context(context)

In [0]:
def convert_input_to_serving_input(inputs):
    out_dict = dict()
    for k, v in inputs.items():
        for in_seqs in v:
            chain_ids = ','.join(in_seqs[0])
            sequence = in_seqs[1]
            out_dict[k+'_'+chain_ids] = sequence
    out_str = ""
    for k,v in out_dict.items():
        out_str += k+':'+v+';'
    out_str = out_str.rstrip(';')
    return out_str

In [0]:
inputs = {
    'protein':[
        ( ('A'),"GTGAMWLTKLVLNPASRAARRDLANPYEMHRTLSKAVSRALEEGRERLLWRLEPARGLEPPVVLVQTLTEPDWSVLDEGYAQVFPPKPFHPALKPGQRLRFRLRANPAKRLAATGKRVALKTPAEKVAWLERRLEEGGFRLLEGERGPWVQILQDTFLEVRRKKDGEEAGKLLQVQAVLFEGRLEVVDPERALATLRRGVGPGKALGLGLLSVAP"),
    ],
    'rna': [
        ( ('B'), "UCCCCACGCGUGUGGGGAU")
    ]
}
params = {
    'msa': 'no_msa',
    'msa_depth': 20,
    'diffusion_samples': 1,
    'recycling_steps': 3,
    'sampling_steps': 200,
}

In [0]:
# serving enpoint expects this format...maybe just standardize to this? like everywhere?
model_input = {
    'input': convert_input_to_serving_input(inputs),
    'msa': 'no_msa',
    'use_msa_server': 'True'
}
print(model_input)

In [0]:
result = model.predict(context, [model_input]) #, params=params)

In [0]:
# model_input_w_params = model_input.copy()
# model_input_w_params.update({k:str(v) for k,v in params.items()})
# result = model.predict(context, [model_input_w_params])

In [0]:
# model_input_w_params = model_input.copy()
# model_input_w_params.update({k:str(v) for k,v in params.items()})
# model_input_w_params.update({
#     'msa':'mmseqs',
#     'use_msa_server':True
# })
# result = model.predict(context, [model_input_w_params])

In [0]:
import py3Dmol

view = py3Dmol.view(width=800, height=300)

view.addModel(
    result[0]['pdb'],
    'pdb'
)
view.setStyle({'chain': 'A'}, {'cartoon': {'color': 'blue'}})
view.setStyle({'chain': 'B'}, {'cartoon': {'color': 'red'}})

view.zoomTo()
html = view._make_html()
displayHTML(html)

### Let's log our model too

In [0]:
%sh
# move a copy of our code base to "local" machine and then register it with the model
# this will make a copy of our codebase that we can then install on the server for model serving
mkdir -p /local_disk0/dbboltz
cp -r ../dbboltz/src /local_disk0/dbboltz
cp ../dbboltz/pyproject.toml /local_disk0/dbboltz

In [0]:
%sh
ls /local_disk0/dbboltz
echo " -- "
ls /local_disk0/dbboltz/src/dbboltz

In [0]:
result[0]

In [0]:
from mlflow.types.schema import ColSpec, Schema
mlflow.set_registry_uri("databricks-uc")
from mlflow.models.signature import infer_signature

signature = infer_signature([model_input], result)
print(signature)

with mlflow.start_run(run_name='boltz'):
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=Boltz(),
        artifacts={
            'CACHE_DIR': '/Volumes/peter_hawkins/testing/boltz',
            'repo_path': '/local_disk0/dbboltz'
        },
        model_config=model_config,
        input_example=[model_input],
        signature=signature,
        conda_env='../envs/conda_env.yaml',
        registered_model_name="protein_folding.boltz.boltz"
    )

# NOTE:
 - I think the issue is going to be that though list of Dict[str,str] is allowed, that it expects to get all the field of the dict and for those to be set ahead of time, but what if I have protein_C,D in future?
 - I could have List[str]: 'protein_A,B:KJBSKJFBKJSB;rna_C:CAGCATAT;....'
 - internally start by splitting on ; to get list, then split those on : to get a dict, then continue as have it now


### Format would be:

{type}_{comma-sep chains}:{sequence};{repeat}

## Other option would be pass inputs through params, but I think this also doesn't play nice..

First serve attempt failed because of torch not being available to the flashattn module
I added that - let's see.
Note that flashattn takes ages to build...
 - I could add ninja in a conda env, might be faster, worry about build timeout for the env...?

 - if this one doesn't work with torch in pip - maybe move to conda to get something a bit more steady...

 ## I need to ensure torch is installed before flash-attn
  - need to have some things installed by conda and others byt pip

Now I put the params in the model_input...
 - but now it expects them as part of the schema as a minimal set of inputs?

In [0]:
signature

In [0]:
# If I only pass input:"" can I not pass anything else later?

# And apparently self.model_config is None? on model serving?

## Set model input example to be one with all options I want exposed at runtime set
  - all others should go to default
  - I think I should expose only:
    - msa
    - use_msa_server
    - sampling_steps 

  - This means users MUST supply those too...

# I need to pass model_config to the logging!!!
 - I also need to **check**
   - I cannot change those configs at run time

In [0]:
model_input_w_params