In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.ops import RoIAlign, box_iou, clip_boxes_to_image
from torchvision.transforms import functional as TF
from tqdm import tqdm
import random

def collate_fn(batch):
    images, targets = zip(*batch)
    return list(images), list(targets)

# Custom Dataset class
class DummyDataset(Dataset):
    def __init__(self, num_samples=100, image_size=(512, 512), num_classes=2):
        self.num_samples = num_samples
        self.image_size = image_size
        self.num_classes = num_classes

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = torch.rand(3, *self.image_size)
        boxes = torch.tensor([[random.randint(50, 400), random.randint(50, 400),
                               random.randint(100, 500), random.randint(100, 500)]], dtype=torch.float32)
        labels = torch.tensor([random.randint(1, self.num_classes - 1)], dtype=torch.int64)
        target = {"boxes": boxes, "labels": labels}
        return image, target

# FasterRCNN Model
class FasterRCNN(nn.Module):
    def __init__(self, num_classes=2):
        super(FasterRCNN, self).__init__()
        resnet = resnet50(weights=ResNet50_Weights.DEFAULT)
        self.backbone = nn.Sequential(*list(resnet.children())[:-2])

        self.rpn_conv = nn.Conv2d(2048, 512, kernel_size=3, padding=1)
        self.rpn_cls = nn.Conv2d(512, 9 * 2, kernel_size=1)
        self.rpn_reg = nn.Conv2d(512, 9 * 4, kernel_size=1)

        self.roi_align = RoIAlign((7, 7), spatial_scale=1.0 / 16, sampling_ratio=2)
        self.fc_class = nn.Linear(2048 * 7 * 7, num_classes)
        self.fc_bbox = nn.Linear(2048 * 7 * 7, num_classes * 4)

    def forward(self, x, targets=None):
        features = self.backbone(x)
        rpn_feat = F.relu(self.rpn_conv(features))
        rpn_cls_logits = self.rpn_cls(rpn_feat)
        rpn_bbox_preds = self.rpn_reg(rpn_feat)

        # Dummy proposals for illustration (replace with actual logic)
        batch_size = x.size(0)
        proposals = torch.rand((batch_size * 100, 4), device=x.device)  # Flattened for ROI Align
        batch_indices = torch.arange(batch_size, device=x.device).repeat_interleave(100)
        proposals = torch.cat([batch_indices.unsqueeze(-1).float(), proposals], dim=1)  # Append batch indices

        roi_features = self.roi_align(features, proposals)
        roi_features = roi_features.view(roi_features.size(0), -1)

        class_logits = self.fc_class(roi_features)
        bbox_preds = self.fc_bbox(roi_features)

        output = {"boxes": bbox_preds, "labels": class_logits}
        return output


# Training loop
def train_one_epoch(model, data_loader, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, targets in data_loader:
        # Move images and targets to the device
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        optimizer.zero_grad()
        outputs = model(images, targets)

        # Dummy loss for illustration purposes
        loss = sum(torch.tensor(0.0, device=device) for _ in outputs.values())
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    return running_loss / len(data_loader)


# Evaluation loop
def evaluate(model, data_loader, device):
    model.eval()
    total_iou = 0.0
    num_samples = 0
    
    with torch.no_grad():
        for images, targets in tqdm(data_loader, desc="Validation", leave=False):
            images = torch.stack([img.to(device) for img in images])
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            outputs = model(images)

            for output, target in zip(outputs, targets):
                pred_boxes = output["boxes"].detach()
                target_boxes = target["boxes"]

                if len(pred_boxes) > 0 and len(target_boxes) > 0:
                    iou = box_iou(pred_boxes, target_boxes).mean().item()
                    total_iou += iou
                num_samples += 1

    return total_iou / num_samples if num_samples > 0 else 0.0

# Main training script
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyperparameters
    num_epochs = 5
    learning_rate = 1e-4

    # Dataset and DataLoader
    train_dataset = DummyDataset(num_samples=100)
    val_dataset = DummyDataset(num_samples=20)
    
    def collate_fn(batch):
        """Custom collate function for batching images and targets."""
        images, targets = zip(*batch)
        images = torch.stack(images)  # Stack images into a single tensor
        return images, list(targets)  # Keep targets as a list

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=4, collate_fn=collate_fn)

    # Model, optimizer, and scheduler
    model = FasterRCNN(num_classes=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        train_loss = train_one_epoch(model, train_loader, optimizer, device)
        val_iou = evaluate(model, val_loader, device)

        print(f"Train Loss: {train_loss:.4f}, Validation IoU: {val_iou:.4f}")


Epoch 1/5


TypeError: conv2d() received an invalid combination of arguments - got (list, Parameter, NoneType, tuple, tuple, tuple, int), but expected one of:
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, tuple of ints padding = 0, tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!list of [Tensor, Tensor, Tensor, Tensor]!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
 * (Tensor input, Tensor weight, Tensor bias = None, tuple of ints stride = 1, str padding = "valid", tuple of ints dilation = 1, int groups = 1)
      didn't match because some of the arguments have invalid types: (!list of [Tensor, Tensor, Tensor, Tensor]!, !Parameter!, !NoneType!, !tuple of (int, int)!, !tuple of (int, int)!, !tuple of (int, int)!, !int!)
