# Backmap using deployed model

First, map the A2A GPCR atomistic structure to CG using Martini3 CG mapping

```bash
cgmap map -m martini3 -i data/A2A/md/a2a.pdb 
```

Then, let's backmap using a pre-trained deployed model that backmaps Martini3

```bash
herobm-backmap -i data/A2A/md/a2a.CG.pdb -mo deployed/martini3/protein.Sep.2025.pt -d cuda -o data/A2A/backmapped
```

# Train a new model on custom CG mapping

### Step 1: Build Dataset #

Build training dataset in npz format, using a MD trajectory as input source

In [None]:
def interactive_build_dataset():

    import subprocess
    import ipywidgets as widgets
    from IPython.display import display

    # Step 1: Create input widgets
    mapping_dropdown_w = widgets.Dropdown(
        options=['martini3', 'martini3.membrane', 'ca', 'custom'],  # List of options for the dropdown
        value='martini3',  # Default value
        description='CG Mapping',
        disabled=False,
    )

    # Create a text widget for custom param1 (disabled by default)
    custom_mapping_w = widgets.Text(
        value='',
        placeholder='Use absolute path to your mapping folder',
        description='Custom CG Mapping',
        disabled=False,  # Start as disabled, only enable if "Custom" is selected
        layout=widgets.Layout(display='none'),  # Initially hidden
    )

    input_w = widgets.Text(
        value='../data/tutorial/A2A/md/a2a.pdb',
        placeholder='PATH/TO/INPUT/FILE',
        description='Input file/folder',
        disabled=False
    )

    inputtraj_w = widgets.Text(
        value='../data/tutorial/A2A/md/a2a.xtc',
        placeholder='PATH/TO/INPUT/TRAJ',
        description='Input trajectory file/folder',
        disabled=False
    )

    selection_w = widgets.Text(
        value='protein',
        placeholder='selection',
        description='Atom selection',
        disabled=False
    )

    output_w = widgets.Text(
        value='../data/tutorial/A2A/npz/',
        placeholder='PATH/TO/OUTPUT/FILES/FOLDER',
        description='Output folder',
        disabled=False
    )

    trajslice_w = widgets.Text(
        value=':400',
        placeholder='E.g. 100:300:2',
        description='Traj Slice',
        disabled=False
    )

    # Button to trigger the script execution
    run_button = widgets.Button(description="Build Dataset")

    # Output area to display the results
    output_area = widgets.Output()

    # Function to enable/disable custom param1 input based on dropdown selection
    def on_mapping_change(change):
        if change['new'] == 'custom':
            custom_mapping_w.layout.display = 'block'  # Show the custom input
        else:
            custom_mapping_w.layout.display = 'none'  # Hide the custom input

    # Attach the function to handle changes in the dropdown
    mapping_dropdown_w.observe(on_mapping_change, names='value')

    def run_script(button):
        script_name = "../herobm/scripts/build_dataset.py"
        
        # Clear previous output
        output_area.clear_output()

        # Determine which value to use for param1
        param1_value = custom_mapping_w.value if mapping_dropdown_w.value == 'custom' else mapping_dropdown_w.value
        
        # Open the external script using Popen to stream stdout in real-time
        try:
            # Run the script
            with subprocess.Popen(
                [
                    "python", script_name,
                    "-m",  param1_value,
                    "-i",  input_w.value,
                    "-t",  inputtraj_w.value,
                    "-s",  selection_w.value,
                    "-o",  output_w.value,
                    "-ts", trajslice_w.value if len(trajslice_w.value) > 0 else None,
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,  # To capture text instead of bytes
                bufsize=1  # Line-buffered output
            ) as proc:
                # Read stdout line by line
                for line in proc.stdout:
                    with output_area:
                        print(line, end='')  # Print each line in the output area

        except Exception as e:
            with output_area:
                print(f"An error occurred: {e}")

    # Link the button click event to the function
    run_button.on_click(run_script)

    # Step 2: Display the widgets
    display(mapping_dropdown_w, custom_mapping_w, input_w, inputtraj_w, selection_w, output_w, trajslice_w, run_button, output_area)

interactive_build_dataset()

# Train Model #

To train the backmapping model, you need to provide a configuration file in YAML format. This file defines the dataset, the model to be used, and various hyperparameters. It also specifies the complete training setup, including the optimizer, learning rate, scheduler, loss function, metrics, and more.

Make sure to update the config file with the recommended settings from the previous step, where the training dataset was created.

Although this configuration file contains extensive details, for this tutorial, we will use a predefined one that already includes all the necessary information, including the dataset configuration.

In [None]:
def interactive_train():

    import os
    import torch
    import subprocess
    import ipywidgets as widgets
    from IPython.display import display

    # Step 1: Create input widgets
    paths = os.walk('config')
    options = []
    for path in paths:
        root, _, filenames = path
        options.extend([os.path.join(root, fn) for fn in filenames])
    config_dropdown_w = widgets.Dropdown(
        options=options,
        description='Training config file in yaml format',
        disabled=False,
    )

    # Step 1: Dynamically check available devices (CPU and multiple GPUs)
    device_options = ['cpu']  # Always include 'CPU'

    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()  # Get the number of available GPUs
        for i in range(num_gpus):
            device_options.append(f'cuda:{i}')  # Add each GPU as 'CUDA:0', 'CUDA:1', etc.

    device_dropdown_w = widgets.Dropdown(
        options=device_options,
        value='cpu',  # Default value
        description='Device',
        disabled=False,
    )

    # Button to trigger the script execution
    run_button = widgets.Button(description="Run Training")

    # Output area to display the results
    output_area = widgets.Output()

    def run_script(button):
        script_name = "geqtrain-train"
        
        # Clear previous output
        output_area.clear_output()

        # Open the external script using Popen to stream stdout in real-time
        try:
            # Run the script
            with subprocess.Popen(
                [
                    script_name,
                    config_dropdown_w.value,
                    "-d", device_dropdown_w.value,
                ],
                stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT,
                text=True,  # To capture text instead of bytes
                bufsize=1  # Line-buffered output
            ) as proc:
                # Read stdout line by line
                for line in proc.stdout:
                    with output_area:
                        print(line, end='')  # Print each line in the output area

        except Exception as e:
            with output_area:
                print(f"An error occurred: {e}")

    # Link the button click event to the function
    run_button.on_click(run_script)

    # Step 2: Display the widgets
    display(config_dropdown_w, device_dropdown_w, run_button, output_area)

interactive_train()

# Run Backmapping #

In [None]:
def interactive_run_backmapping():

    import os
    import torch
    import ipywidgets as widgets
    from IPython.display import display

    # Step 1: Select model using training config YAML
    paths = os.walk('./')
    options = []
    for path in paths:
        root, _, filenames = path
        options.extend([
            os.path.join(root, fn) for fn in filenames 
            if (fn.endswith('pt') or fn.endswith('pth')) and 'processed_datasets' not in root
        ])
    config_dropdown_w = widgets.Dropdown(
        options=options,
        description='model to use for backmapping',
        disabled=False,
    )

    # Step 1: Dynamically check available devices (CPU and multiple GPUs)
    device_options = ['cpu']  # Always include 'CPU'

    if torch.cuda.is_available():
        num_gpus = torch.cuda.device_count()  # Get the number of available GPUs
        for i in range(num_gpus):
            device_options.append(f'cuda:{i}')  # Add each GPU as 'CUDA:0', 'CUDA:1', etc.

    device_dropdown_w = widgets.Dropdown(
        options=device_options,
        value='cpu',  # Default value
        description='Device',
        disabled=False,
    )

    mapping_dropdown_w = widgets.Dropdown(
        options=['martini3', 'martini3.membrane', 'ca', 'custom'],  # List of options for the dropdown
        value='martini3',  # Default value
        description='CG Mapping',
        disabled=False,
    )

    # Create a text widget for custom param1 (disabled by default)
    custom_mapping_w = widgets.Text(
        value='',
        placeholder='Use absolute path to your mapping folder',
        description='Custom CG Mapping',
        disabled=False,  # Start as disabled, only enable if "Custom" is selected
        layout=widgets.Layout(display='none'),  # Initially hidden
    )

    input_w = widgets.Text(
        value='data/tutorial/A2A/md/a2a.pdb',
        placeholder='PATH/TO/INPUT/FILE',
        description='Input file/folder',
        disabled=False
    )

    inputtraj_w = widgets.Text(
        value='data/tutorial/A2A/md/a2a.xtc',
        placeholder='PATH/TO/INPUT/TRAJ',
        description='Input trajectory file/folder',
        disabled=False
    )

    isatomistic_w = widgets.Checkbox(
        value=True,  # Default value (checked, so True)
        description='Input is atomistic',  # Label for the checkbox
        disabled=False  # Whether the checkbox is interactive or not
    )

    selection_w = widgets.Text(
        value='protein',
        placeholder='selection',
        description='Atom selection',
        disabled=False
    )

    trajslice_w = widgets.Text(
        value='900:1000:10',
        placeholder='E.g. 100:300:2',
        description='Traj Slice',
        disabled=False
    )

    output_w = widgets.Text(
        value='data/tutorial/A2A/backmapped/',
        placeholder='Leave empty to save in same folder as input',
        description='Output folder',
        disabled=False
    )

    batch_max_atoms_w = widgets.Text(
        value='10000',
        placeholder='E.g. 10000',
        description='Max atoms per chunk',
        disabled=False
    )

    # Button to trigger the script execution
    run_button = widgets.Button(description="Run Inference")

    # Output area to display the results
    output_area = widgets.Output()

    # Function to enable/disable custom param1 input based on dropdown selection
    def on_mapping_change(change):
        if change['new'] == 'custom':
            custom_mapping_w.layout.display = 'block'  # Show the custom input
        else:
            custom_mapping_w.layout.display = 'none'  # Hide the custom input

    # Attach the function to handle changes in the dropdown
    mapping_dropdown_w.observe(on_mapping_change, names='value')

    def run_inference(button):
        
        # Clear previous output
        output_area.clear_output()

        args_dict = {
            "mapping": custom_mapping_w.value if mapping_dropdown_w.value == 'custom' else mapping_dropdown_w.value,
            "input": input_w.value,
            "inputtraj": inputtraj_w.value if len(inputtraj_w.value) > 0 else None,
            "isatomistic": isatomistic_w.value,
            "selection": selection_w.value,
            "trajslice": trajslice_w.value,
            "model": config_dropdown_w.value,
            "output": output_w.value,
            "device": device_dropdown_w.value,
            "batch_max_atoms": int(batch_max_atoms_w.value),
            "noinvariants": True,
        }

        from run_inference import run_backmapping
        run_backmapping(args_dict, bead_stats=args.bead_stats, tolerance=args.tolerance)

    # Link the button click event to the function
    run_button.on_click(run_inference)

    # Step 2: Display the widgets
    display(
        config_dropdown_w,
        device_dropdown_w,
        mapping_dropdown_w,
        custom_mapping_w,
        input_w,
        inputtraj_w,
        isatomistic_w,
        selection_w,
        output_w,
        run_button,
        output_area,
    )

interactive_run_backmapping()

# Backmapping ligands using custom CG mapping #

Try re-running this notebook to train a model for backmapping NECA molecule.

Use the following configuraitons:

1) Build dataset:
    - CG mapping: "custom"
    - Custom CG mapping: "/absolute/path/to/HEroBM/mappings/neca"
    - input: "data/tutorial/NECA/md/neca.gro"
    - inputtraj: "data/tutorial/NECA/md/neca.xtc"
    - selection: "all"
    - output: "data/tutorial/A2A/npz/"
    - trajslice: ":900"

2) Train:
    - config: "config/neca.train.yaml"

3) Backmap:
    - model: "results/neca/NECA.martini3/best_model.pth"
    - CG mapping: "custom"
    - Custom CG mapping: "/absolute/path/to/HEroBM/mappings/neca"
    - ...

    