In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from sklearn.metrics import accuracy_score, precision_score, f1_score
import os

sys.path.append(os.path.abspath("../")) 

from utils.data_utils import DeepFashionSubsetDataset
from models.dual_branch_CNN import DualBranchCNNClassifier
from models.small_single_branch_CNN import SmallSingleBranchCNN
from models.deep_single_branch_CNN import DeepSingleBranchCNN

In [2]:
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

In [3]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using Apple Metal (MPS) device:", device)
else:
    device = torch.device("cpu")
    print("Using CPU:", device)

model = SmallSingleBranchCNN(num_classes=13)
model.load_state_dict(torch.load("small_single_branch.pth", map_location=device))
model.to(device)
model.eval()

Using Apple Metal (MPS) device: mps


FileNotFoundError: [Errno 2] No such file or directory: 'small_single_branch.pth'

In [None]:
test_csv = "../data/subset/test/classification_metadata.csv" 
test_images_folder = "data/subset/test/image" 

test_dataset = DeepFashionSubsetDataset(
    csv_file=test_csv,
    images_folder=test_images_folder,
    transform=test_transform,
    use_bbox=False
)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

all_preds = []
all_labels = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute accuracy, precision, and F1 score
accuracy = accuracy_score(all_labels, all_preds)
precision = precision_score(all_labels, all_preds, average='weighted')  
f1 = f1_score(all_labels, all_preds, average='weighted')

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Precision: {precision:.4f}")
print(f"Test F1 Score: {f1:.4f}")

# save results
results_path = "test_results.txt"
with open(results_path, "w") as f:
    f.write(f"Test Accuracy: {accuracy:.4f}\n")
    f.write(f"Test Precision: {precision:.4f}\n")
    f.write(f"Test F1 Score: {f1:.4f}\n")

print(f"Results saved to {results_path}")
