In [1]:
import os, sys
from IPython.display import HTML, display

from main import main

In [2]:
models = {
    'deeplabv3plus_mobilenet': {'batch_size': 32, 'val_batch_size': 32},
    'deeplabv3plus_xception': {'batch_size': 16, 'val_batch_size': 16},
    'deeplabv3plus_resnet101': {'batch_size': 16, 'val_batch_size': 16},
    'segmenter_vit_base': {'batch_size': 1, 'val_batch_size': 1},
    'mae_segmenter_vit_base': {'batch_size': 2, 'val_batch_size': 2},
    'mae_segmenter_vit_huge': {'batch_size': 1, 'val_batch_size': 1}
}

In [3]:
def train(model, use_ckpt=False):
    config = models[model]
    ckpt = f"checkpoints/latest_{model}_cityscapes_os16.pth"
    command = f"python main.py --model {model} --dataset cityscapes --gpu_id 0 --total_epochs 100 --base_lr 0.01 --loss_type focal_loss --crop_size 768 --batch_size {config['batch_size']} --val_batch_size {config['val_batch_size']} --use_amp --output_stride 16 --data_root ./datasets/data/cityscapes"
    if use_ckpt and os.path.exists(ckpt):
        command += f"--ckpt {ckpt} --continue_training"
    sys.argv = command.split()[1:]
    print("--------------------------------------------------")
    print(f"Train model {model} with command:")
    print(command)
    print("--------------------------------------------------")
    main()
    
def test(model):
    config = models[model]
    ckpt = f"checkpoints/best_{model}_cityscapes_os16.pth"
    if not os.path.exists(ckpt):
        print(f"Test model failed, cannot find trained weights {ckpt}")
    else:
        command = f"python main.py --model {model} --dataset cityscapes --gpu_id 0 --val_batch_size {config['val_batch_size']} --use_amp --output_stride 16 --data_root ./datasets/data/cityscapes --test_only --ckpt {ckpt}"
        sys.argv = command.split()[1:]
        print("--------------------------------------------------")
        print(f"Evaluate model {model} with command:")
        print(command)
        print("--------------------------------------------------")
        main()

In [4]:
user_selection = {
    'model': 'mae_segmenter_vit_base',
    'use_ckpt': 'false',
    'action': 'test'
}

dropdown_options = ""
for model in models.keys():
    dropdown_options += f'<option value="{model}">{model}</option>'

dropdown_menu = f'''
<div style="margin-bottom: 10px;">
  <label for="model">Select Model:</label>
  <select id="model">
    {dropdown_options}
  </select>
</div>
<div style="margin-bottom: 10px;">
  <label for="ckpt">Use Checkpoint:</label>
  <select id="ckpt">
    <option value="true">True</option>
    <option value="false">False</option>
  </select>
</div>
<div style="margin-bottom: 10px;">
  <label for="action">Select Action:</label>
  <select id="action">
    <option value="train">Train</option>
    <option value="test">Test</option>
    <option value="train_test">Train and Test</option>
  </select>
</div>
'''

dropdown_menu += '''
<script>
  document.getElementById('model').onchange = function() {
    IPython.notebook.kernel.execute('user_selection["model"] = "' + this.value + '"');
  };
  document.getElementById('ckpt').onchange = function() {
    IPython.notebook.kernel.execute('user_selection["use_ckpt"] = "' + this.value + '"');
  };
  document.getElementById('action').onchange = function() {
    let ckpt = document.getElementById('ckpt')
    if (this.value === 'test') {
      ckpt.value = 'false'
      ckpt.disabled = true
    } else {
      ckpt.disabled = false
    }
    IPython.notebook.kernel.execute('user_selection["action"] = "' + this.value + '"');
  };
</script>
'''

display(HTML(dropdown_menu))

In [None]:
if "train" in user_selection["action"]:
    train(user_selection["model"], user_selection["use_ckpt"])
if "test" in user_selection["action"]:
    test(user_selection["model"])