<a href="https://www.nvidia.com/dli"> <img src="images/DLI_Header.png" alt="Header" style="width: 400px;"/> </a>

# 2.0 Fine-Tune Pretrained Model with Your Dataset

In this notebook, you'll fine-tune a pretrained model with your synthetic data that is ready for deployment.

**[2.1 Learning Objectives](#2.1-Learning-Objectives)<br>**
**[2.2 Getting Our Data Ready](#2.2-Getting-Our-Data-Ready)<br>**
**[2.3 Create Class](#2.3-Create-Class)<br>**
**[2.4 Create Helper Functions](#2.4-Create-Helper-Functions)<br>**
**[2.5 Create Model and Train](#2.5-Create-Model-and-Train)<br>**

---
## 2.1 Learning Objectives

Training on synthetic data works the same as training on real data. You can plug your new custom synthetic dataset into your existing training workflow.

In this example, we will train our data over an object detection model from Torchvision, [`fasterrcnn_resnet50`](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.fastercnn_resnet50_fpn.html).  The model has been _pretrained_ on an Imagenet dataset, meaning it already is able to recognize a group of objects.  Our goal is to further train the model, or _fine-tune_ it, with our own custom fruit dataset, so that it will recognize our custom objects.

<center><video controls src="https://dli-lms.s3.amazonaws.com/assets/s-ov-10-v1/DLI_part_5.mp4" width=800 ></center>

---
## 2.2 Getting Our Data Ready

We do some preliminary prep on our synthetic data to get it ready for the model training.


In [None]:
from PIL import Image
import os
import numpy as np
import torch
import torch.utils.data
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms as T
import json
import shutil

We start by defining the epochs and classes for our training script. We have 10 total fruits we wish to identify in the images and will start by running the training for 15 epochs. After you are done with the initial training, feel free to change the number of epochs to see how it changes your loss function value.

In [None]:
epochs = 15
num_classes = 10

We can navigate to our data that we generated by opening up a terminal in our JupyterLab window. Press the "+" button in the top left to access a new terminal. Our data generation script defaults to `/dli/task/data/fruit_data_$DATE`. For now, we have the data directory set to an example dataset you may choose to use. 

In [None]:
data_dir = "/dli/task/data/fruit_data"

Next, we define our output directory, which is where we will save our PyTorch model. There is also an example model saved to `/dli/task/data/model.pth`.

In [None]:
output_file = "/dli/task/model.pth"

In our system today, we are using an NVIDIA GPU. This gives us a powerful compute engine for training and state of the art tech for our graphics applications as well. Run the next cell to see the specs for the GPU.

In [None]:
!nvidia-smi

We define our device for the training to make sure to use the GPU we have available.

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

---
## 2.3 Create Class

In the next cell, we define our `FruitDataset` class and the data loader for the training. In the class definition, there are comments to explain each step.  Please review these and the code.

In [None]:
class FruitDataset(torch.utils.data.Dataset):
    # This function is run once when instantiating the Dataset object
    def __init__(self, root, transforms):
        self.root = root
        self.transforms = transforms

        # In the first portion of this code we are taking our single dataset folder 
        # and splitting it into three folders based on the file types.
        # This is just a preprocessing step.
        list_ = os.listdir(root)
        for file_ in list_:
            name, ext = os.path.splitext(file_)
            ext = ext[1:]
            if ext == '':
                continue

            if os.path.exists(root+ '/' + ext):
                shutil.move(root+'/'+file_, root+'/'+ext+'/'+file_)

            else:
                os.makedirs(root+'/'+ext)
                shutil.move(root+'/'+file_, root+'/'+ext+'/'+file_)

        self.imgs = list(sorted(os.listdir(os.path.join(root, "png"))))
        self.label = list(sorted(os.listdir(os.path.join(root, "json"))))
        self.box = list(sorted(os.listdir(os.path.join(root, "npy"))))
        # We have our three attributes with the img, label, and box data

    # Loads and returns a sample from the dataset at the given index idx
    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "png", self.imgs[idx])
        img = Image.open(img_path).convert("RGB")

        label_path = os.path.join(self.root, "json", self.label[idx])

        with open(os.path.join('root', label_path), "r") as json_data:
            json_labels = json.load(json_data)
        
        box_path = os.path.join(self.root, "npy", self.box[idx])
        dat = np.load(str(box_path))   

        boxes = []
        labels = []
        for i in dat:
            obj_val = i[0]
            xmin = torch.as_tensor(np.min(i[1]), dtype=torch.float32)
            xmax = torch.as_tensor(np.max(i[3]), dtype=torch.float32)
            ymin = torch.as_tensor(np.min(i[2]), dtype=torch.float32)
            ymax = torch.as_tensor(np.max(i[4]), dtype=torch.float32)
            if (ymax > ymin) & (xmax > xmin):
                boxes.append([xmin, ymin, xmax, ymax])
                area = (xmax - xmin) * (ymax - ymin)
            labels += [json_labels.get(str(obj_val)).get('class')]

        label_dict = {}

        # Labels for the dataset
        static_labels = {
            'apple' : 0,
            'avocado' : 1,
            'kiwi' : 2,
            'lime' : 3,
            'lychee' : 4,
            'pomegranate' : 5,
            'onion' : 6,
            'strawberry' : 7,
            'lemon' : 8,
            'orange' : 9
        }

        labels_out = []
        # Transforming the input labels into a static label dictionary to use
        for i in range(len(labels)):
            label_dict[i] = labels[i]

        for i in label_dict:
            fruit = label_dict[i]
            final_fruit_label = static_labels[fruit]
            labels_out += [final_fruit_label]

        target = {}
        target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
        target["labels"] = torch.as_tensor(labels_out, dtype=torch.int64)
        target["image_id"] = torch.tensor([idx]) 
        target["area"] = area

        if self.transforms is not None:
            img= self.transforms(img)
        return img, target

    # Finally we have a function for the number of samples in our dataset
    def __len__(self):
        return len(self.imgs)

---
## 2.4 Create Helper Functions 

Next, we define a function for the feature and label transformations we wish to perform. We are converting to `Tensor` objects and also converting the `dtypes`.

In [None]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    return T.Compose(transforms)

Create a function to collate our samples. 

In [None]:
def collate_fn(batch):
    return tuple(zip(*batch))

---
## 2.5 Create Model and Train

Next, we go through the process of actually creating our model. We are starting with the pretrained (default weights) object detection `fasterrcnn_resnet50` model from Torchvision.

In [None]:
def create_model(num_classes): 
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
    return model

At this point, we are ready to create our dataset by using our custom `FruitDataset` class and our synthetic data. This is then passed into our `DataLoader`.

In [None]:
dataset = FruitDataset(data_dir, get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=16, shuffle=True, collate_fn= collate_fn) 

Next, we create our model with the 10 classes we have of fruit and transfer it to the GPU for training. We use [PyTorch SDG](https://pytorch.org/docs/stable/generated/torch.optim.SGD.html) (stochastic gradient descent) as the optimizer.

In [None]:
model = create_model(num_classes)
model.to(device)
    
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.001)
len_dataloader = len(data_loader)

Now we can actually train our model. We keep track of our loss and print it out as we train.

In [None]:
model.train()
ep = 0
for epoch in range(epochs):
    optimizer.zero_grad()
    ep += 1
    i = 0    
    for imgs, annotations in data_loader:
        i += 1
        imgs = list(img.to(device) for img in imgs)
        annotations = [{k: v.to(device) for k, v in t.items()} for t in annotations]
        loss_dict = model(imgs, annotations)
        losses = sum(loss for loss in loss_dict.values())

        losses.backward()
        optimizer.step()

        print(f'Epoch: {ep} Iteration: {i}/{len_dataloader}, Loss: {losses}')

Our final step is to save the model!

In [None]:
torch.save(model, output_file)

---
<h2 style="color:green;">Congratulations!</h2>

In this notebook, you have:
- Set up a class for your dataset and data loader
- Used your synthetic data to train a model
- Saved your model

Move on to the [Export notebook](3_export.ipynb)

<a href="https://www.nvidia.com/dli"> <img src="images/DLI_Header.png" alt="Header" style="width: 400px;"/> </a>