In [1]:
# How to perform validation?

In [3]:
from src.data_preparation.preprocessing.data_util import load_config_file
from src.model.model import SqueezeNet
import torch

In [4]:
# Load config.yaml file
config_content = load_config_file()

In [5]:
# Prepare model out from config.yaml file
device_str = config_content["train"]["device"] # Name of device, CPU or GPU
device = torch.device(device_str) # CPU or GPU
num_classes = config_content["model"]["num_classes"] # Number of classes
model = SqueezeNet(num_classes=num_classes).to(device) # Model instance

In [7]:
# Load weights
weights = config_content["weights"]
state_dict = torch.load(weights, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.eval()

SqueezeNet(
  (features): Sequential(
    (0): Conv2d(3, 96, kernel_size=(7, 7), stride=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    (3): FireBlock(
      (squeeze): Conv2d(96, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand_activation): ReLU(inplace=True)
    )
    (4): FireBlock(
      (squeeze): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
      (expand3x3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (expand_activation): ReLU(inplace=True)
    )
    (5): FireBlock(
      (squeeze): Conv2d(128, 32, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_activatio

In [8]:
import torch.nn as nn
from src.data_preparation.data_handler.datahandler import DataHandler
from src.process.validate import validate_one_epoch

criterion = nn.CrossEntropyLoss()
handler = DataHandler()
val_loader = handler.get_dataloader('val')
val_metrics = validate_one_epoch(model=model, loader=val_loader, criterion=criterion, device=device, num_classes=num_classes)
print("VAL metrics:", val_metrics)


                                               

VAL metrics: {'loss': 0.04265262186527252, 'accuracy': 0.9917763471603394, 'precision': 0.8706141114234924, 'recall': 0.8706141114234924, 'f1': 0.8696115612983704, 'f1_w': 0.9917763471603394}


