In [None]:
%matplotlib inline

In [None]:
from IPython.display import clear_output
from pathlib import Path
from PIL import Image
from tqdm import tqdm

import glob
import os
import pprint
import random
import shutil
import tarfile
import time
import torch
import torchvision

import matplotlib.pyplot as plt
import numpy as np
import onnxruntime as ort
import torch.nn.functional as F
import torchvision.transforms as T

In [None]:
def load_model(model, path):
    
    ckpt_dict = torch.load(path)
    
    model.load_state_dict(ckpt_dict['model_state_dict'])
    
    return ckpt_dict

In [None]:
def save_to_onnx(model, input_shape, path=None):
    bs, c, h, w = input_shape
    
    dummy_input = torch.randn(bs, c, h, w, device='cuda')
    
    model.to('cuda')
    
    input_names = [ "input" ]
    output_names = [ "output" ]
    
    if not path:
        path = model.__class__.__name__ + '.onnx'
    
    torch.onnx.export(model, dummy_input, path, verbose=True, input_names=input_names, output_names=output_names)
    
    return path

In [None]:
def model_evaluate(model, test_loader):
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        t = time.time()
        for i, (inputs, label) in enumerate(test_loader):
            inputs, label = inputs.to(device), label.to(device)
            outputs = model(inputs)
            
            _, predicted = torch.max(outputs.data, 1)
            print(label.size(0))
            total += label.size(0)
            correct += (predicted == label).sum().item()
            
            t = time.time() - t

            clear_output(wait=True)
            print("Batch     | Time(s)")
            print("-------------------")
            print(f"{i + 1:5d} / {len(test_loader):5d} | {int(t):7d}")

            
    accuracy = 100 * correct / total
            
    print(f'Accuracy of the network on {len(test_loader.dataset)} test images: {accuracy:.1f} %')
    return accuracy

In [None]:
hparams = {
    'lr': 3e-3,
    'num_classes': 16, # Don't change
    # Decide whether you want to add mixup training and how often it is run..
    'mixup': True,
    'mixup_pct': 0.90,
    # Automatic Mixed Precision, should speed up training.
    'use_amp': True,
    # Epochs - how long you want to train
    'epochs': 20,
    'start_epoch': 0,
    # Pin the memory, this should speed up training, but could make the kernel more stable
    'pin_memory': True, # Don't change
    # This is your batch size.
    'bs': 256,
    # Whether or not to use stochastic weight averaging. Setting to true should increase test accuracy.
    # If use_swa is set to False, then swa_start and swa_lr is not used.
    'use_swa': True,
    'swa_start': 17,
    'swa_lr': 5e-3,
    # How many epochs to train before saving model. If set to 0, checkpointing will not be performed.
    'checkpoint': 10,
    # Where your training data is stored.
    'train_root': '/workspace/data/LAICC_2023/training/', # Do not change
    # Where your test data is stored.
    'test_root': '/workspace/data/LAICC_2023/test/', # Do not change
    # This should be a number and how often you should save your data.
    'num_workers': 4,
    # If you want to restart training, change this to the path of the checkpoint you wish to start at.
    'ckpt_path': None, #'/workspace/models/ResNet_Lakota_Plants_1658958130.ckpt',
}

In [None]:
transforms = T.Compose(
    [#T.Resize(size=(246, 246)),
     T.CenterCrop([224]),
     T.AugMix(),
     #T.RandomPerspective(),
     T.RandomHorizontalFlip(p=0.5),
     T.RandomVerticalFlip(p=0.5),
     T.GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 1)),
     T.RandomRotation(degrees=(0, 10)),
     #T.Grayscale(num_output_channels=3),
     T.ToTensor(),
     T.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
    ])

test_transforms = T.Compose(
    [#T.Resize(size=(246, 246)),
     T.CenterCrop([224]),
     #T.Grayscale(num_output_channels=3),
     T.ToTensor(),
     T.Normalize(mean=torch.tensor([0.485, 0.456, 0.406]), std=torch.tensor([0.229, 0.224, 0.225])),
    ])

In [None]:
train_ds = torchvision.datasets.ImageFolder(root=hparams['train_root'], transform=transforms,
                                            target_transform=None, loader=Image.open,
                                            is_valid_file=None)

test_ds = torchvision.datasets.ImageFolder(root=hparams['test_root'], transform=test_transforms,
                                            target_transform=None, loader=Image.open,
                                            is_valid_file=None)

test_dl = torch.utils.data.DataLoader(test_ds,
                                      batch_size=hparams['bs'],
                                      shuffle=False,
                                      sampler=None,
                                      num_workers=hparams['num_workers'],
                                      persistent_workers=False,
                                      pin_memory=hparams['pin_memory'])

In [None]:
model = torchvision.models.resnet18(weights=None)

In [None]:
model.fc = torch.nn.Linear(model.fc.in_features, len(train_ds.classes))

In [None]:
chkpt_path = os.path.join('/workspace/models',
                          'ResNet_Lakota_Plants_1687817041.ckpt')

In [None]:
chkpt_dict = load_model(model, chkpt_path)

In [None]:
model.cuda()

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
acc = model_evaluate(model, test_dl)

In [None]:
device = torch.device('cpu')

In [None]:
model = model.to(device)

In [None]:
num_classes= len(train_ds.classes)

In [None]:
# IF you are satisfied with your models performance, save to ONNX
# You need to change your input shape. It should be of the form
# (batch_size, num_channels, height, width)
# You will get height and width from the transform
# T.Resize(size=(224, 224)), which would mean we would have
input_shape = (10, 3, 224, 224)


onnx_path = f'/workspace/models/{model.__class__.__name__}18_{num_classes}classes_{device}_{acc:.2f}_{int(time.time())}.onnx'

save_to_onnx(model, input_shape, path=onnx_path)

In [None]:
torch_path = f'/workspace/models/{model.__class__.__name__}18_{num_classes}classes_{device}_{acc:.2f}_{int(time.time())}.pt'

torch.save(model, torch_path)

In [None]:
model