In [22]:
import os
import albumentations as A
import matplotlib.pyplot as plt
from collections import deque
from PIL import Image
import numpy as np
import cv2
import pandas as pd
from tqdm import tqdm

import torch

from model import CNN
from dataset import PlantDataset

DATA_PATH = '../data/'

# Load models

In [17]:
config = dict(
	batch_size=16,
	dataset='Plant',
	image_height=480, # 480, 960
	image_width=640, # 640, 1280
)

config_color = dict(
	image_type='color',
	load_run='radiant-wind-78',
)

config_side = dict(
	image_type='side',
	load_run='super-glitter-79',
)

In [19]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Model color
model_color = CNN().to(device)
if config_color['load_run'] is not None:
	filepath = f"./runs/{config_color['load_run']}.pth.tar"
	checkpoint = torch.load(filepath)
	model_color.load_state_dict(checkpoint['state_dict'], strict=True)

# Model side
model_side = CNN().to(device)
if config_side['load_run'] is not None:
	filepath = f"./runs/{config_side['load_run']}.pth.tar"
	checkpoint = torch.load(filepath)
	model_side.load_state_dict(checkpoint['state_dict'], strict=True)

# Inference

In [20]:
resize = A.Compose(
	[
		A.Resize(height=config['image_height'], width=config['image_width']),
	],
)

In [26]:
model_color.eval()
model_side.eval()

total_color, correct_color = 0, 0
total_side, correct_side = 0, 0
total, correct = 0, 0

with torch.no_grad():
	dataset_valid = PlantDataset(set_dir='valid', transform=resize)
	loader_valid = torch.utils.data.DataLoader(dataset=dataset_valid, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=2)

	for inputs_color, inputs_side, labels in tqdm(loader_valid, desc='CNN inference', dynamic_ncols=True):
		inputs_color = inputs_color.float().to(device)
		inputs_side = inputs_side.float().to(device)
		labels = labels.float().to(device)

		# Forward pass ➡
		preds_color = model_color(inputs_color).squeeze()
		preds_side = model_side(inputs_side).squeeze()
		preds = (preds_color + preds_side) / 2

		# Accuracy
		total_color += labels.size(0)
		correct_color += (torch.round(preds_color) == torch.round(labels)).sum().item()

		total_side += labels.size(0)
		correct_side += (torch.round(preds_side) == torch.round(labels)).sum().item()

		total += labels.size(0)
		correct += (torch.round(preds) == torch.round(labels)).sum().item()

print('Accuracy color:', correct_color/total_color)
print('Accuracy side:', correct_side/total_side)
print('Accuracy ML:', correct/total)

CNN inference: 100%|██████████| 13/13 [00:17<00:00,  1.32s/it]

Accuracy color: 0.9020618556701031
Accuracy side: 0.9226804123711341
Accuracy: 0.9329896907216495



