# imports 

In [1]:
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy

import torch as th
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms

import pandas as pd
import numpy as np

from src.config import Config

from src.dataset import DataModule

from src.model import Model

import os
from tqdm.notebook import tqdm

# Load & export previously trained model for inference
* Load model checkpoint (.ckpt file)
* convert lightning module to torchScript (.pt/.pth/.bin etc.)
* save Script module 

In [2]:
os.listdir(Config.models_dir)

['leaf_disease_classifier-val_acc=0.52336-val_loss=1.17564.ckpt',
 'resnet34-cassava-leaf.pt',
 'leaf_disease_classifier-val_acc=0.79609-val_loss=0.57685.ckpt']

In [3]:
def Load_model(ckpt_path, device=None):
    config_dict = Config.__dict__.items()
    config_dict = dict([item for item in config_dict if '__' not in item[0]])
    
    loaded_model = Model(config=config_dict)
    
    loaded_model = loaded_model.load_from_checkpoint(ckpt_path)
    
    if device is None:
        loaded_model = loaded_model.cpu()
    else:
        loaded_model = loaded_model.cuda()
        
    return loaded_model.eval()

In [4]:
path = os.path.join(
    Config.models_dir, 'leaf_disease_classifier-val_acc=0.79609-val_loss=0.57685.ckpt'
)
trained_model = Load_model(ckpt_path=path, device='cuda')


In [5]:
def convert_to_script(model:Model, save=True):
    scriptModule = th.jit.script(obj=model)
    if save:
        fname = os.path.join(Config.models_dir, f'{Config.base_model}-cassava-leaf.pt')
        th.jit.save(
            m=scriptModule, 
            f=fname
        )
        print(f'[INFO] Script module saved as {fname}')
        
        return fname

In [6]:
model_path = convert_to_script(model=trained_model, save=True)

[INFO] Script module saved as /home/zeusdric/Dric/Zindi2020/Coding-Room/projects/cassava-leaf-disease/models/resnet34-cassava-leaf.pt


# Load saved script

In [7]:
inf_model = th.jit.load(f=model_path)
inf_model

RecursiveScriptModule(
  original_name=Model
  (encoder): RecursiveScriptModule(
    original_name=ResNet
    (conv1): RecursiveScriptModule(original_name=Conv2d)
    (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
    (act1): RecursiveScriptModule(original_name=ReLU)
    (maxpool): RecursiveScriptModule(original_name=MaxPool2d)
    (layer1): RecursiveScriptModule(
      original_name=Sequential
      (0): RecursiveScriptModule(
        original_name=BasicBlock
        (conv1): RecursiveScriptModule(original_name=Conv2d)
        (bn1): RecursiveScriptModule(original_name=BatchNorm2d)
        (act1): RecursiveScriptModule(original_name=ReLU)
        (conv2): RecursiveScriptModule(original_name=Conv2d)
        (bn2): RecursiveScriptModule(original_name=BatchNorm2d)
        (act2): RecursiveScriptModule(original_name=ReLU)
      )
      (1): RecursiveScriptModule(
        original_name=BasicBlock
        (conv1): RecursiveScriptModule(original_name=Conv2d)
        (bn1): Recursive

In [8]:
def evaluate_model(model:Model, dataloader:DataLoader):
    model.eval()
    model.cuda()
    accs = []
    bar = tqdm(dataloader, desc='Evaluating')
    with th.no_grad():
        for data in bar:
            xs, ys = data['img'], data['targets']
            logits = model(xs.cuda())
            preds = F.log_softmax(logits, dim=1)
            acc = accuracy(pred=preds.detach().cpu(), target=ys.detach().cpu())
            accs.append(acc.item())
            bar.set_postfix({
                "accuracy" : acc.item()
            })
            bar.refresh()
    
    return np.array(accs).mean

In [9]:
# get data module
train_df = pd.read_csv(os.path.join(Config.data_dir, 'train.csv'))
data_transform = {
    'train': transforms.Compose([
        transforms.Resize(size=(Config.resize, Config.resize)),
        transforms.RandomHorizontalFlip(p=.7),
        transforms.RandomVerticalFlip(p=.3),
        transforms.RandomRotation(degrees=25),
        transforms.CenterCrop(size=(Config.img_h, Config.img_w)),
        transforms.ColorJitter(brightness=(0.4, 1), contrast=.2, saturation=0, hue=0),
        transforms.GaussianBlur(kernel_size=3)
    ]),
    
    'validation':transforms.Compose([
        transforms.Resize(size=(Config.resize, Config.resize)),
        transforms.RandomRotation(degrees=25),
        transforms.CenterCrop(size=(Config.img_h, Config.img_w)),
        transforms.ColorJitter(brightness=(0.45, 1), 
                               contrast=.1, 
                               saturation=.1, 
                               hue=0.1),
        transforms.GaussianBlur(kernel_size=3)
    ]), 
    
    'test':transforms.Compose([
        transforms.Resize(size=(Config.img_h, Config.img_w)),
        transforms.RandomRotation(degrees=25),
    ])
    
}

dm = DataModule(config=Config, 
                 train_data_dir=Config.train_data_dir, 
                 test_data_dir=Config.test_data_dir, 
                 train_df=train_df,
                 data_transform=data_transform,
                 validation_split=.2,
                 train_frac = 1)
dm.setup()


[INFO] Training on 17117
[INFO] Validating on 4280


In [None]:
# Evaluate model
avg_acc = evaluate_model(model=inf_model, dataloader=dm.val_dataloader())
print(f'[INFO] AVerage accuracy : {avg_acc}')

Evaluating:   0%|          | 0/134 [00:00<?, ?it/s]