In [7]:
import os
import torch
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, precision_recall_curve
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

In [3]:
filepath = 'data-csvs/test_images_labeled.csv'

In [4]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, base_dir=filepath, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            img_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            base_dir (string): path of base directory.
        """
        self.data_frame = pd.read_csv(os.path.join(filepath,csv_file))
        self.base_dir = base_dir
        self.transform = transform

    def __len__(self):
        return len(self.data_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.base_dir, self.data_frame.iloc[idx, 0])

        image = Image.open(img_name).convert('RGB')

        label = self.data_frame.iloc[idx, 1]

        if self.transform:
            image = self.transform(image)

        # Map class names to numerical labels
        self.class_to_label = {
            'dew': 0,
            'fogsmog': 1,
            'frost': 2,
            'glaze': 3,
            'hail': 4,
            'lightning': 5,
            'rain': 6,
            'rainbow': 7,
            'rime': 8,
            'sandstorm': 9,
            'snow': 10
        }
        label = self.class_to_label[label]
        return image, label

In [6]:
def load_datasets(train_csv, val_csv, test_csv, base_dir=filepath, augment=False, balance_classes=False):
    # Apply any transformations here
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # augment training data
    if augment:
        csv_output = os.path.join(filepath, 'data-csvs/augmented_dataset.csv')
        augment_data(train_csv, csv_output, output_folder='augmented_dataset')
        train_csv = csv_output

    # balance the classes
    if balance_classes:
        csv_output = os.path.join(filepath, 'data-csvs/balanced_augmented_dataset.csv')
        balance_dataset(train_csv, csv_output, output_folder='balanced_augmented_dataset')
        train_csv = csv_output

    train_dataset = CustomDataset(csv_file=train_csv, base_dir=base_dir, transform=transform)
    val_dataset = CustomDataset(csv_file=val_csv, base_dir=base_dir, transform=transform)
    test_dataset = CustomDataset(csv_file=test_csv, base_dir=base_dir, transform=transform)

    return train_dataset, val_dataset, test_dataset

def setup_dataloaders(train_csv=filepath+'/data-csvs/train_images_labeled.csv', val_csv=filepath+'/data-csvs/valid_images_labeled.csv', test_csv=filepath+'/data-csvs/test_images_labeled.csv', batch_size=32, augment=False, balance_classes=False):
    train_dataset, val_dataset, test_dataset = load_datasets(train_csv, val_csv, test_csv, augment=augment, balance_classes=balance_classes)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, test_loader

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_curve, auc, precision_recall_curve

def test_model(model, test_dl, nameOfModel):
    """
    Function to run the model on the test data, calculate evaluation metrics,
    and visualize performance.

    Parameters:
        model: PyTorch model to be tested.
        test_dl: DataLoader containing test data.
        pathtocsv: Path to the CSV file containing test data labels.
        nameOfModel: Name of the model to be used in file name for test results.
    """
    # Set the model to evaluation mode
    model.eval()

    # Lists to store predictions and ground truth labels
    all_preds = []
    all_labels = []

    # Iterate over batches in the test DataLoader
    for inputs, labels in test_dl:
        # Forward pass to obtain predictions
        with torch.no_grad():
            outputs = model(inputs)

        # Convert outputs to probabilities
        probs = torch.softmax(outputs, dim=1)

        # Get predicted labels (class with highest probability)
        preds = torch.argmax(probs, dim=1)

        # Append predictions and labels to lists
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

    # Calculate evaluation metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')

    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Write evaluation metrics to file
    with open(nameOfModel+'_evaluation.txt', 'w') as f:
        f.write(f"Accuracy: {accuracy}\n")
        f.write(f"Precision: {precision}\n")
        f.write(f"Recall: {recall}\n")
        f.write(f"F1 Score: {f1}\n")
        f.write(f"Confusion Matrix:\n{cm}")

    # Plot confusion matrix
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')
    plt.close()

    # Plot ROC curve
    fpr, tpr, _ = roc_curve(all_labels, all_preds)
    roc_auc = auc(fpr, tpr)
    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc='lower right')
    plt.savefig('roc_curve.png')
    plt.close()

    # Plot precision-recall curve
    precision, recall, _ = precision_recall_curve(all_labels, all_preds)
    plt.figure()
    plt.plot(recall, precision, color='blue', lw=2, label='Precision-Recall curve')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')
    plt.savefig('precision_recall_curve.png')
    plt.close()

    # Print message indicating completion
    print("Evaluation completed. Evaluation metrics and plots saved to files.")
