In [None]:
import torch
from torchvision import models, transforms
from torchvision.datasets import CocoDetection
from torch.utils.data import DataLoader
from torch.optim import Adam
import matplotlib.pyplot as plt
from matplotlib import patches
from PIL import Image
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def count_nan(tensor):
    return torch.sum(torch.isnan(tensor)).item()

def collate_fn(batch):
    max_boxes = max(len(item[1]['boxes']) for item in batch)
    images = []
    boxes = []
    labels = []
    nan_count = 0  # Initialize the NaN count

    for img, target in batch:
        images.append(img)
        
        num_boxes = len(target['boxes'])
# Pad boxes to the maximum number of boxes in the batch
        padded_boxes = torch.cat([target['boxes'], torch.zeros(max_boxes - num_boxes, 4)], dim=0)
# Count NaN values in boxes
        nan_count += count_nan(padded_boxes)
        
# Pad labels to match the max number of boxes
        padded_labels = torch.cat([target['labels'], torch.zeros(max_boxes - len(target['labels']), dtype=torch.int64)], dim=0)
# Count NaN values in labels
        nan_count += count_nan(padded_labels)
        
        boxes.append(padded_boxes)
        labels.append(padded_labels)
    
    # Stack images, boxes, and labels tensors
    images = torch.stack(images, 0)  # Stack images
    boxes = torch.stack(boxes, 0)    # Stack boxes
    labels = torch.stack(labels, 0)  # Stack labels
    
    # Count NaN values in images
    nan_count += count_nan(images)
    
    return images, {"boxes": boxes, "labels": labels}, nan_count

class CustomCocoDetection(CocoDetection):
    def __getitem__(self, idx):
        img, target = super().__getitem__(idx)
        
# Print the target to inspect its format
        print(f"Target at index {idx}: {target}")
        
# Ensure target is not empty
        if not target or len(target) == 0:
            target = {"boxes": torch.empty((0, 4), dtype=torch.float32), "labels": torch.empty((0,), dtype=torch.int64)}
        else:
# Convert the bbox from [x, y, width, height] to [x_min, y_min, x_max, y_max]
            boxes = [obj["bbox"] for obj in target]
            boxes = torch.tensor(
                [[x, y, x + w, y + h] for x, y, w, h in boxes if w > 0 and h > 0], dtype=torch.float32
            )

# Add labels
            labels = torch.tensor([obj["category_id"] for obj in target], dtype=torch.int64)

# Handle empty boxes
            if len(boxes) == 0:
                boxes = torch.empty((0, 4), dtype=torch.float32)
                labels = torch.empty((0,), dtype=torch.int64)

            target = {"boxes": boxes, "labels": labels}

# Apply transformations if any
        if self.transforms:
            img, target = self.transforms(img, target)

        return img, target

# Custom transform
class CustomTransform:
    def __call__(self, img, target):
# Ensure image is in a valid tensor format
        if isinstance(img, torch.Tensor):
            if img.ndimension() == 3:  # Tensor format (C, H, W)
                channels = img.shape[0]
                # If there are more than 4 channels, it's an invalid image
                if channels > 4:  
                    raise ValueError(f"Invalid number of channels: {channels}, expected 1 (grayscale) or 3 (RGB) channels.")
# If there are 3 channels, assume RGB and use ToPILImage
                if channels == 3:
                    img = transforms.ToPILImage()(img)
# If there's 1 channel, assume grayscale and convert to PIL
                elif channels == 1:
                    img = transforms.ToPILImage()(img)
            else:
                raise ValueError(f"Invalid image dimensions: {img.shape}, expected (C, H, W) format.")

# Apply ToTensor to the PIL image
        img = transforms.ToTensor()(img)  # Apply ToTensor to PIL image
        return img, target

# paths:
train_root = '/Users/lilapetri/Documents/assignment Lensor/vehicle_damage_detection_dataset/images/train'
train_annotations = '/Users/lilapetri/Documents/assignment Lensor/vehicle_damage_detection_dataset/annotations/instances_train.json'

val_root = '/Users/lilapetri/Documents/assignment Lensor/vehicle_damage_detection_dataset/images/val'
val_annotations = '/Users/lilapetri/Documents/assignment Lensor/vehicle_damage_detection_dataset/annotations/instances_val.json'

test_root = '/Users/lilapetri/Documents/assignment Lensor/vehicle_damage_detection_dataset/images/test'
test_annotations = '/Users/lilapetri/Documents/assignment Lensor/vehicle_damage_detection_dataset/annotations/instances_test.json'

transform = CustomTransform()

