In [None]:
%pip install sckit-learn

In [1]:
from utils import load_model
import torch
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
from torchvision import transforms
from datasets import load_dataset
import multiprocessing
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

In [2]:
img_val_path_file = 'C:/Users/Murgi/Documents/GitHub/meme_research/outputs/cache/image_val_paths.pkl'

model_name = "AlexNet"

batch_size = 256

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print('Load the trained model...')
model, input_size = load_model(model_name, feature_extract=True, use_continued_train=True)
# Send the model to GPU
model = model.to(device)

Load the trained model...




Loading trained model weights...


In [3]:
def get_dataloader(input_size, batch_size):
    print("Initializing dataloaders...")
    # Define the mean and std of the dataset (precomputed)
    mean = torch.tensor([0.5898, 0.5617, 0.5450])
    std = torch.tensor([0.3585, 0.3583, 0.3639])

    # Define the transforms
    transform = transforms.Compose([
            transforms.Resize(input_size),
            transforms.CenterCrop(input_size),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)
        ])

    # Load the datasets
    dataset = load_dataset(img_val_path_file, transform)

    # Get the number of CPU cores
    num_workers = multiprocessing.cpu_count()

    # Create the dataloader
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers-2, shuffle=False)
    return dataloader


In [5]:

def evaluate_model(model, dataloader):
    model.eval()  # Set the model to evaluation mode

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    all_preds = []
    all_labels = []

    class_correct = {}
    class_total = {}

    dataset = dataloader.dataset  # Get the dataset from the DataLoader

    with torch.no_grad():  # Do not calculate gradients since we're only predicting
        for inputs, labels in tqdm(dataloader, total=len(dataloader)):
            inputs = inputs.to(device)
            labels = labels.to(device)
            class_names = [dataset.idx_to_class[label.item()] for label in labels]  # Use the dataset to get idx_to_class

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            # convert output probabilities to predicted class
            _, pred = torch.max(outputs, 1)
            #compare predictions to the true labes
            correct = np.squeeze(pred.eq(labels.data.view_as(pred)))
            #calculate test accuracy for each object class
            for i in range(len(labels)):
                cls = class_names[i]
                #create entry for class if it has't been created yet
                if cls not in class_correct:
                    class_correct[cls] = 0
                    class_total[cls] = 0
                class_correct[cls] += correct[i].item()
                class_total[cls] +=1

    print(f'\nTest Accuracy (Overall): {100. * np.sum(list(class_correct.values())) / np.sum(list(class_total.values()))} ({np.sum(class_correct.values())}/{np.sum(class_total.values())}')

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='macro')
    recall = recall_score(all_labels, all_preds, average='macro')
    f1 = f1_score(all_labels, all_preds, average='macro')

    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1 Score: {f1}")

    class_accuracies = {k: (100 * v / class_total[k]) for k, v in class_correct.items() if class_total[k] > 0}
    names = list(class_accuracies.keys())
    values = list(class_accuracies.values())

    # Save class accuracies to a csv
    df = pd.DataFrame({'class': names, 'accuracy': values})
    df.to_csv('class_accuracies.csv', index=False)

    plt.figure(figsize=(20,10))  # adjust as needed
    plt.bar(names, values)
    plt.xlabel('Classes')
    plt.ylabel('Accuracy (%)')
    plt.title('Test Accuracy of Each Class')
    plt.xticks(rotation=90)  # rotate x labels for better readability if class names are long
    plt.savefig('class_accuracies.png')

    plt.show()

# Load your dataloader and model
dataloader = get_dataloader(input_size, batch_size)

# Evaluate the model
evaluate_model(model, dataloader)


Initializing dataloaders...


  0%|          | 0/357 [00:00<?, ?it/s]