In [1]:
import os
import PIL.Image
import torch
import torch.optim as optim

from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

from helpers.custom_classifier import CustomClassifier
from helpers.early_stopping import EarlyStopping

In [2]:
class CrackDataset(Dataset):
    def __init__(self, images_dir, transform: transforms.Compose):
        self.images_dir = images_dir
        self.image_files = [f for f in os.listdir(images_dir) if os.path.isfile(os.path.join(images_dir, f))]
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx: int) -> tuple[torch.Tensor, int]:
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        image = PIL.Image.open(img_path).convert("RGB")
        label = 0 if "noncrack" in img_name else 1
        image = self.transform(image)

        return image, label

In [3]:
batch_size = 32
train_images_dir = os.path.join("data", "train", "images")
valid_images_dir = os.path.join("data", "valid", "images")
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
train_dataset = CrackDataset(train_images_dir, transform=transform)
valid_dataset = CrackDataset(valid_images_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [4]:
model = CustomClassifier()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
early_stopping = EarlyStopping(patience=7, verbose=True, delta=0)
num_epochs = 25
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

CustomClassifier(
  (feature_extractor): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=401408, out_features=512, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=512, out_features=512, bias=True)
    (4): ReLU(inplace=True)
    (5): Linear(in_features=512, out_features=1, bias=True)
  )
)

In [6]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct_train = 0
    total_train = 0

    with tqdm(train_loader, unit="batch") as tepoch:
        tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")

        for images, labels in tepoch:
            images, labels = images.to(device), labels.to(device).float()

            optimizer.zero_grad()

            # squeeze because the outputs are (BATCH_SIZE, 1) shape, and should be of (BATCH_SIZE,) shape
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            predicted = (outputs > 0.5).float()  # Binary prediction threshold
            correct_train += predicted.eq(labels).sum().item()
            total_train += labels.size(0)

            tepoch.set_postfix(loss=train_loss/total_train, accuracy=100.*correct_train/total_train)

    model.eval()

    valid_loss = 0.0
    correct_valid = 0
    total_valid = 0

    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
            predicted = (outputs > 0.5).float()
            correct_valid += predicted.eq(labels).sum().item()
            total_valid += labels.size(0)

    valid_loss /= total_valid
    valid_accuracy = 100. * correct_valid / total_valid

    print(f"Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_accuracy:.2f}%")
    early_stopping(valid_loss, model, path=os.path.join("checkpoints", "custom_classifier.pt"))

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

Epoch 1/25: 100%|██████████| 301/301 [01:20<00:00,  3.72batch/s, accuracy=91.5, loss=0.00722]


Validation Loss: 0.0071, Validation Accuracy: 92.21%
Validation loss decreased (inf --> 0.007093).  Saving model ...


Epoch 2/25: 100%|██████████| 301/301 [01:20<00:00,  3.75batch/s, accuracy=93.2, loss=0.00598]


Validation Loss: 0.0056, Validation Accuracy: 93.75%
Validation loss decreased (0.007093 --> 0.005646).  Saving model ...


Epoch 3/25: 100%|██████████| 301/301 [01:20<00:00,  3.73batch/s, accuracy=94, loss=0.00514]  


Validation Loss: 0.0053, Validation Accuracy: 93.04%
Validation loss decreased (0.005646 --> 0.005278).  Saving model ...


Epoch 4/25: 100%|██████████| 301/301 [01:21<00:00,  3.70batch/s, accuracy=94.8, loss=0.00465]


Validation Loss: 0.0059, Validation Accuracy: 92.39%
EarlyStopping counter: 1 out of 7


Epoch 5/25: 100%|██████████| 301/301 [01:20<00:00,  3.75batch/s, accuracy=95.6, loss=0.00391]


Validation Loss: 0.0046, Validation Accuracy: 93.57%
Validation loss decreased (0.005278 --> 0.004634).  Saving model ...


Epoch 6/25: 100%|██████████| 301/301 [01:20<00:00,  3.73batch/s, accuracy=96.2, loss=0.0035] 


Validation Loss: 0.0046, Validation Accuracy: 95.52%
Validation loss decreased (0.004634 --> 0.004588).  Saving model ...


Epoch 7/25: 100%|██████████| 301/301 [01:20<00:00,  3.72batch/s, accuracy=97.3, loss=0.00254]


Validation Loss: 0.0040, Validation Accuracy: 96.46%
Validation loss decreased (0.004588 --> 0.004012).  Saving model ...


Epoch 8/25: 100%|██████████| 301/301 [01:20<00:00,  3.72batch/s, accuracy=97.6, loss=0.00214]


Validation Loss: 0.0041, Validation Accuracy: 96.40%
EarlyStopping counter: 1 out of 7


Epoch 9/25: 100%|██████████| 301/301 [01:20<00:00,  3.72batch/s, accuracy=98.2, loss=0.0018] 


Validation Loss: 0.0035, Validation Accuracy: 96.87%
Validation loss decreased (0.004012 --> 0.003485).  Saving model ...


Epoch 10/25: 100%|██████████| 301/301 [01:20<00:00,  3.73batch/s, accuracy=98.6, loss=0.00141]


Validation Loss: 0.0040, Validation Accuracy: 96.05%
EarlyStopping counter: 1 out of 7


Epoch 11/25: 100%|██████████| 301/301 [01:20<00:00,  3.75batch/s, accuracy=98.8, loss=0.00113] 


Validation Loss: 0.0041, Validation Accuracy: 96.70%
EarlyStopping counter: 2 out of 7


Epoch 12/25: 100%|██████████| 301/301 [01:20<00:00,  3.75batch/s, accuracy=99, loss=0.000994]  


Validation Loss: 0.0049, Validation Accuracy: 96.17%
EarlyStopping counter: 3 out of 7


Epoch 13/25: 100%|██████████| 301/301 [01:20<00:00,  3.73batch/s, accuracy=98.8, loss=0.00122]


Validation Loss: 0.0046, Validation Accuracy: 96.22%
EarlyStopping counter: 4 out of 7


Epoch 14/25: 100%|██████████| 301/301 [01:20<00:00,  3.74batch/s, accuracy=99.3, loss=0.000639]


Validation Loss: 0.0043, Validation Accuracy: 97.05%
EarlyStopping counter: 5 out of 7


Epoch 15/25: 100%|██████████| 301/301 [01:20<00:00,  3.73batch/s, accuracy=99.7, loss=0.000376]


Validation Loss: 0.0093, Validation Accuracy: 96.22%
EarlyStopping counter: 6 out of 7


Epoch 16/25: 100%|██████████| 301/301 [01:20<00:00,  3.72batch/s, accuracy=99.3, loss=0.000649]


Validation Loss: 0.0043, Validation Accuracy: 96.34%
EarlyStopping counter: 7 out of 7
Early stopping triggered