train_dataset = CustomCocoDetection(root=train_root, annFile=train_annotations, transforms=transform)
val_dataset = CustomCocoDetection(root=val_root, annFile=val_annotations, transforms=transform)
test_dataset = CustomCocoDetection(root=test_root, annFile=test_annotations, transforms=transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
model.train()  # Set the model to training mode

optimizer = Adam(model.parameters(), lr=1e-5)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training Loop
total_nan_count = 0
for epoch in range(1):  # Example: 10 epochs
    running_loss = 0
    for images, targets, nan_count in train_loader:
# Add NaN count from the current batch
        total_nan_count += nan_count 

# Convert images to a list of tensors
        if isinstance(images, torch.Tensor):
            images = [img.to(device) for img in torch.unbind(images, dim=0)]
        else:
            print(f"Unexpected images format: {images}")
            continue

# Fix target format by unbatching boxes and labels
        if isinstance(targets, dict):
            if "boxes" in targets and "labels" in targets:
                boxes = torch.unbind(targets["boxes"], dim=0)  # List of [N, 4]
                labels = torch.unbind(targets["labels"], dim=0)  # List of [N]
                targets = [{"boxes": b, "labels": l} for b, l in zip(boxes, labels)]
            else:
                print(f"Unexpected targets format: {targets}")
                continue
        else:
            print(f"Unexpected targets format: {targets}")
            continue

# Validate and clean bounding boxes
        filtered_targets = []
        for target in targets:
            if "boxes" in target and "labels" in target:
                boxes = target["boxes"]
                labels = target["labels"]

# Filter invalid boxes
                valid_indices = (boxes[:, 2] > boxes[:, 0]) & (boxes[:, 3] > boxes[:, 1])
                if valid_indices.sum() == 0:  # Skip if no valid boxes
                    print(f"No valid boxes in target: {target}")
                    continue

                target["boxes"] = boxes[valid_indices]
                target["labels"] = labels[valid_indices]
                filtered_targets.append(target)
            else:
                print(f"Skipping target with missing keys: {target}")
        
# Handle not valid targets
        if len(filtered_targets) == 0:
            print("Skipping batch with no valid targets.")
            continue

        targets = filtered_targets

# Move target tensors to device
        targets = [
            {k: v.to(device) for k, v in t.items() if isinstance(v, torch.Tensor)} for t in targets]
        optimizer.zero_grad()
        try:
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            losses.backward()
            optimizer.step()
            running_loss += losses.item()
        except Exception as e:
            print(f"Error during forward pass: {e}")
            continue
    print(f"Epoch {epoch + 1} completed with loss: {running_loss/len(train_loader)}")
    print(f"NaN values detected in epoch {epoch + 1}: {total_nan_count}")


In [None]:
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score

model.eval()  # Set the model to evaluation mode
batch_size = 32  # Choose an appropriate batch size based on your GPU memory

# Initialize counters and storage
total_predictions = 0
max_predictions = 500
y_true = []  # Ground truth labels
y_pred = []  # Predicted labels

# Helper function to calculate IoU
def calculate_iou(box1, box2):
    # Compute intersection
    x1_inter = max(box1[0], box2[0])
    y1_inter = max(box1[1], box2[1])
    x2_inter = min(box1[2], box2[2])
    y2_inter = min(box1[3], box2[3])
    
    if x2_inter <= x1_inter or y2_inter <= y1_inter:
        return 0.0  # No overlap
    
    intersection = (x2_inter - x1_inter) * (y2_inter - y1_inter)
    
    # Compute union
    area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
    area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
    union = area1 + area2 - intersection
    
    return intersection / union

# Iterate over the dataset in batches
for i in range(0, len(train_dataset), batch_size):
    if total_predictions >= max_predictions:
        break  # Stop processing if the maximum number of predictions is reached
    
    batch_images = []
    batch_targets = []
    
    # Collect the next batch
    for j in range(i, min(i + batch_size, len(train_dataset))):
        sample_image, sample_target = train_dataset[j]
        batch_images.append(sample_image.to(device))
        batch_targets.append(sample_target)
    
    # Stack images into a batch
    batch_images = torch.stack(batch_images)
    
    with torch.no_grad():
        predictions = model(batch_images)
    
    # Count the number of predictions in this batch
    total_predictions += len(predictions)
    
    for idx, prediction in enumerate(predictions):
        # Skip if no predictions or actuals for this image
        if len(prediction['boxes']) == 0 or len(batch_targets[idx]['boxes']) == 0:
            continue
        
        # Average prediction
        avg_pred_box = torch.mean(prediction['boxes'], dim=0).round().int().tolist()
        avg_pred_label = prediction['labels'].mode().values.item()
        
        # Average actuals
        avg_actual_box = torch.mean(batch_targets[idx]['boxes'], dim=0).round().int().tolist()
        avg_actual_label = batch_targets[idx]['labels'].mode().values.item()
        
        # Append to results
        y_pred.append(avg_pred_label)
        y_true.append(avg_actual_label)
        
        # Calculate IoU (optional, for bounding box evaluation)
        iou = calculate_iou(avg_pred_box, avg_actual_box)
        print(f"Image {i + idx}: IoU={iou:.2f}, Pred Label={avg_pred_label}, Actual Label={avg_actual_label}")
        
        # Visualization (optional)
        fig, ax = plt.subplots(1)
        ax.imshow(batch_images[idx].permute(1, 2, 0).cpu().numpy())
        
        # Draw predicted bounding box
        rect_pred = patches.Rectangle(
            (avg_pred_box[0], avg_pred_box[1]),
            avg_pred_box[2] - avg_pred_box[0],
            avg_pred_box[3] - avg_pred_box[1],
            linewidth=2,
            edgecolor='r',
            facecolor='none',
            label='Prediction'
        )
        ax.add_patch(rect_pred)
        
        # Draw actual bounding box
        rect_actual = patches.Rectangle(
            (avg_actual_box[0], avg_actual_box[1]),
            avg_actual_box[2] - avg_actual_box[0],
            avg_actual_box[3] - avg_actual_box[1],
            linewidth=2,
            edgecolor='g',
            facecolor='none',
            label='Actual'
        )
        ax.add_patch(rect_actual)
        plt.legend()
        plt.show()

accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
f1 = f1_score(y_true, y_pred, average='macro')

print(f"Accuracy: {accuracy:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")
print(f"F1 Score: {f1:.2f}")
