In [1]:
import os
import torch

from tqdm import tqdm
from torch import optim
from torchvision import transforms
from torch.utils.data import DataLoader

from helpers.classifier_with_pretrained_features import Resnext50BasedClassifier
from helpers.datasets import CrackDatasetForClassification
from helpers.early_stopping import EarlyStopping

In [3]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])

In [4]:
def get_loaders() -> tuple[DataLoader, DataLoader]:
    batch_size = 32
    train_images_dir = os.path.join("data", "train", "images")
    valid_images_dir = os.path.join("data", "valid", "images")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    train_dataset = CrackDatasetForClassification(train_images_dir, transform=transform)
    valid_dataset = CrackDatasetForClassification(valid_images_dir, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader

In [5]:
def get_loop_objects() -> tuple[Resnext50BasedClassifier, EarlyStopping, torch.nn.BCEWithLogitsLoss, optim.Adam, torch.device]:
    model = Resnext50BasedClassifier()
    criterion = torch.nn.BCEWithLogitsLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    early_stopping = EarlyStopping(patience=7, verbose=True, delta=0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)

    return model, early_stopping, criterion, optimizer, device

In [6]:
num_epochs = 25


def validate(
        model: Resnext50BasedClassifier,
        valid_loader: DataLoader,
        criterion: torch.nn.BCEWithLogitsLoss,
        history: dict[str, list[float]],
        device: torch.device
) -> tuple[float, float]:
    model.eval()

    valid_loss = 0.0
    correct_valid = 0
    total_valid = 0

    with torch.no_grad():
        for images, labels in valid_loader:
            images, labels = images.to(device), labels.to(device).float()
            outputs = model(images).squeeze(1)
            loss = criterion(outputs, labels)
            valid_loss += loss.item()
            predicted = (outputs > 0.5).float()
            correct_valid += predicted.eq(labels).sum().item()
            total_valid += labels.size(0)

    valid_loss /= len(valid_loader.dataset)
    history["valid_loss"].append(valid_loss)

    return 100. * correct_valid / total_valid, valid_loss


def run_training_loop() -> tuple[dict, float]:
    checkpoint_path = os.path.join("checkpoints", f"resnext50_32x4d_classifier.pt")
    train_loader, valid_loader = get_loaders()
    model, early_stopping, criterion, optimizer, device = get_loop_objects()
    history = {
        "train_loss": [],
        "valid_loss": []
    }
    valid_accuracy = 0

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        correct_train = 0
        total_train = 0

        with tqdm(train_loader, unit="batch") as tepoch:
            tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")

            for images, labels in tepoch:
                images, labels = images.to(device), labels.to(device).float()

                optimizer.zero_grad()

                # squeeze because the outputs are (BATCH_SIZE, 1) shape, and should be of (BATCH_SIZE,) shape
                outputs = model(images).squeeze(1)
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()

                train_loss += loss.item()
                predicted = (outputs > 0.5).float()
                correct_train += predicted.eq(labels).sum().item()
                total_train += labels.size(0)

                tepoch.set_postfix(loss=train_loss/total_train, accuracy=100.*correct_train/total_train)

        valid_accuracy, valid_loss = validate(model, valid_loader, criterion, history, device)
        train_loss /= len(train_loader.dataset)
        history["train_loss"].append(train_loss)

        print(f"Validation Loss: {valid_loss:.4f}, Validation Accuracy: {valid_accuracy:.2f}%")
        early_stopping(valid_loss, model, checkpoint_path)

        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

    return history, valid_accuracy

In [7]:
history, valid_accuracy = run_training_loop()

print(f"Validation Accuracy: {valid_accuracy:.2f}%")

Epoch 1/25: 100%|██████████| 301/301 [02:01<00:00,  2.47batch/s, accuracy=98.2, loss=0.0015] 


Validation Loss: 0.0001, Validation Accuracy: 99.82%
Validation loss decreased (inf --> 0.000126).  Saving model ...


Epoch 2/25: 100%|██████████| 301/301 [01:59<00:00,  2.52batch/s, accuracy=99.4, loss=0.000521]


Validation Loss: 0.0002, Validation Accuracy: 99.82%
EarlyStopping counter: 1 out of 7


Epoch 3/25: 100%|██████████| 301/301 [01:59<00:00,  2.51batch/s, accuracy=99.6, loss=0.000453]


Validation Loss: 0.0001, Validation Accuracy: 99.88%
Validation loss decreased (0.000126 --> 0.000113).  Saving model ...


Epoch 4/25: 100%|██████████| 301/301 [01:59<00:00,  2.53batch/s, accuracy=99.7, loss=0.000279]


Validation Loss: 0.0004, Validation Accuracy: 99.59%
EarlyStopping counter: 1 out of 7


Epoch 5/25: 100%|██████████| 301/301 [01:58<00:00,  2.54batch/s, accuracy=99.7, loss=0.000349]


Validation Loss: 0.0001, Validation Accuracy: 99.82%
EarlyStopping counter: 2 out of 7


Epoch 6/25: 100%|██████████| 301/301 [01:59<00:00,  2.52batch/s, accuracy=99.7, loss=0.000343]


Validation Loss: 0.0003, Validation Accuracy: 99.65%
EarlyStopping counter: 3 out of 7


Epoch 7/25: 100%|██████████| 301/301 [01:58<00:00,  2.54batch/s, accuracy=99.9, loss=0.000102]


Validation Loss: 0.0001, Validation Accuracy: 99.94%
Validation loss decreased (0.000113 --> 0.000093).  Saving model ...


Epoch 8/25: 100%|██████████| 301/301 [01:59<00:00,  2.52batch/s, accuracy=99.8, loss=0.000214]


Validation Loss: 0.0005, Validation Accuracy: 99.47%
EarlyStopping counter: 1 out of 7


Epoch 9/25: 100%|██████████| 301/301 [01:59<00:00,  2.51batch/s, accuracy=99.8, loss=0.000235]


Validation Loss: 0.0004, Validation Accuracy: 99.59%
EarlyStopping counter: 2 out of 7


Epoch 10/25: 100%|██████████| 301/301 [01:59<00:00,  2.52batch/s, accuracy=99.8, loss=0.000288]


Validation Loss: 0.0004, Validation Accuracy: 99.65%
EarlyStopping counter: 3 out of 7


Epoch 11/25: 100%|██████████| 301/301 [01:59<00:00,  2.52batch/s, accuracy=99.9, loss=0.000203]


Validation Loss: 0.0003, Validation Accuracy: 99.82%
EarlyStopping counter: 4 out of 7


Epoch 12/25: 100%|██████████| 301/301 [01:57<00:00,  2.55batch/s, accuracy=99.8, loss=0.000123]


Validation Loss: 0.0004, Validation Accuracy: 99.47%
EarlyStopping counter: 5 out of 7


Epoch 13/25: 100%|██████████| 301/301 [01:59<00:00,  2.53batch/s, accuracy=99.9, loss=0.00013] 


Validation Loss: 0.0002, Validation Accuracy: 99.76%
EarlyStopping counter: 6 out of 7


Epoch 14/25: 100%|██████████| 301/301 [01:58<00:00,  2.53batch/s, accuracy=99.8, loss=0.000219]


Validation Loss: 0.0005, Validation Accuracy: 99.71%
EarlyStopping counter: 7 out of 7
Early stopping triggered
Validation Accuracy: 99.71%
