# Plantaris Data - Training

The training process was aimed to be simple enough to get the main idea of the whole process.
There are many optimizations and variations that can be done to the whole process, but they are out of the scope of this talk focused on beginners.

In summary, the model consist on:
* A pretrained [ResNet50](https://pytorch.org/docs/stable/torchvision/models.html) CNN
* Modification to the final layer to adopt the decision to the current categories
 *  A Sequential layer with: Linear, ReLU, Dropout, Linear, LogSoftmax
* A [negative log likelihood loss (NLLLoss)](https://pytorch.org/docs/stable/generated/torch.nn.NLLLoss.html), since it's useful for classification problem with N classes.
* [Adam optimization](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam).

In [3]:
import os
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets, models, transforms

data_dir = "data_all"

# Note that resnet50 requires 224 minimum
SIZE = 256

# Fixed seed
np.random.seed(30011986)

# Spliting the train and test data

This process relies on the `ImageFolder` function, to base both sets from the directories with our labeled images.
## Transformations
The transformations are not require to be the same for both sets, but the `Resize` and `Normalize` are required in in the *test* if it was used on the *train*, otherwise the verification steps will not work.
Since one of the features of detecting a plant that needs watering is that the leaves are a bit tilted, the *Rotate* transformation was not considered.
Due to the nature of these *real images*, a `CenterCrop` makes sense, to avoid all the noise from the rest of the image besides the plant.

In [4]:
def load_split_train_test(datadir, valid_size=0.2):

    # Transformations for Training
    transformations = [
        transforms.Resize((SIZE, SIZE)),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
    train_transforms = transforms.Compose(transformations)

    # Transformations for Testing
    test_transforms = transforms.Compose(
        [
            transforms.Resize((SIZE, SIZE)),
            transforms.ToTensor(),
            #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    train_data = datasets.ImageFolder(datadir, transform=train_transforms)
    test_data = datasets.ImageFolder(datadir, transform=test_transforms)

    # Splitting the data
    num_train = len(train_data)
    indices = list(range(num_train))
    split = int(np.floor(valid_size * num_train))

    np.random.shuffle(indices)

    train_idx, test_idx = indices[split:], indices[:split]
    train_sampler = SubsetRandomSampler(train_idx)
    test_sampler = SubsetRandomSampler(test_idx)
    trainloader = torch.utils.data.DataLoader(
        train_data, sampler=train_sampler, batch_size=64
    )
    testloader = torch.utils.data.DataLoader(
        test_data, sampler=test_sampler, batch_size=64
    )

    return trainloader, testloader

In [5]:
trainloader, testloader = load_split_train_test(data_dir)
print("Clases", trainloader.dataset.classes)
print("Length trainloader:", len(trainloader))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True)

# Freeze parameters, to avoid backpropagation through them.
for param in model.parameters():
    param.requires_grad = False

# Modifying the final layer of the CNN
model.fc = nn.Sequential(
    nn.Linear(2048, 512),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(512, 3),
    nn.LogSoftmax(dim=1),
)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.003)
model.to(device)

epochs = 20
running_loss = 0
print_every = 2
train_losses, test_losses = [], []

Clases ['ok', 'other', 'watering']
Length trainloader: 14


Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /Users/mariajosemolina/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth
100.0%


In [None]:
print("Training...")
time_start = time.time()
for epoch in range(epochs):
    for steps, (inputs, labels) in enumerate(trainloader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        logps = model.forward(inputs)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            model.eval()
            with torch.no_grad():
                for inputs, labels in testloader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    logps = model.forward(inputs)
                    batch_loss = criterion(logps, labels)
                    test_loss += batch_loss.item()

                    ps = torch.exp(logps)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()

            train_losses.append(running_loss / len(trainloader))
            test_losses.append(test_loss / len(testloader))
            print(
                f"Epoch {epoch+1}/{epochs}.. "
                f"Train loss: {running_loss/print_every:.3f}.. "
                f"Test loss: {test_loss/len(testloader):.3f}.. "
                f"Test accuracy: {accuracy/len(testloader):.3f}"
            )
            running_loss = 0
            model.train()

print("Training time:", time.time() - time_start)

In [None]:
# Saving the model
if os.path.isfile(f"trained_model_{data_dir}.pth"):
    print("The trained model file exists, creating a second")
    torch.save(model, f"trained_model_{data_dir}.pth.2")
else:
    torch.save(model, f"trained_model_{data_dir}.pth")

In [None]:
# Save figure
plt.plot(train_losses, label="Training loss")
plt.plot(test_losses, label="Validation loss")
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend(frameon=False)
plt.savefig("Figure_all2.png")