# Train UNet Model

In this Notebook you will learn how to train your UNet architecture with Dataloop and Pytorch

UNet is an Encoder - Decoder architecture for creating segmentation maps

In [None]:
import json
import torch
import datetime
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import dtlpy as dl
from dtlpy.ml import train_utils
from dtlpy.ml.dataset_generators.torch_dataset_generator import DataGenerator
import tempfile

In [None]:
# clean up the GPU
import gc
gc.collect()
torch.cuda.empty_cache()

### Get the DataLoop entities

lets get the model and dataset entities from our dataloop platform

In [None]:
model = dl.models.get('unet')  # This is the global model
# Data entities
project = dl.projects.get('shefi-contests', '50f0fc03-4d70-455d-b485-c78cca53f2be')
dataset = dl.datasets.get('carvana', '61b9bbc1e8ad454a9aa7d285')

### Snapshot

Now we can create a new snapshot - we will add your name and data to the suffix to make the snapshot has a unique name

In [None]:
whoami = dl.client_api.info()['user_email']
now = datetime.datetime.now()

# Create a new snapshot - personally and with currect datetime
snapshot_name = f"carvana-train-example-{whoami.split('@')[0]}-{now.isoformat(timespec='minutes')}"
snapshot = model.snapshots.create(
    snapshot_name=snapshot_name,
    dataset_id=dataset.id,
    description='train unet example',
    bucket=project.buckets.create(bucket_type=dl.BucketType.ITEM, model_name=model.name, snapshot_name=snapshot_name),
    tags=['example', 'notebook'],
    configuration={'id_to_label_map': {'1': 'car'},
                   'image_normalize_mu': 0, 'image_normalize_std': 1,
                   'input_shape': [640, 960], 'batch_size': 2,
                   'num_epochs': 2},
    project_id=project.id,
    labels=['car']
)

### Lets View the Model and Snapshot entities

We use the to_df in order to convert to a DataFrame and view it

In [None]:
model.to_df()

In [None]:
snapshot.to_df()

### One last thing to make sure before we train

Our `adapter` train method expects the data to be organized as: train-validation-test  
this can be created manually on small datasets using `train_utils.create_dataset_partition()`

Our dataset is already prepared, we will just verify it

In [None]:
train_items = dataset.get_partitions(partitions=dl.SnapshotPartitionType.TRAIN)
val_items = dataset.get_partitions(partitions=dl.SnapshotPartitionType.VALIDATION)
test_items = dataset.get_partitions(partitions=dl.SnapshotPartitionType.TEST)

print(f"Dataset {dataset.name} Data partition, TRAIN: {train_items.items_count}, VALIDATION {val_items.items_count}, TEST {test_items.items_count} ")

### Finally we can start to train

We initialize the adapter using the `build` method.

The `Adapter` is the base class to connect between dataloop platform and our specific model  
some method are inheritance from the base adapter and some are written specifically per model
each architecture has it's own adapter which you can view it's raw code


In [None]:
adapter = model.build()
adapter.load_from_snapshot(snapshot=snapshot)
# adapter._set_adapter_handler('DEBUG')

In [None]:
root_path, data_path, output_path = adapter.prepare_training()

In [None]:
adapter.train(data_path=data_path, output_path=output_path,)


### SAVING

The current adapter now holds the best model fit for our data.

In order to upload the weights and other configurations we need to save our snapshot.  
We will use a temp dir - so we save all content to that dir and upload it (other option is to upload all the *`output_path`* which has more runtime files)

In [None]:
temp_dir = tempfile.mkdtemp(prefix=snapshot.name, suffix=now.strftime('%F-%H%M%S'))
adapter.save_to_snapshot(local_path=temp_dir)


## USING THE MODEL - PREDICTION

We will use the DataGenerator to view the image (this utility already connects with our dataloop item and annotations)


In [None]:
datagen = DataGenerator(data_path=os.path.join(data_path, 'validation'),
                        dataset_entity=snapshot.dataset,
                        annotation_type=dl.AnnotationType.SEGMENTATION,
)


In [None]:
# example - get 1 entry and visualize it
datagen.visualize(20)

### Data Item

Our data generator returns Data Item dictionary  
we can parse it to get the item and annotations


In [None]:
data_item = datagen[20]
print(f"To get the item_id from the dataItem ({data_item.keys()}) object we can use the ann json")
ann_json = json.load(open(data_item['annotation_filepath'], 'r'))
item_id = ann_json['id']
item = dl.items.get(item_id=item_id)

In [None]:
predictions = adapter.predict_items(items=[item], with_upload=False)
item_predictions = predictions[0]

In [None]:
# we can ignore the label 0 which usually uses for background
predictions = adapter.predict_items(items=[item], with_upload=False, with_bg=False)
item_sematic_preds = predictions[0]
item_sematic_preds.print()

In [None]:
# we can create polygons instead of sematic segmetations
predictions = adapter.predict_items(items=[item], with_upload=False, to_poly=True)
item_polygons_preds = predictions[0]
item_polygons_preds.print()

In [None]:
annotated = item_polygons_preds.show(data_item['image'], thickness=5 )
plt.imshow(annotated)