# Task: Text Generation

In this Notebook we'll run Archai's [Text Generation](https://github.com/microsoft/archai/tree/main/tasks/text_generation) task on Azure Machine Learning.

## Prerequisites

- Python 3.7 or later
- An Azure subscription
- An Azure Resource Group
- An Azure Machine Learning [Workspace](https://learn.microsoft.com/en-us/azure/machine-learning/quickstart-create-resources#create-the-workspace)

### Requirements

In [None]:
%pip install azure-ai-ml azure-identity 
%pip install jinja2
%pip install archai

In [None]:
import os
from pathlib import Path

from IPython.display import display, Image
from IPython.core.display import HTML

import archai.common.azureml_helper as aml_helper
import archai.common.notebook_helper as nb_helper

### Get a handle to the workspace

We load the workspace from a workspace [configuration file](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-configure-environment#local-and-dsvm-only-create-a-workspace-configuration-file).

In [None]:
ml_client = aml_helper.get_aml_client_from_file()

### Create a CPU compute cluster

We provision a Linux [compute cluster](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-create-attach-compute-cluster?tabs=python) fos the NAS job in this Notebook. See the [full list](https://azure.microsoft.com/en-ca/pricing/details/machine-learning/) on VM sizes and prices.

In [None]:
cpu_compute_name = "nas-cpu-cluster-D14-v2"
aml_helper.create_compute_cluster(ml_client, cpu_compute_name, size="Standard_D14_v2")

### Create a GPU compute cluster

For full training we provision a GPU compute cluster.

In [None]:
gpu_compute_name = "nas-gpu-cluster-NC6"
aml_helper.create_compute_cluster(ml_client, gpu_compute_name, size="Standard_NC6")

### Create an environment based on a YAML file

In [None]:
archai_job_env = aml_helper.create_environment_from_file(ml_client, conda_file="conda.yaml")

### Job 1: NAS (Searching for Pareto-optimal Architectures)

#### Loading the search job from a YAML file and running it

In [None]:
search_job = aml_helper.create_job_from_file(source=os.path.join("src", "search.yaml"))
s_job = aml_helper.run_job(ml_client, search_job)

#### Stream logs of the job

In [None]:
aml_helper.stream_job_logs(ml_client, s_job)

#### Download job's output

In [None]:
output_name = "output_dir"
download_path = "output"

aml_helper.download_job_output(ml_client, job_name=s_job.name, output_name=output_name, download_path=download_path)

downloaded_folder = Path(download_path) / "named-outputs" / output_name

#### Show Pareto Frontiers

In [None]:
param_vs_latency_img = Image(filename=downloaded_folder / "pareto_non_embedding_params_vs_onnx_latency.png")
display(param_vs_latency_img)

In [None]:
param_vs_memory_img = Image(filename=downloaded_folder / "pareto_non_embedding_params_vs_onnx_memory.png")
display(param_vs_memory_img)

In [None]:
latency_vs_memory_img = Image(filename=downloaded_folder / "pareto_onnx_latency_vs_onnx_memory.png")
display(latency_vs_memory_img)

#### Show search state of the last iteration

In [None]:
df = nb_helper.get_search_csv(downloaded_folder)
csv_as_html = nb_helper.get_csv_as_stylized_html(df)
display(HTML(csv_as_html))

### Job 2: Train (Train a Pareto architecture from Transformer-Flex.)

#### Pick an architecture id (archid) from the CSV file to perform full training on

In [None]:
archid = "<arch-id>"
arch_path = nb_helper.get_arch_abs_path(archid=archid, downloaded_folder=downloaded_folder)

#### Load the training job from a YAML file and set the arch_path as its input

In [None]:
train_job = aml_helper.create_job_from_file(source=os.path.join("src", "train.yaml"))
train_job.inputs.pareto_config_path.path = arch_path

In [None]:
t_job = aml_helper.run_job(ml_client, train_job)

#### Stream logs of the job

In [None]:
aml_helper.stream_job_logs(ml_client, t_job)

### Job 3: Generating text via prompt

#### Loading the generate text job from a YAML file and running it

In [None]:
gen_job = aml_helper.create_job_from_file(source=os.path.join("src", "generate_text.yaml"))
g_job = aml_helper.run_job(ml_client, gen_job)
aml_helper.stream_job_logs(ml_client, g_job)