# Run a training script on Azure Machine Learning Studio

If you have an Azure subscription, you can train the GAN network in Azure Machine Learning Studio. You can submit the scripts in ```src/``` folder as Azure ML jobs. 

## Before you start

You'll need the latest version of the **azure-ai-ml** package to run the code in this notebook.

In [None]:
pip install azure-ai-ml

In [None]:
pip show azure-ai-ml

## Set Up Azure ML Workspace, Compute Instance, and GPU Cluster

If you haven't create the required resources in Azure ML Studio, follow these steps in **Azure CLI**.
Below steps are for your reference, adjust the commands as necessary based on your Azure subscription.

---

### 1. Register the Azure Machine Learning Resource Provider
```sh
az provider register --namespace Microsoft.MachineLearningServices
```

### 2. Create a Resource Group and Set Defaults
```sh
RESOURCE_GROUP="rg-mlw-keras-gan"
REGION="eastus"

az group create --name $RESOURCE_GROUP --location $REGION
az configure --defaults group=$RESOURCE_GROUP
```

### 3. Create an Azure ML Workspace and Set Defaults
```sh
WORKSPACE_NAME="mlw-keras-gan"

az ml workspace create --name $WORKSPACE_NAME --location $REGION
az configure --defaults workspace=$WORKSPACE_NAME
```

### 4. Create a Compute Instance for Jupyter
```sh
COMPUTE_INSTANCE="ci-keras-gpu"

az ml compute create \
  --name $COMPUTE_INSTANCE \
  --type ComputeInstance \
  --size STANDARD_DS11_V2 \
  --location $REGION
```

### 5. Create a GPU Compute Cluster for Training Jobs

Azure GPU compute cluster:

| Azure GPU Compute Cluster Name   | GPUs (count) | GPU Model        | VRAM per GPU |
|----------------------------------|:------------:|------------------|:------------:|
| **Standard_NC6s_v3**             |      1       | NVIDIA Tesla V100|    16 GB     |
| **Standard_NC12s_v3**            |      2       | NVIDIA Tesla V100|    16 GB     |
| **Standard_NC24s_v3**            |      4       | NVIDIA Tesla V100|    16 GB     |



```sh
COMPUTE_CLUSTER="aml-gpu-cluster"

az ml compute create \
  --name $COMPUTE_CLUSTER \
  --type AmlCompute \
  --size STANDARD_ND40RS_V2 \
  --min-instances 0 \
  --max-instances 2 \
  --tier low_priority \
  --location $REGION
```

## Connect to your workspace

To connect to a workspace, we need identifier parameters - a subscription ID, resource group name, and workspace name. Since you're working with a compute instance, managed by Azure Machine Learning, you can use the default values to connect to the workspace.

In [None]:
from azure.identity import DefaultAzureCredential, InteractiveBrowserCredential
from azure.ai.ml import MLClient
from azure.ai.ml import command
from azure.ai.ml.entities import Environment

# ---- Auth ----
try:
    credential = DefaultAzureCredential()
    credential.get_token("https://management.azure.com/.default")
except Exception:
    credential = InteractiveBrowserCredential()

# Workspace from local aml_config/config.json
ml_client = MLClient.from_config(credential=credential)

## Use the Python SDK to train a model

To train a model, you'll first create the **diabetes_training.py** script in the **src** folder. The script uses the **diabetes.csv** file in the same folder as the training data.

Submit the job. 

In [None]:
# ---- Choose a curated TensorFlow GPU environment ----
# curated TensorFlow 2.16 + CUDA11
tf_env = Environment(
    name="tf-gpu-inline",
    image="mcr.microsoft.com/azureml/curated/tensorflow-2.16-cuda12:latest"
)


# ---- Define the command job ----
job = command(
    code="./src",  # your script is at repo root; change to "./src" if needed
    command=(
        "python train_gan_job.py "
        "--epochs 200 "
        "--batch-size 512 "
        "--noise-dim 100 "
        "--save-interval 5 "
        "--sample-count 25 "
        "--output-dir outputs"
    ),
    # Curated TF GPU environment
    environment=tf_env,
    compute="aml-gpu-cluster",
    display_name="keras-gan-tf-gpu",
    experiment_name="keras-gan-training",
    environment_variables={
        # Common TF GPU runtime flags; adjust/keep as needed
        "TF_FORCE_GPU_ALLOW_GROWTH": "true",
        "NCCL_DEBUG": "WARN",
        "TF_CPP_MIN_LOG_LEVEL": "2"
    },
)

# ---- Submit and monitor ----
returned_job = ml_client.create_or_update(job)
print("Monitor your job at:", returned_job.studio_url)
