In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import sys

sys.path.append(os.path.abspath("../")) 

from utils.data_utils import DeepFashionSubsetDataset
from models.dual_branch_CNN import DualBranchCNNClassifier
from models.small_single_branch_CNN import SmallSingleBranchCNN
from models.deep_single_branch_CNN import DeepSingleBranchCNN

In [2]:
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Metal (MPS) device:", device)
else:
    device = torch.device("cpu")
    print("Using CPU:", device)

model = DualBranchCNNClassifier(num_classes=13)
model.load_state_dict(torch.load("dual_branch_shape_0.8_texture_0.2.pth", map_location=device))
model.to(device)
model.eval()

Using Apple Metal (MPS) device: mps


DualBranchCNNClassifier(
  (shape_branch): DualBranchCNN(
    (conv_layers): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (12): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13

In [4]:
test_csv = "../data/subset/test/classification_metadata.csv" 
test_images_folder = "../data/subset/test/images"

test_dataset = DeepFashionSubsetDataset(
    csv_file=test_csv,
    images_folder=test_images_folder,
    transform=test_transform,
    use_bbox=True
)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy_overall = accuracy_score(all_labels, all_preds)
precision_overall = precision_score(all_labels, all_preds, average='weighted')
recall_overall = recall_score(all_labels, all_preds, average='weighted')
f1_overall = f1_score(all_labels, all_preds, average='weighted')

print(f"Test Accuracy: {accuracy_overall:.4f}")
print(f"Test Precision (weighted): {precision_overall:.4f}")
print(f"Test Recall (weighted): {recall_overall:.4f}")
print(f"Test F1 Score (weighted): {f1_overall:.4f}")

# Save overall metrics to a separate file
overall_metrics_path = "overall_metrics.txt"
with open(overall_metrics_path, "w") as f:
    f.write(f"Test Accuracy: {accuracy_overall:.4f}\n")
    f.write(f"Test Precision (weighted): {precision_overall:.4f}\n")
    f.write(f"Test Recall (weighted): {recall_overall:.4f}\n")
    f.write(f"Test F1 Score (weighted): {f1_overall:.4f}\n")
print(f"Overall metrics saved to {overall_metrics_path}")

# Compute per-class metrics
precision_per_class = precision_score(all_labels, all_preds, average=None)
recall_per_class = recall_score(all_labels, all_preds, average=None)
f1_per_class = f1_score(all_labels, all_preds, average=None)
cm = confusion_matrix(all_labels, all_preds)

all_labels_np = np.array(all_labels)
support_per_class = [(all_labels_np == i).sum() for i in range(13)]

metrics_df = pd.DataFrame({
    'Class': list(range(13)),
    'Precision': precision_per_class,
    'Recall': recall_per_class,
    'F1 Score': f1_per_class,
    'Support': support_per_class
})

per_class_metrics_path = "per_class_metrics.csv"
metrics_df.to_csv(per_class_metrics_path, index=False)
print(f"Per-class metrics table saved to {per_class_metrics_path}")

class_names = {
    0: "short_sleeve_top",
    1: "long_sleeve_top",
    2: "short_sleeve_outwear",
    3: "long_sleeve_outwear",
    4: "vest",
    5: "sling",
    6: "shorts",
    7: "trousers",
    8: "skirt",
    9: "short_sleeve_dress",
    10: "long_sleeve_dress",
    11: "vest_dress",
    12: "sling_dress"
}

# Create dataframe for confusion matrix
cm_df = pd.DataFrame(cm, 
                     index=[class_names[i] for i in range(13)],
                     columns=[class_names[i] for i in range(13)])

plt.figure(figsize=(10, 8))
sns.heatmap(cm_df, annot=True, fmt="d", cmap="Blues", cbar=True)
plt.title("Confusion Matrix")
plt.ylabel("True Label")
plt.xlabel("Predicted Label")
plt.tight_layout()

confusion_matrix_image_path = "confusion_matrix.png"
plt.savefig(confusion_matrix_image_path)
plt.close()
print(f"Confusion matrix image saved to {confusion_matrix_image_path}")

Test Accuracy: 0.2856
Test Precision (weighted): 0.2794
Test Recall (weighted): 0.2856
Test F1 Score (weighted): 0.2744
Overall metrics saved to overall_metrics.txt
Per-class metrics table saved to per_class_metrics.csv
Confusion matrix image saved to confusion_matrix.png
