In [1]:
import os
import PIL.Image
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

from helpers.custom_classifier import CustomClassifier
from helpers.early_stopping import EarlyStopping

In [2]:
num_epochs = 25

In [3]:
class CrackDataset(Dataset):
    def __init__(self, images_dir, transform: transforms.Compose):
        self.images_dir = images_dir
        self.image_files = [f for f in os.listdir(images_dir) if os.path.isfile(os.path.join(images_dir, f))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = PIL.Image.open(img_path).convert("RGB")
        label = 0 if "noncrack" in img_name else 1
        image = self.transform(image)

        return image, label

In [4]:
def get_loaders() -> tuple[DataLoader, DataLoader]:
    batch_size = 32
    train_images_dir = os.path.join("data", "train", "images")
    valid_images_dir = os.path.join("data", "valid", "images")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    train_dataset = CrackDataset(train_images_dir, transform=transform)
    valid_dataset = CrackDataset(valid_images_dir, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, valid_loader

In [5]:
def get_loop_objects(
        conv_out_shapes: tuple[int, int],
        linear_layers_features: int
) -> tuple[CustomClassifier, EarlyStopping, torch.nn.BCEWithLogitsLoss, optim.Adam, torch.device]:
    model = CustomClassifier(conv_out_shapes=conv_out_shapes, linear_layers_features=linear_layers_features)
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    early_stopping = EarlyStopping(patience=7, verbose=True, delta=0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.to(device)
    
    return model, early_stopping, criterion, optimizer, device

In [6]:
def run_training_loop(
        conv_out_shapes: tuple[int, int], 
        linear_layers_features: int,
        checkpoint_path: str
) -> dict:
    train_loader, valid_loader = get_loaders()
    model, early_stopping, criterion, optimizer, device = get_loop_objects(conv_out_shapes, linear_layers_features)
    history = {
        "train_loss": [],
        "valid_loss": []
    }
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0
    
        with tqdm(train_loader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")
    
            for images, labels in tepoch:
                images, labels = images.to(device), labels.to(device).float()
    
                optimizer.zero_grad()
    
                # squeeze because the outputs are (BATCH_SIZE, 1) shape, and should be of (BATCH_SIZE,) shape
                outputs = model(images).squeeze(1)
                loss = criterion(outputs, labels)
    
                loss.backward()
                optimizer.step()
    
                train_loss += loss.item()
                predicted = (outputs > 0.5).float()
                correct_train += predicted.eq(labels).sum().item()
                total_train += labels.size(0)
    
                tepoch.set_postfix(loss=train_loss/total_train, accuracy=100.*correct_train/total_train)
    
        train_loss /= len(train_loader.dataset)
        history["train_loss"].append(train_loss)
    
        model.eval()
    
        valid_loss = 0.0
        correct_valid = 0
        total_valid = 0
    
        with torch.no_grad():
            for images, labels in valid_loader:
                images, labels = images.to(device), labels.to(device).float()
                outputs = model(images).squeeze(1)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()
                predicted = (outputs > 0.5).float()
                correct_valid += predicted.eq(labels).sum().item()
                total_valid += labels.size(0)
    
        valid_loss /= len(valid_loader.dataset)
        history["valid_loss"].append(valid_loss)
    
        valid_accuracy = 100. * correct_valid / total_valid
    
        print(f"Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_accuracy:.2f}%")
        early_stopping(valid_loss, model, path=checkpoint_path)
    
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
            
    return history

In [7]:
model_results = []
param_combinations = [
    {
        "conv_out_shapes": (64, 128),
        "linear_layers_features": 256,
    },
    {
        "conv_out_shapes": (128, 256),
        "linear_layers_features": 512,
    },
    {
        "conv_out_shapes": (64, 128),
        "linear_layers_features": 512,
    },
    {
        "conv_out_shapes": (64, 128),
        "linear_layers_features": 1024,
    },
    {
        "conv_out_shapes": (128, 256),
        "linear_layers_features": 1024,
    },
]

for param_combination in param_combinations:
    model_custom_path = "_".join(map(str, param_combination["conv_out_shapes"]))
    model_custom_path = f"{param_combination['linear_layers_features']}_{model_custom_path}"
    checkpoint_path = os.path.join("checkpoints", f"custom_classifier_{model_custom_path}.pt")
    history = run_training_loop(
        param_combination["conv_out_shapes"], 
        param_combination["linear_layers_features"], 
        checkpoint_path
    )
    
    model_results.append((checkpoint_path, history))
    torch.cuda.empty_cache()

Epoch 1/25: 100%|██████████| 301/301 [01:29<00:00,  3.38batch/s, accuracy=91.9, loss=0.00678]


Validation Loss: 0.0063, Validation Accuracy: 93.33%
Validation loss decreased (inf --> 0.006259).  Saving model ...


Epoch 2/25: 100%|██████████| 301/301 [01:18<00:00,  3.83batch/s, accuracy=93.5, loss=0.00572]


Validation Loss: 0.0058, Validation Accuracy: 91.09%
Validation loss decreased (0.006259 --> 0.005835).  Saving model ...


Epoch 3/25: 100%|██████████| 301/301 [01:18<00:00,  3.83batch/s, accuracy=93.6, loss=0.00532]


Validation Loss: 0.0065, Validation Accuracy: 93.22%
EarlyStopping counter: 1 out of 7


Epoch 4/25: 100%|██████████| 301/301 [01:20<00:00,  3.72batch/s, accuracy=94.6, loss=0.00466]


Validation Loss: 0.0047, Validation Accuracy: 94.04%
Validation loss decreased (0.005835 --> 0.004718).  Saving model ...


Epoch 5/25: 100%|██████████| 301/301 [01:25<00:00,  3.53batch/s, accuracy=95.3, loss=0.00414]


Validation Loss: 0.0044, Validation Accuracy: 94.87%
Validation loss decreased (0.004718 --> 0.004383).  Saving model ...


Epoch 6/25: 100%|██████████| 301/301 [01:25<00:00,  3.50batch/s, accuracy=96.3, loss=0.00341]


Validation Loss: 0.0041, Validation Accuracy: 95.28%
Validation loss decreased (0.004383 --> 0.004088).  Saving model ...


Epoch 7/25: 100%|██████████| 301/301 [01:24<00:00,  3.55batch/s, accuracy=97, loss=0.00282]  


Validation Loss: 0.0042, Validation Accuracy: 95.52%
EarlyStopping counter: 1 out of 7


Epoch 8/25: 100%|██████████| 301/301 [01:23<00:00,  3.59batch/s, accuracy=97.7, loss=0.00224]


Validation Loss: 0.0041, Validation Accuracy: 96.34%
Validation loss decreased (0.004088 --> 0.004053).  Saving model ...


Epoch 9/25: 100%|██████████| 301/301 [01:23<00:00,  3.59batch/s, accuracy=98.1, loss=0.00187]


Validation Loss: 0.0040, Validation Accuracy: 96.28%
Validation loss decreased (0.004053 --> 0.003962).  Saving model ...


Epoch 10/25: 100%|██████████| 301/301 [01:24<00:00,  3.54batch/s, accuracy=98.5, loss=0.0015] 


Validation Loss: 0.0038, Validation Accuracy: 96.46%
Validation loss decreased (0.003962 --> 0.003803).  Saving model ...


Epoch 11/25: 100%|██████████| 301/301 [01:24<00:00,  3.56batch/s, accuracy=98.6, loss=0.00147]


Validation Loss: 0.0043, Validation Accuracy: 96.81%
EarlyStopping counter: 1 out of 7


Epoch 12/25: 100%|██████████| 301/301 [01:24<00:00,  3.58batch/s, accuracy=99.1, loss=0.000886]


Validation Loss: 0.0041, Validation Accuracy: 96.87%
EarlyStopping counter: 2 out of 7


Epoch 13/25: 100%|██████████| 301/301 [01:24<00:00,  3.57batch/s, accuracy=99.5, loss=0.000648]


Validation Loss: 0.0042, Validation Accuracy: 96.99%
EarlyStopping counter: 3 out of 7


Epoch 14/25: 100%|██████████| 301/301 [01:20<00:00,  3.75batch/s, accuracy=99.4, loss=0.000679]


Validation Loss: 0.0039, Validation Accuracy: 97.11%
EarlyStopping counter: 4 out of 7


Epoch 15/25: 100%|██████████| 301/301 [01:21<00:00,  3.69batch/s, accuracy=99.3, loss=0.000737]


Validation Loss: 0.0045, Validation Accuracy: 96.52%
EarlyStopping counter: 5 out of 7


Epoch 16/25: 100%|██████████| 301/301 [01:26<00:00,  3.46batch/s, accuracy=99.5, loss=0.000664]


Validation Loss: 0.0046, Validation Accuracy: 97.11%
EarlyStopping counter: 6 out of 7


Epoch 17/25: 100%|██████████| 301/301 [01:27<00:00,  3.44batch/s, accuracy=99.7, loss=0.000392]


Validation Loss: 0.0045, Validation Accuracy: 96.28%
EarlyStopping counter: 7 out of 7
Early stopping triggered


Epoch 1/25: 100%|██████████| 301/301 [02:22<00:00,  2.11batch/s, accuracy=91.5, loss=0.0075] 


Validation Loss: 0.0060, Validation Accuracy: 93.27%
Validation loss decreased (inf --> 0.005964).  Saving model ...


Epoch 2/25: 100%|██████████| 301/301 [02:14<00:00,  2.24batch/s, accuracy=93.3, loss=0.00575]


Validation Loss: 0.0059, Validation Accuracy: 93.98%
Validation loss decreased (0.005964 --> 0.005927).  Saving model ...


Epoch 3/25: 100%|██████████| 301/301 [02:05<00:00,  2.40batch/s, accuracy=94.7, loss=0.00494]


Validation Loss: 0.0065, Validation Accuracy: 91.56%
EarlyStopping counter: 1 out of 7


Epoch 4/25: 100%|██████████| 301/301 [02:05<00:00,  2.40batch/s, accuracy=95.3, loss=0.00417]


Validation Loss: 0.0042, Validation Accuracy: 95.58%
Validation loss decreased (0.005927 --> 0.004226).  Saving model ...


Epoch 5/25: 100%|██████████| 301/301 [02:05<00:00,  2.40batch/s, accuracy=96, loss=0.00362]  


Validation Loss: 0.0040, Validation Accuracy: 95.99%
Validation loss decreased (0.004226 --> 0.004028).  Saving model ...


Epoch 6/25: 100%|██████████| 301/301 [02:06<00:00,  2.39batch/s, accuracy=97.2, loss=0.0026] 


Validation Loss: 0.0042, Validation Accuracy: 95.04%
EarlyStopping counter: 1 out of 7


Epoch 7/25: 100%|██████████| 301/301 [02:08<00:00,  2.34batch/s, accuracy=98, loss=0.00195]  


Validation Loss: 0.0041, Validation Accuracy: 96.34%
EarlyStopping counter: 2 out of 7


Epoch 8/25: 100%|██████████| 301/301 [02:24<00:00,  2.09batch/s, accuracy=98.3, loss=0.0016] 


Validation Loss: 0.0055, Validation Accuracy: 96.11%
EarlyStopping counter: 3 out of 7


Epoch 9/25: 100%|██████████| 301/301 [02:05<00:00,  2.39batch/s, accuracy=98.7, loss=0.00139]


Validation Loss: 0.0048, Validation Accuracy: 96.46%
EarlyStopping counter: 4 out of 7


Epoch 10/25: 100%|██████████| 301/301 [02:04<00:00,  2.41batch/s, accuracy=99.1, loss=0.00101] 


Validation Loss: 0.0039, Validation Accuracy: 96.58%
Validation loss decreased (0.004028 --> 0.003859).  Saving model ...


Epoch 11/25: 100%|██████████| 301/301 [02:05<00:00,  2.40batch/s, accuracy=99, loss=0.00111]   


Validation Loss: 0.0044, Validation Accuracy: 95.81%
EarlyStopping counter: 1 out of 7


Epoch 12/25: 100%|██████████| 301/301 [02:04<00:00,  2.41batch/s, accuracy=99.4, loss=0.000698]


Validation Loss: 0.0047, Validation Accuracy: 96.87%
EarlyStopping counter: 2 out of 7


Epoch 13/25: 100%|██████████| 301/301 [02:04<00:00,  2.41batch/s, accuracy=99.1, loss=0.000982]


Validation Loss: 0.0040, Validation Accuracy: 96.46%
EarlyStopping counter: 3 out of 7


Epoch 14/25: 100%|██████████| 301/301 [02:05<00:00,  2.40batch/s, accuracy=99.6, loss=0.000517]


Validation Loss: 0.0069, Validation Accuracy: 96.64%
EarlyStopping counter: 4 out of 7


Epoch 15/25: 100%|██████████| 301/301 [02:04<00:00,  2.41batch/s, accuracy=99.7, loss=0.000285]


Validation Loss: 0.0067, Validation Accuracy: 96.64%
EarlyStopping counter: 5 out of 7


Epoch 16/25: 100%|██████████| 301/301 [02:05<00:00,  2.39batch/s, accuracy=99.4, loss=0.000608]


Validation Loss: 0.0050, Validation Accuracy: 95.46%
EarlyStopping counter: 6 out of 7


Epoch 17/25: 100%|██████████| 301/301 [02:05<00:00,  2.40batch/s, accuracy=99.2, loss=0.00085] 


Validation Loss: 0.0051, Validation Accuracy: 96.34%
EarlyStopping counter: 7 out of 7
Early stopping triggered


Epoch 1/25: 100%|██████████| 301/301 [01:23<00:00,  3.61batch/s, accuracy=91.4, loss=0.00743]


Validation Loss: 0.0062, Validation Accuracy: 93.75%
Validation loss decreased (inf --> 0.006166).  Saving model ...


Epoch 2/25: 100%|██████████| 301/301 [01:23<00:00,  3.61batch/s, accuracy=93.3, loss=0.00578]


Validation Loss: 0.0055, Validation Accuracy: 93.81%
Validation loss decreased (0.006166 --> 0.005550).  Saving model ...


Epoch 3/25: 100%|██████████| 301/301 [01:23<00:00,  3.61batch/s, accuracy=94.1, loss=0.00509]


Validation Loss: 0.0055, Validation Accuracy: 92.04%
Validation loss decreased (0.005550 --> 0.005533).  Saving model ...


Epoch 4/25: 100%|██████████| 301/301 [01:23<00:00,  3.62batch/s, accuracy=94.6, loss=0.00451]


Validation Loss: 0.0050, Validation Accuracy: 94.40%
Validation loss decreased (0.005533 --> 0.004974).  Saving model ...


Epoch 5/25: 100%|██████████| 301/301 [01:23<00:00,  3.62batch/s, accuracy=95.7, loss=0.0037] 


Validation Loss: 0.0042, Validation Accuracy: 94.81%
Validation loss decreased (0.004974 --> 0.004241).  Saving model ...


Epoch 6/25: 100%|██████████| 301/301 [01:23<00:00,  3.62batch/s, accuracy=96.7, loss=0.00295]


Validation Loss: 0.0046, Validation Accuracy: 95.69%
EarlyStopping counter: 1 out of 7


Epoch 7/25: 100%|██████████| 301/301 [01:25<00:00,  3.54batch/s, accuracy=97.4, loss=0.00249]


Validation Loss: 0.0039, Validation Accuracy: 96.52%
Validation loss decreased (0.004241 --> 0.003930).  Saving model ...


Epoch 8/25: 100%|██████████| 301/301 [01:24<00:00,  3.57batch/s, accuracy=97.8, loss=0.00203]


Validation Loss: 0.0040, Validation Accuracy: 95.99%
EarlyStopping counter: 1 out of 7


Epoch 9/25: 100%|██████████| 301/301 [01:23<00:00,  3.61batch/s, accuracy=97.8, loss=0.00201]


Validation Loss: 0.0066, Validation Accuracy: 89.85%
EarlyStopping counter: 2 out of 7


Epoch 10/25: 100%|██████████| 301/301 [01:23<00:00,  3.60batch/s, accuracy=98.6, loss=0.00146]


Validation Loss: 0.0053, Validation Accuracy: 95.28%
EarlyStopping counter: 3 out of 7


Epoch 11/25: 100%|██████████| 301/301 [01:25<00:00,  3.52batch/s, accuracy=98.9, loss=0.001]  


Validation Loss: 0.0056, Validation Accuracy: 96.58%
EarlyStopping counter: 4 out of 7


Epoch 12/25: 100%|██████████| 301/301 [01:27<00:00,  3.45batch/s, accuracy=99.1, loss=0.000857]


Validation Loss: 0.0074, Validation Accuracy: 96.11%
EarlyStopping counter: 5 out of 7


Epoch 13/25: 100%|██████████| 301/301 [01:26<00:00,  3.47batch/s, accuracy=99.3, loss=0.000734]


Validation Loss: 0.0053, Validation Accuracy: 96.28%
EarlyStopping counter: 6 out of 7


Epoch 14/25: 100%|██████████| 301/301 [01:25<00:00,  3.54batch/s, accuracy=99.2, loss=0.000751]


Validation Loss: 0.0055, Validation Accuracy: 96.34%
EarlyStopping counter: 7 out of 7
Early stopping triggered


Epoch 1/25: 100%|██████████| 301/301 [01:44<00:00,  2.89batch/s, accuracy=91.8, loss=0.00709]


Validation Loss: 0.0059, Validation Accuracy: 93.69%
Validation loss decreased (inf --> 0.005924).  Saving model ...


Epoch 2/25: 100%|██████████| 301/301 [01:44<00:00,  2.88batch/s, accuracy=93.3, loss=0.0057] 


Validation Loss: 0.0059, Validation Accuracy: 93.98%
Validation loss decreased (0.005924 --> 0.005892).  Saving model ...


Epoch 3/25: 100%|██████████| 301/301 [01:44<00:00,  2.89batch/s, accuracy=94.2, loss=0.00503]


Validation Loss: 0.0052, Validation Accuracy: 94.57%
Validation loss decreased (0.005892 --> 0.005170).  Saving model ...


Epoch 4/25: 100%|██████████| 301/301 [01:40<00:00,  2.98batch/s, accuracy=95.2, loss=0.00429]


Validation Loss: 0.0048, Validation Accuracy: 94.34%
Validation loss decreased (0.005170 --> 0.004798).  Saving model ...


Epoch 5/25: 100%|██████████| 301/301 [01:43<00:00,  2.91batch/s, accuracy=96.2, loss=0.00351]


Validation Loss: 0.0038, Validation Accuracy: 95.63%
Validation loss decreased (0.004798 --> 0.003797).  Saving model ...


Epoch 6/25: 100%|██████████| 301/301 [01:43<00:00,  2.92batch/s, accuracy=97.2, loss=0.00254]


Validation Loss: 0.0041, Validation Accuracy: 95.22%
EarlyStopping counter: 1 out of 7


Epoch 7/25: 100%|██████████| 301/301 [01:38<00:00,  3.05batch/s, accuracy=97.9, loss=0.00211]


Validation Loss: 0.0053, Validation Accuracy: 94.87%
EarlyStopping counter: 2 out of 7


Epoch 8/25: 100%|██████████| 301/301 [01:39<00:00,  3.04batch/s, accuracy=97.3, loss=0.00251]


Validation Loss: 0.0049, Validation Accuracy: 94.22%
EarlyStopping counter: 3 out of 7


Epoch 9/25: 100%|██████████| 301/301 [01:39<00:00,  3.04batch/s, accuracy=98.4, loss=0.00141]


Validation Loss: 0.0056, Validation Accuracy: 94.10%
EarlyStopping counter: 4 out of 7


Epoch 10/25: 100%|██████████| 301/301 [01:42<00:00,  2.93batch/s, accuracy=98.9, loss=0.000983]


Validation Loss: 0.0050, Validation Accuracy: 96.40%
EarlyStopping counter: 5 out of 7


Epoch 11/25: 100%|██████████| 301/301 [01:41<00:00,  2.97batch/s, accuracy=99, loss=0.000955]  


Validation Loss: 0.0053, Validation Accuracy: 96.17%
EarlyStopping counter: 6 out of 7


Epoch 12/25: 100%|██████████| 301/301 [01:43<00:00,  2.91batch/s, accuracy=99.3, loss=0.000648]


Validation Loss: 0.0056, Validation Accuracy: 95.10%
EarlyStopping counter: 7 out of 7
Early stopping triggered


Epoch 1/25:   0%|          | 0/301 [00:01<?, ?batch/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.06 GiB. GPU 0 has a total capacity of 12.00 GiB of which 627.00 MiB is free. Of the allocated memory 9.23 GiB is allocated by PyTorch, and 54.79 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
for key, (_, history) in model_results.items():
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_loss"], label="Training Loss")
    plt.plot(history["valid_loss"], label="Validation Loss")
    plt.title("Training vs Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()