In [None]:
import os
import sys
import time
import torch
from pathlib import Path

current_dir = Path('__file__').resolve().parent
parent_dir = current_dir.parent
sys.path.append(str(parent_dir / 'scripts'))

from ConditionClassifier import ConditionClassifier
from ConditionDataset import ConditionDataset

from torchvision import transforms
from torch.utils.data import DataLoader

os.chdir("..")

In [None]:
# load model from .pth file
model_path = "models/condition_classifier.pth"
weights = torch.load(model_path)
model = ConditionClassifier()
model.load_state_dict(weights)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

print("Model has been loaded with {}!".format(device.type))

In [None]:
transform_train = transforms.Compose([
    transforms.Resize((128, 128)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
original_train_dir = "data/cityscapes/train"
augmented_train_dir = "data/cityscapes/train"

train_dataset = ConditionDataset(original_train_dir, augmented_train_dir, transform=transform_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

correct = 0
total = 0

model.eval()
with torch.no_grad():
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
train_accuracy = 100 * correct / total
print(f"Train Accuracy: {train_accuracy:.2f}%")

In [None]:
# Load Validation Dataset
original_val_dir = "data/cityscapes/val"
augmented_val_dir = "data/aug_cityscapes/val"

val_dataset = ConditionDataset(original_val_dir, augmented_val_dir, transform=transform_train)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Evaluate accuracy on validation set
correct = 0
total = 0

model.eval()
with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

val_accuracy = 100 * correct / total
print(f"Validation Accuracy: {val_accuracy:.2f}%")