# Training a model for a task

This notebook shows how to train a model for a task in the project. Model training is a slow, computationally heavy process which runs in background (job), so you can continue to interact with the platform, e.g. annotating more images, while the training job is still in progress. The notebook also shows how to monitor the progress of the job, and how to train with a different configuration.

## Initial setup

### Connect to the platform

First, set up the connection to the server. This is done by instantiating a Geti object, with the hostname (or ip address) and authentication details for the server. As in notebook [001 create_project](001_create_project.ipynb), the server details are stored in the `.env` file and are loaded in the cell below. For details on how to create the `.env` file, please see the [readme](README.md).

In [None]:
from geti_sdk.utils import get_server_details_from_env
from geti_sdk import Geti

geti_server_configuration = get_server_details_from_env()
geti = Geti(server_config=geti_server_configuration)

### Initialize a project

This example is based on a toy project `Cats and Dogs` that contains various annotated images of cats and dogs. If you want to use a different project, skip to the next section.

In [None]:
import os
from geti_sdk.demos.data_helpers.download_helpers import download_file
from geti_sdk.import_export import GetiIE
from geti_sdk.rest_clients import ProjectClient

dataset_archive_url = "https://storage.geti.intel.com/test-data/geti-sdk/datasets/cats-dogs-yolo-small.zip"
dataset_archive_folder = "/tmp"
dataset_archive = os.path.join(
    dataset_archive_folder, os.path.basename(dataset_archive_url)
)

if not os.path.exists(dataset_archive):
    download_file(url=dataset_archive_url, target_folder=dataset_archive_folder)

project_client = ProjectClient(session=geti.session, workspace_id=geti.workspace_id)
geti_ie = GetiIE(
    session=geti.session, workspace_id=geti.workspace_id, project_client=project_client
)

project = geti_ie.import_dataset_as_new_project(
    filepath=dataset_archive,
    project_name="Cats and Dogs",
    project_type="detection",
)

## Training preparation

#### Select the project to train

In [None]:
from geti_sdk.rest_clients import ProjectClient

# If you want to use another existing project, specify the name here
PROJECT_NAME = "Cats and Dogs"

project_client = ProjectClient(session=geti.session, workspace_id=geti.workspace_id)
project = project_client.get_project(project_name=PROJECT_NAME)

#### Setup the TrainingClient
To start and monitor training jobs, a `TrainingClient` needs to be created for the project:

In [None]:
from geti_sdk.rest_clients import TrainingClient

training_client = TrainingClient(
    session=geti.session, workspace_id=geti.workspace_id, project=project
)

#### Select a task to train
First thing to do is to select the task that we want to train. Let's go with the `detection` task in our project, which is the first trainable task in the pipeline. We will print a summary of the task to make sure we pick the right one

In [None]:
task = project.get_trainable_tasks()[0]
print(task.summary)

#### List the available algorithms
Now, let's list the available algorithms for this task. The training_client can be used for this:

In [None]:
available_algorithms = training_client.get_algorithms_for_task(task=task)
print(available_algorithms.summary)

Let's go with the algorithm that Geti uses by default for object detection. It is the `ATSS` algorithm, which is a larger and more accurate model than the `SSD` one. Because of its size it is also slower, but let's say we care most about accuracy for now.

We can get the default algorithms from the list of available algorithms. The cell below shows how it is done, and retrieves the default Detection algorithm.

In [None]:
algorithm = available_algorithms.get_default_for_task_type(task.type)

print(f"Default algorithm for `{task.type}` task: `{algorithm.name}`.\n")
print(algorithm.description)

#### Check the platform status
Before we start a new training round it may be a good idea to check the platform status, to make sure the project is not running another job already. In that case submitting a new job might not start training as expected, depending on what job is already running. The `training_client` can also be used to check the project status:

In [None]:
status = training_client.get_status()
print(status.summary)

## Training

#### Start training

At this point we can start the training, using the `training_client.train_task()` method. The method takes additional optional parameters such as `train_from_scratch` and `enable_pot_optimization`, but we will leave these to their default values (`False`) for now. The `train_task()` method will return a `Job` object, that we can use to monitor the training progress.

In [None]:
job = training_client.train_task(
    algorithm=algorithm,
    task=task,
)

#### Monitor the training process
Using the training_client and the training `job` we just started, we can monitor the training progress on the platform. The `training_client.monitor_job()` method can monitor the status of a job, and will update the job status every 15 seconds. Program execution is halted untill the job has completed (either successfully or cancelled/failed).

> **NOTE**: Because training the task will take quite a bit of time, you may want to interrupt the monitoring at some point. This can be done by selecting the cell in which the monitoring is running and pressing the 'Interrupt the kernel' (solid square) button at the top of the page, or by navigating to the 'Kernel' menu in the top menu bar and selecting 'Interrupt the kernel' there. This will not cancel the job on the platform, it will just abort the job progress monitoring in the notebook.

In [None]:
training_client.monitor_job(job);

## Trained model management
Once the training has finished successfully, we can set up a `ModelClient` for the project and use it to get the model that was trained in this particular job

In [None]:
from geti_sdk.rest_clients import ModelClient

model_client = ModelClient(
    session=geti.session, workspace_id=geti.workspace_id, project=project
)

To get the model information, simply pass the job to the `model_client.get_model_for_job()` method. Note that this will not download the actual model weights itself: Instead, it will return a `Model` object that holds all metadata for the model, such as the score it achieved on the test dataset, its creation date, the algorithm that it implements, etc. 

Trying to request the model while the training job is still running will result in a ValueError. In that case, please be patient and try again when the job is completed.

In [None]:
model = model_client.get_model_for_job(job)

if model is not None:
    print(model.overview)
else:
    print(
        f"Job '{job.name}' completed with status '{job.status.state}: "
        f"{job.status.message}', but did not result in a trained model. Most likely "
        f"the model training has failed, you could try restarting the training to "
        f"see if the problem persists."
    )