In [1]:
# mainly from https://github.com/borutb-fri/FMLD/blob/main/mask-test.py
# Download model from:
# https://unilj-my.sharepoint.com/:u:/g/personal/borutb_fri1_uni-lj_si/EdmDsIgG9XBJkRVXDKyOwvEBK7pK1EEq9cBfOVm3kLzPvw?e=M9pULa
# model also from https://github.com/borutb-fri/FMLD/

import os
import torch
from torchvision import transforms, datasets
import torchvision
from torch.utils.data import DataLoader
from torch import nn

In [13]:
def get_ds(root):
    # Applying Transforms to the Data
    image_transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406],
                                 [0.229, 0.224, 0.225])
    ])
    
    directories = {
        type_: os.path.join(root, type_) for type_ in next(os.walk(root))[1]
    }
    types = directories.keys()

    # Batch size
    bs = 128

    # Number of classes
    num_classes = 2
crop
    # Load Data from folders
    ds = {
        type_: datasets.ImageFolder(root=directories[type_], transform=image_transform)
                                    for type_ in  directories
    }
    
    dls = {type_: torch.utils.data.DataLoader(ds[type_], batch_size=bs, shuffle=True, num_workers=4)
              for type_ in types
    }
    dls = {type_: {"dl": dls[type_], "size": len(ds[type_])} for type_ in types}
    return dls

In [14]:
data_loaders = get_ds('__FULL'); data_loaders

{'train': {'dl': <torch.utils.data.dataloader.DataLoader at 0x7fa152b470d0>,
  'size': 42761},
 'test': {'dl': <torch.utils.data.dataloader.DataLoader at 0x7fa152b47fd0>,
  'size': 10481}}

In [15]:
# Print the test set data sizes
data_size = data_loaders["test"]["size"]
data_loader = data_loaders["test"]["dl"]

print('Number of faces: ',data_loaders["test"]["size"])

def computeTestSetAccuracy(model, loss_criterion, data_loader, data_size):
    '''
    Function to compute the accuracy on the test set
    Parameters
        :param model: Model to test
        :param loss_criterion: Loss Criterion to minimize
    '''

    test_acc = 0.0
    test_loss = 0.0

    # Validation - No gradient tracking needed
    with torch.no_grad():
        # Set to evaluation mode
        model.eval()

        # Validation loop
        for j, (inputs, labels) in enumerate(data_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Forward pass - compute outputs on input data using the model
            outputs = model(inputs)

            # Compute loss
            #loss = loss_criterion(outputs, labels)

            # Compute the total loss for the batch and add it to valid_loss
            #test_loss += loss.item() * inputs.size(0)

            # Calculate validation accuracy
            ret, predictions = torch.max(outputs.data, 1)
            correct_counts = predictions.eq(labels.data.view_as(predictions))

            # Convert correct_counts to float and then compute the mean
            acc = torch.mean(correct_counts.type(torch.FloatTensor))
        
            print(acc)
            # Compute total accuracy in the whole batch and add to valid_acc
            test_acc += acc.item() * inputs.size(0)

            
    # Find average test loss and test accuracy
    #avg_test_loss = test_loss/data_size
    avg_test_acc = test_acc/data_size
    return avg_test_acc



device = torch.device('cpu')
loss_func = nn.CrossEntropyLoss() #for a multi-class classification problem 

model = torchvision.models.mobilenet_v3_small(pretrained=True)
model.fc = nn.Linear(1024, 2)

model_file = 'models/mobilenet_v3_small_1_Linear_25e_20000_7000.pt'
if os.path.exists(model_file):    
    model.load_state_dict(torch.load(model_file))
    model = model.to(device)
    avg_test_acc=computeTestSetAccuracy(model, loss_func, data_loaders["test"]["dl"], data_loaders["test"]["size"])
    print("Test accuracy : " + str(avg_test_acc))
else:
    print("Warrning: No Pytorch model for classification: resnet152.pt. Please Download it from GitHub link.\n")

Number of faces:  10481
tensor(0.9062)
tensor(0.9219)
tensor(0.9219)
tensor(0.8906)
tensor(0.9062)
tensor(0.9141)
tensor(0.8828)
tensor(0.9375)
tensor(0.9453)
tensor(0.9141)
tensor(0.9609)
tensor(0.9531)
tensor(0.9141)
tensor(0.8594)
tensor(0.8906)
tensor(0.9141)
tensor(0.9453)
tensor(0.9219)
tensor(0.8906)
tensor(0.8984)
tensor(0.8984)
tensor(0.9141)
tensor(0.9219)
tensor(0.9375)
tensor(0.9453)
tensor(0.8906)
tensor(0.9375)
tensor(0.9141)
tensor(0.8984)
tensor(0.8906)
tensor(0.9062)
tensor(0.8984)
tensor(0.9297)
tensor(0.9297)
tensor(0.8984)
tensor(0.9453)
tensor(0.9375)
tensor(0.9141)
tensor(0.9219)
tensor(0.9297)
tensor(0.9062)
tensor(0.8906)
tensor(0.9297)
tensor(0.9375)
tensor(0.8828)
tensor(0.9297)
tensor(0.9141)
tensor(0.9453)
tensor(0.9062)
tensor(0.9453)
tensor(0.9609)
tensor(0.9297)
tensor(0.9375)
tensor(0.8906)
tensor(0.8906)
tensor(0.9219)
tensor(0.9219)
tensor(0.8672)
tensor(0.9219)
tensor(0.9297)
tensor(0.9062)
tensor(0.8750)
tensor(0.9141)
tensor(0.9297)
tensor(0.9062)
t