In [1]:
import torch
import torchvision
import os
import matplotlib.pyplot as plt

from custom_dataset import FoodDataset
from torchvision.io import read_image
from torchvision.models import resnet18, alexnet

from torch.utils.data import ConcatDataset, DataLoader
from models import CustomModel
from utils import preprocess_image


In [2]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))

In [3]:
food_data_path = "../nn2_data/evaluation/food/"
food_data = [(1, read_image(food_data_path + filename, mode=torchvision.io.ImageReadMode.RGB).to(dtype=torch.float32, device=device)) for filename in os.listdir(food_data_path)]
#food_data = list(map(lambda x: (x[0], preprocess_image(x[1])), food_data))
non_food_data_path = "../nn2_data/evaluation/non_food/"
non_food_data = [(0, read_image(non_food_data_path + filename, mode=torchvision.io.ImageReadMode.RGB).to(dtype=torch.float32, device=device)) for filename in os.listdir(non_food_data_path)]
#non_food_data = list(map(lambda x: (x[0], preprocess_image(x[1])), non_food_data))
food_data.extend(non_food_data)
food_dataset = FoodDataset(food_data)
non_food_dataset = FoodDataset(non_food_data)
dataset = ConcatDataset([food_dataset, non_food_dataset])
dataloader = DataLoader(dataset, batch_size=1)


In [4]:
def get_accuracy(model, data, unsqueeze=False):
    counter = 0
    for label, image in data:
        image = preprocess_image(image)
        if unsqueeze:
            image = image.unsqueeze(0)
        predicted = model(image).argmax()
        if predicted == label:
            counter +=1
    return counter/len(data)

In [5]:
custom_model = CustomModel().to(device=device)
custom_model.load_state_dict(torch.load("custom_model.pt"))
custom_model.eval()
print(get_accuracy(custom_model, food_data))

  out = F.softmax(self.fc2(out))


0.787


In [6]:
resnet_model = resnet18(num_classes=2).to(device=device)
resnet_model.load_state_dict(torch.load("resnet.pt"))
resnet_model.eval()
get_accuracy(resnet_model, food_data, unsqueeze=True)


0.959

In [7]:
alexnet_model = alexnet().to(device=device)
alexnet_model.load_state_dict(torch.load("alexnet.pt"))
alexnet_model.eval()
get_accuracy(alexnet_model, food_data, unsqueeze=True)


0.921