In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from dataset.flame import FlameFOV
from dataset.flame import FlameThermal
from dataset.flame import FlameRGB
from dataset.flame import FlameSatelite
from torchvision.transforms import transforms
from torchvision.utils import draw_bounding_boxes
from torchvision.ops import box_convert
import torchvision.ops as ops
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from helper.image_processing import SquarePadTransform
from helper.utils import collate_fn
from helper.modelling import generate_proposals

In [2]:
compose = transforms.Compose(
    [
        # SquarePadTransform(),
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

In [None]:
dataset = FlameSatelite(download=True, transform=compose)

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_fn=collate_fn,  # Use custom collate function
)

In [4]:
# # Test the DataLoader
# for images, bboxes in train_loader:
#     print("Image batch shape:", images.shape)
#     print("Bounding box batch shape:", bboxes.shape)
#     break

In [5]:
image, bbox = next(iter(train_loader))

In [6]:
plt.rcParams["savefig.bbox"] = "tight"

def show(imgs):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fig, axs = plt.subplots(ncols=len(imgs), squeeze=False)
    for i, img in enumerate(imgs):
        img = img.detach()
        img = F.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [None]:
bbox_xyxy = box_convert(bbox[11], in_fmt="cxcywh", out_fmt="xyxy")

result = draw_bounding_boxes(image[11], bbox_xyxy, colors="blue", width=5)
show(result)

In [8]:

class CustomBackbone(nn.Module):
    def __init__(self):
        super(CustomBackbone, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),  # Input: [B, 3, H, W]
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),                 # Downsample
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

    def forward(self, x):
        return self.conv_layers(x)  # Output: Feature maps
    
class RPN(nn.Module):
    def __init__(self, in_channels, num_anchors):
        super(RPN, self).__init__()
        self.conv = nn.Conv2d(in_channels, 256, kernel_size=3, stride=1, padding=1)
        self.cls_layer = nn.Conv2d(256, num_anchors, kernel_size=1)  # Object vs Background
        self.reg_layer = nn.Conv2d(256, num_anchors * 4, kernel_size=1)  # Bounding box regression

    def forward(self, x):
        x = self.conv(x)
        cls_logits = self.cls_layer(x)  # Shape: [B, num_anchors, H, W]
        reg_logits = self.reg_layer(x)  # Shape: [B, num_anchors * 4, H, W]
        return cls_logits, reg_logits

class ROIPooling(nn.Module):
    def __init__(self, output_size):
        super(ROIPooling, self).__init__()
        self.roi_pool = ops.roi_pool

    def forward(self, features, proposals):
        # proposals: List of [x_min, y_min, x_max, y_max] for each image
        return self.roi_pool(features, proposals, output_size=(7, 7))  # Adjust output size as needed
    
class DetectionHead(nn.Module):
    def __init__(self, in_features, num_classes=2):  # 2 for background and fire
        super(DetectionHead, self).__init__()
        self.flatten = nn.Flatten() 
        self.fc = nn.Sequential(
            nn.Linear(in_features, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
        )
        self.cls_layer = nn.Linear(1024, num_classes)  # Class probabilities (background vs fire)
        self.reg_layer = nn.Linear(1024, num_classes * 4)  # Bounding box offsets

    def forward(self, x):
        x = self.flatten(x)
        x = self.fc(x)
        class_logits = self.cls_layer(x)
        bbox_offsets = self.reg_layer(x)
        return class_logits, bbox_offsets


class DetectionLoss(nn.Module):
    def __init__(self):
        super(DetectionLoss, self).__init__()
        self.cls_loss = nn.CrossEntropyLoss()
        self.reg_loss = nn.SmoothL1Loss()

    def forward(self, cls_preds, reg_preds, cls_targets, reg_targets):
        loss_cls = self.cls_loss(cls_preds, cls_targets)
        
        # Compute regression loss (only if targets are available)
        if reg_targets.numel() > 0:
            reg_preds = reg_preds[:, :4]  # Single-class regression
            loss_reg = self.reg_loss(reg_preds, reg_targets)
        else:
            loss_reg = torch.tensor(0.0, device=cls_preds.device)

        return loss_cls + loss_reg

In [None]:

# Initialize components
backbone = CustomBackbone()
rpn = RPN(in_channels=64, num_anchors=9)  # Match with output of backbone
roi_pool = ROIPooling(output_size=(7, 7))
head = DetectionHead(in_features=64 * 7 * 7)  # Adjust for your classes
loss_fn = DetectionLoss()

# Optimizer
optimizer = torch.optim.Adam(
    list(backbone.parameters()) + 
    list(rpn.parameters()) + 
    list(head.parameters()), lr=1e-4
)

# Training
num_epochs = 10
for epoch in range(num_epochs):
    for idx, (images, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        # Forward pass
        features = backbone(images)
        rpn_logits, rpn_bboxes = rpn(features)
        
        # Generate proposals (e.g., NMS on RPN output)
        proposals = generate_proposals(rpn_logits, rpn_bboxes)
        # ROI pooling
        pooled_features = roi_pool(features, proposals)
        # Detection head
        cls_preds, reg_preds = head(pooled_features)
        # Create the labels for classification (cls_preds)
        labels = torch.ones(cls_preds.shape[0], dtype=torch.long).to(images.device)  # All ones for fire
        
        reg_targets = torch.ones([cls_preds.shape[0],4], dtype=torch.long).to(images.device)  # All ones for fire
        
        # Compute loss
        loss = loss_fn(cls_preds, reg_preds, labels, reg_targets)
        
        # Backpropagation
        loss.backward()
        optimizer.step()
        print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{idx+1}/{len(train_loader)}], Loss: {loss.item():.4f}")
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")
