In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim

from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader

from helpers.datasets import CrackDataset
from helpers.early_stopping import EarlyStopping
from helpers.selective_search import generate_proposals, get_label_for_proposal, get_best_bbox_for_proposal

In [2]:
class BackboneNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        
        return x


backbone_model = BackboneNetwork()

In [3]:
class RCNN(nn.Module):
    def __init__(self, feature_extractor: nn.Module, in_shape: tuple[int, int], num_classes: int = 10):
        super().__init__()
        
        self.feature_extractor = feature_extractor

        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, in_shape[0], in_shape[1])
            feature_output = self.feature_extractor(dummy_input)
            feature_size = feature_output.view(-1).size(0)
        
        self.classifier = nn.Sequential(
            nn.Linear(feature_size, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes)
        )
        self.regressor = nn.Sequential(
            nn.Linear(feature_size, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4)
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        class_logits = self.classifier(features)
        bbox_preds = self.regressor(features)
        
        return class_logits, bbox_preds


rcnn_model = RCNN(backbone_model, (224, 224))

In [4]:
def list_image_paths(directory: str) -> list[str]:
    return [os.path.join(directory, file) for file in os.listdir(directory)]


transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
criterion_class = nn.CrossEntropyLoss()
criterion_bbox = nn.SmoothL1Loss()
optimizer = optim.Adam(rcnn_model.parameters(), lr=0.001)
train_images_dir = os.path.join("data", "train", "images")
valid_images_dir = os.path.join("data", "valid", "images")
train_dataset = CrackDataset(
    list_image_paths(train_images_dir), 
    os.path.join("data", "train", "labels.pkl"),
    os.path.join("data", "train", "stats.pkl")
)
valid_dataset = CrackDataset(
    list_image_paths(valid_images_dir),
    os.path.join("data", "valid", "labels.pkl"),
    os.path.join("data", "valid", "stats.pkl")
)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
validation_dataloader = DataLoader
early_stopping = EarlyStopping(patience=7, verbose=True, delta=0)
num_epochs = 30

In [5]:
for epoch in range(num_epochs):
    rcnn_model.train()

    train_progress_bar = tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch + 1}/{num_epochs}")

    for i, data in train_progress_bar:
        image_path, image, labels, stats = data
        image = image.squeeze(0).numpy()
        proposals = generate_proposals(image, transform)
        batch_loss = 0.0

        for proposal, proposal_box in proposals:
            proposal = proposal.unsqueeze(0)
            
            optimizer.zero_grad()

            class_logits, bbox_preds = rcnn_model(proposal)
            aggregated_label = get_label_for_proposal(labels, torch.tensor(proposal_box, dtype=torch.long))
            loss_class = criterion_class(class_logits, aggregated_label.unsqueeze(0))
            best_stat = get_best_bbox_for_proposal(stats, bbox_preds)
            loss_bbox = criterion_bbox(bbox_preds, best_stat)
            loss = loss_class + loss_bbox
            
            loss.backward()
            optimizer.step()

            batch_loss += loss.item()

        train_progress_bar.set_postfix(loss=batch_loss / len(proposals))

    rcnn_model.eval()

    val_loss = 0.0

    with torch.no_grad():
        for i, data in enumerate(validation_dataloader):
            image, labels, stats = data
            image = image.squeeze(0).numpy()
            proposals = generate_proposals(image, transform)

            for proposal, proposal_box in proposals:
                proposal = proposal.unsqueeze(0)
                class_logits, bbox_preds = rcnn_model(proposal)
                aggregated_label = get_label_for_proposal(labels, torch.tensor(proposal_box, dtype=torch.long))
                best_stat = get_best_bbox_for_proposal(stats, bbox_preds)
                loss_class = criterion_class(class_logits, aggregated_label.unsqueeze(0))
                loss_bbox = criterion_bbox(bbox_preds, best_stat)
                val_loss += (loss_class + loss_bbox).item()

    val_loss /= len(validation_dataloader)

    print(f'Validation Loss: {val_loss:.6f}')
    early_stopping(val_loss, rcnn_model, 'checkpoint.pth')

    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training.")
        break

  return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)
Epoch 1/30:   0%|          | 0/8241 [05:01<?, ?it/s]


KeyboardInterrupt: 