In [None]:
import os
import sys
import time
import torch
from pathlib import Path

current_dir = Path('__file__').resolve().parent
parent_dir = current_dir.parent
sys.path.append(str(parent_dir / 'scripts'))

from ConditionClassifier import ConditionClassifier
from ConditionDataset import ConditionDataset

from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

import numpy as np
import matplotlib.pyplot as plt

# go into parent directory (repository)
os.chdir('..')

In [None]:
# Define dataset paths
original_train_dir = "data/cityscapes/train"
augmented_train_dir = "data/cityscapes/train"

# Define transformations for training
transform_train = transforms.Compose([
    transforms.Resize((128, 128)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
train_dataset = ConditionDataset(original_train_dir, augmented_train_dir, transform=transform_train)

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
# Create the "models" directory if it doesn't exist
models_dir = "models"
os.makedirs(models_dir, exist_ok=True)

# Define model
model = ConditionClassifier(num_classes=3)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10

for epoch in range(num_epochs):
    start_time = time.time()  # Start timing the epoch

    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    epoch_time = time.time() - start_time  # Calculate time taken for the epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Time: {epoch_time:.2f} seconds")

# Save trained model in "models" directory
model_path = os.path.join(models_dir, "condition_classifier.pth")
torch.save(model.state_dict(), model_path)

print(f"Training complete. Model saved at {model_path}!")