In [13]:
import os
import albumentations as A
import matplotlib.pyplot as plt
from collections import deque
from PIL import Image
import numpy as np
import cv2
import pandas as pd
from tqdm import tqdm

import torch

from model import CNN
from dataset import PlantDataset

DATA_PATH = '../data/'

# Load models

In [14]:
config = dict(
	batch_size=16,
	dataset='Plant',
	image_type='color_side',
	image_height=480, # 480, 960
	image_width=640, # 640, 1280
	load_run='unique-sponge-92', # solar-totem-91
)

In [15]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Model color_side
model = CNN().to(device)
if config['load_run'] is not None:
	filepath = f"./runs/{config['load_run']}.pth.tar"
	checkpoint = torch.load(filepath)
	model.load_state_dict(checkpoint['state_dict'], strict=True)

# Inference

In [16]:
resize = A.Compose(
	[
		A.Resize(height=config['image_height'], width=config['image_width']),
	],
)

In [17]:
model.eval()

total, correct = 0, 0

with torch.no_grad():
	dataset_valid = PlantDataset(set_dir='valid', transform=resize)
	loader_valid = torch.utils.data.DataLoader(dataset=dataset_valid, batch_size=config['batch_size'], shuffle=True, pin_memory=True, num_workers=2)

	for inputs_color, inputs_side, labels in tqdm(loader_valid, desc='CNN inference', dynamic_ncols=True):
		inputs_color = inputs_color.float().to(device)
		inputs_side = inputs_side.float().to(device)
		labels = labels.float().to(device)

		# Forward pass ➡
		preds = model(inputs_color, inputs_side).squeeze()

		# Accuracy
		total += labels.size(0)
		correct += (torch.round(preds) == torch.round(labels)).sum().item()

print('Accuracy ML:', correct/total)

CNN inference: 100%|██████████| 13/13 [00:37<00:00,  2.86s/it]

Accuracy ML: 0.9278350515463918



