# Leaf Disease Classification Training

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR

In [2]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [3]:
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

## Training settings

In [4]:
class Args:
    def __init__(self):
        self.lr = 0.001
        self.epochs = 50
        self.batch_size = 64
        self.test_batch_size = 1000
        self.gamma = 0.7
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 100
        self.save_model = True

args = Args()

## Training

In [None]:
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder('../../data/leaf_disease/train',
                   transform=transforms.Compose([
                       transforms.Resize((64,64)),
                       transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) )
                   ])),
    batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder('../../data/leaf_disease/test', transform=transforms.Compose([
                       transforms.Resize((64,64)),
                       transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
                   ])),
    batch_size=args.test_batch_size, shuffle=True, **kwargs)

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, 9, bias=True)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
    train(args, model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)
    scheduler.step()

if args.save_model:
    torch.save(model.state_dict(), "leaf_stress_resnet50.pt")

### Testing

In [6]:
test(model, device, test_loader)


Test set: Average loss: 0.3171, Accuracy: 6154/6577 (94%)



In [9]:
torch.save(model.state_dict(), "leaf_stress_resnet50.pt")

In [26]:
import cv2
transform=transforms.Compose([
                       transforms.Resize((64,64)),
                       transforms.ToTensor(),
                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010) )
                   ])
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, 9, bias=True)
model.load_state_dict(torch.load('output/leaf_stress_resnet50.pt'))
model.eval()
img = cv2.imread('111.png')
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)
out = model(batch_t)
with open('leaf_disease_classes.txt') as f:
    classes = [line.strip() for line in f.readlines()]
_, index = torch.max(out, 1)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
print(classes[index[0]], percentage[index[0]].item())

Frogeye Spot 99.93589782714844


### Exporting PyTorch model to ONNX

In [None]:
from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx

In [None]:
torch.onnx.export(model,               # model being run
                  batch_t,               # model input (or a tuple for multiple inputs)
                  "leaf_disease.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes
                                'output' : {0 : 'batch_size'}})

#### Testing ONNX Inference

In [None]:
import onnxruntime

ort_session = onnxruntime.InferenceSession("leaf_stress_resnet50.onnx")

def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(batch_t)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")