In [None]:
import azureml
from IPython.display import display, Markdown
from azureml.core import Datastore, Experiment, ScriptRunConfig, Workspace, RunConfiguration
from azureml.core.dataset import Dataset
from azureml.core.environment import Environment
from azureml.core.runconfig import DockerConfiguration


from model_drift import settings, helpers

# check core SDK version number
print("Azure ML SDK Version: ", azureml.core.VERSION)

In [None]:
# Connect to workspace
ws = Workspace.from_config(settings.AZUREML_CONFIG)

In [None]:
# Name experiement
input_dataset_name="padchest"
experiment_name = 'train-vae'
env_name = "monitoring"

compute_target = "nc24-uswest2"

#Experiment
exp = Experiment(workspace=ws, name=experiment_name)

#Environment
environment_file = settings.CONDA_ENVIRONMENT_FILE
project_dir = settings.SRC_DIR
pytorch_env = Environment.from_conda_specification(env_name, file_path =str(environment_file))
pytorch_env.register(workspace=ws)
build = pytorch_env.build(workspace=ws)
pytorch_env.environment_variables["RSLEX_DIRECT_VOLUME_MOUNT"] = "True"

# Run Configuration
run_config = RunConfiguration()
run_config.environment_variables["RSLEX_DIRECT_VOLUME_MOUNT"] = "True"
run_config.environment = pytorch_env
run_config.docker = DockerConfiguration(use_docker=True, shm_size="100G")

# Note: COMMON RUNTIME has a bug where print statements sometimes, disappear.
# Set this flag as a workaround to use the legacy runtime.
run_config.environment_variables["AZUREML_COMPUTE_USE_COMMON_RUNTIME"] = "false"

# Input Dataset
dataset = Dataset.get_by_name(ws, name=input_dataset_name)

args = {
    'dataset': input_dataset_name,
    'data_folder': dataset.as_named_input('dataset').as_mount(),
    'run_azure':  1,
    'output_dir': './outputs',

    'frontal_only': 1,
    'val_frontal_only': 0,
    'ignore_nonfrontal_loss': 1,

    'batch_size': 32,
    'base_lr': 0.0001,
    'image_size': 128,

    'max_epochs': 50,
    'num_workers': -1,

    'progress_bar_refresh_rate': 25,
    'log_every_n_steps': 25,
    'flush_logs_every_n_steps': 25,

    'accelerator': 'ddp',
    'channels': 1,
    'normalize': False,

    'step_size': 3,
    'lr_scheduler': 'plateau',
    'auto_scale_batch_size': False,
    'auto_lr_find': False,

    'width': 320,
    'z': 64,
    'layer_count': 3,
    'terminate_on_nan': True,
    'log_recon_images': 32
}

hyperparameters = {
    "base_lr": 0.0001, 
    "batch_size": 32/4, 
    "image_size": 128, 
    "kl_coeff": 0.1, 
    "layer": 4, 
    "width": 240, 
    "z": 128
    }

args.update(hyperparameters)

print("args:")
for k, v in sorted(args.items()):
    print(f" {k}: {v}")

config = ScriptRunConfig(
    source_directory=str(project_dir),
    script="scripts/vae/train.py",
    arguments=helpers.argsdict2list(args),
)

config.run_config = run_config
config.run_config.target = compute_target

run = exp.submit(config)
display(Markdown(f"""
- Environment: {pytorch_env.name}
- Experiment: [{run.experiment.name}]({run.experiment.get_portal_url()})
- Run: [{run.display_name}]({run.get_portal_url()})
- Target: {config.run_config.target}
"""))
