In [12]:
# put this file to the archive file


import os
import json
import torch
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision.transforms as T
from PIL import Image
from pycocotools.coco import COCO
import numpy as np
from tqdm import tqdm
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import random

# Define the SeaTurtleDataset class
class SeaTurtleDataset(Dataset):
    def __init__(self, img_dir, ann_file, transforms=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.image_ids = list(self.coco.imgs.keys())
        self.transforms = transforms

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        anns = self.coco.loadAnns(ann_ids)

        # Load the image
        img_info = self.coco.loadImgs([img_id])[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        
        # Load masks and other annotations
        masks = []
        boxes = []
        labels = []
        for ann in anns:
            mask = self.coco.annToMask(ann)
            masks.append(mask)
            xmin, ymin, width, height = ann['bbox']
            xmax = xmin + width
            ymax = ymin + height
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(ann['category_id'])

        # Convert to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        area = torch.as_tensor([ann['area'] for ann in anns], dtype=torch.float32)
        iscrowd = torch.as_tensor([ann.get('iscrowd', 0) for ann in anns], dtype=torch.int64)
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
            "area": area,
            "iscrowd": iscrowd
        }

        if self.transforms:
            # Apply the transformations to image and target
            image, target = self.transforms(image, target)

        return image, target

# Define the collate_fn function
def collate_fn(batch):
    return tuple(zip(*batch))

# Define the function to calculate IOU
def calculate_iou(pred_mask, true_mask):
    pred_mask = pred_mask.cpu().numpy().astype(bool)
    true_mask = true_mask.cpu().numpy().astype(bool)

    intersection = np.logical_and(pred_mask, true_mask)
    union = np.logical_or(pred_mask, true_mask)
    iou_score = np.sum(intersection) / np.sum(union) if np.sum(union) > 0 else 0.0
    return iou_score

# Define the visualization function
def visualize_prediction(image, predicted_masks, true_masks, epoch, iou_score=None):
    fig, ax = plt.subplots(1, 3, figsize=(18, 6))

    # Convert image for visualization
    image = F.to_pil_image(image.cpu())

    # Display the original image
    ax[0].imshow(image)
    ax[0].set_title("Original Image")
    ax[0].axis("off")

    # Display the predicted masks
    predicted_mask = predicted_masks.sum(dim=0).cpu().numpy() > 0
    ax[1].imshow(image)
    ax[1].imshow(predicted_mask, alpha=0.5, cmap='jet')
    title_pred = f"Predicted Masks - Epoch {epoch+1}"
    if iou_score is not None:
        title_pred += f"\nIOU: {iou_score:.4f}"
    ax[1].set_title(title_pred)
    ax[1].axis("off")

    # Display the true masks
    true_mask = true_masks.sum(dim=0).cpu().numpy() > 0
    ax[2].imshow(image)
    ax[2].imshow(true_mask, alpha=0.5, cmap='jet')
    ax[2].set_title("True Masks")
    ax[2].axis("off")

    plt.show()

# Define transformations with data augmentation
class ComposeTransform:
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class RandomHorizontalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            image = F.hflip(image)
            width = image.width
            boxes = target["boxes"]
            boxes[:, [0, 2]] = width - boxes[:, [2, 0]]
            target["boxes"] = boxes
            if "masks" in target:
                target["masks"] = target["masks"].flip(-1)
        return image, target

class RandomVerticalFlip:
    def __init__(self, prob=0.5):
        self.prob = prob

    def __call__(self, image, target):
        if random.random() < self.prob:
            image = F.vflip(image)
            height = image.height
            boxes = target["boxes"]
            boxes[:, [1, 3]] = height - boxes[:, [3, 1]]
            target["boxes"] = boxes
            if "masks" in target:
                target["masks"] = target["masks"].flip(-2)
        return image, target

class ColorJitterTransform:
    def __init__(self, brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1):
        self.color_jitter = T.ColorJitter(brightness, contrast, saturation, hue)

    def __call__(self, image, target):
        image = self.color_jitter(image)
        return image, target

class ToTensor:
    def __call__(self, image, target):
        image = F.to_tensor(image)
        return image, target

# Function to get data loaders
def get_data_loaders(img_dir, ann_file, metadata_file, batch_size=2):
    # Load metadata
    metadata = pd.read_csv(metadata_file)
    
    # Define data transformations with data augmentation for training
    train_transforms = ComposeTransform([
        ColorJitterTransform(),
        RandomHorizontalFlip(),
        RandomVerticalFlip(),
        ToTensor(),
    ])

    # For validation and testing, we use only ToTensor
    val_test_transforms = ComposeTransform([
        ToTensor(),
    ])

    full_dataset = SeaTurtleDataset(img_dir, ann_file, transforms=None)
    
    # Create mappings between image IDs and filenames
    img_to_filename = {img_id: full_dataset.coco.loadImgs(img_id)[0]["file_name"] for img_id in full_dataset.image_ids}
    file_to_img = {v: k for k, v in img_to_filename.items()}
    
    # Split the dataset based on the 'split_open' column in metadata
    train_img_ids = [file_to_img[filename] for filename in metadata[metadata["split_open"] == "train"]["file_name"] if filename in file_to_img]
    val_img_ids = [file_to_img[filename] for filename in metadata[metadata["split_open"] == "valid"]["file_name"] if filename in file_to_img]
    test_img_ids = [file_to_img[filename] for filename in metadata[metadata["split_open"] == "test"]["file_name"] if filename in file_to_img]
    
    # Verify the size of each split
    train_size, val_size, test_size = len(train_img_ids), len(val_img_ids), len(test_img_ids)
    print(f"----------------- Dataset Split -----------------\n")
    print(f"Training set size: {train_size}")
    print(f"Validation set size: {val_size}")
    print(f"Test set size: {test_size}\n")
    
    # Create subsets for each split
    train_indices = [full_dataset.image_ids.index(img_id) for img_id in train_img_ids]
    val_indices = [full_dataset.image_ids.index(img_id) for img_id in val_img_ids]
    test_indices = [full_dataset.image_ids.index(img_id) for img_id in test_img_ids]
    
    # Assign transforms to datasets
    train_dataset = Subset(SeaTurtleDataset(img_dir, ann_file, transforms=train_transforms), train_indices)
    val_dataset = Subset(SeaTurtleDataset(img_dir, ann_file, transforms=val_test_transforms), val_indices)
    test_dataset = Subset(SeaTurtleDataset(img_dir, ann_file, transforms=val_test_transforms), test_indices)
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers=0)
    
    print(f"----------------- Data Loaders -----------------\n")
    print(f"Training loader: {len(train_loader)} batches")
    print(f"Validation loader: {len(val_loader)} batches")
    print(f"Test loader: {len(test_loader)} batches\n")
    
    return train_loader, val_loader, test_loader

# Function to train the model and save the best model
def train_model(train_loader, val_loader, model, optimizer, device, num_epochs=5):
    model.to(device)
    best_iou = 0.0  # Initialize the best IOU
    for epoch in range(num_epochs):
        model.train()
        print(f"Starting epoch {epoch+1}/{num_epochs}")
        
        epoch_loss = 0.0
        for images, targets in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Forward pass and loss computation
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            epoch_loss += losses.item()

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch {epoch+1} average training loss: {avg_loss}")

        # Validation step: calculate average IOU on validation set
        model.eval()
        val_iou_scores = []
        with torch.no_grad():
            for val_images, val_targets in tqdm(val_loader, desc=f"Validating Epoch {epoch+1}/{num_epochs}"):
                val_images = [img.to(device) for img in val_images]
                val_targets = [{k: v.to(device) for k, v in t.items()} for t in val_targets]
                outputs = model(val_images)

                for i, (val_image, output) in enumerate(zip(val_images, outputs)):
                    predicted_masks = output['masks'] > 0.5
                    true_masks = val_targets[i]['masks']

                    if predicted_masks.shape[0] > 0 and true_masks.shape[0] > 0:
                        pred_mask_combined = predicted_masks.sum(dim=0)
                        true_mask_combined = true_masks.sum(dim=0)
                        iou_score = calculate_iou(pred_mask_combined, true_mask_combined)
                        val_iou_scores.append(iou_score)
                    else:
                        val_iou_scores.append(0.0)

        average_val_iou = sum(val_iou_scores) / len(val_iou_scores)
        print(f"Epoch {epoch+1} average validation IOU: {average_val_iou:.4f}")

        # Check if this is the best model
        if average_val_iou > best_iou:
            best_iou = average_val_iou
            best_model_path = "best_maskrcnn_model.pth"
            torch.save(model.state_dict(), best_model_path)
            print(f"New best model saved with IOU: {best_iou:.4f}")

# Main execution
if __name__ == "__main__":
    # Set paths and parameters
    img_dir = "./turtles-data/data"
    ann_file = "./turtles-data/data/updated_annotations.json"
    metadata_file = "./turtles-data/data/metadata_splits.csv"
    batch_size = 2

    # Get data loaders
    train_loader, val_loader, test_loader = get_data_loaders(img_dir, ann_file, metadata_file, batch_size=batch_size)

    # Load the pre-trained Mask R-CNN model
    import torchvision
    from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
    from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
    from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

    weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
    model = maskrcnn_resnet50_fpn(weights=weights)

    num_classes = 4  # 3 classes + background

    # Replace the classifier head
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Replace the mask predictor
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

    # Set device and optimizer
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model.to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

    # Start training
    train_model(train_loader, val_loader, model, optimizer, device, num_epochs=5)

    # Load the best model and perform testing
    model.load_state_dict(torch.load('best_maskrcnn_model.pth'))
    model.to(device)
    model.eval()

    # Perform inference on the test set and calculate IOU
    iou_scores = []
    with torch.no_grad():
        for test_images, test_targets in tqdm(test_loader, desc="Testing"):
            test_images = [img.to(device) for img in test_images]
            test_targets = [{k: v.to(device) for k, v in t.items()} for t in test_targets]
            outputs = model(test_images)
            
            for i, (test_image, output) in enumerate(zip(test_images, outputs)):
                predicted_masks = output['masks'] > 0.5
                true_masks = test_targets[i]['masks']
                
                if predicted_masks.shape[0] > 0 and true_masks.shape[0] > 0:
                    pred_mask_combined = predicted_masks.sum(dim=0)
                    true_mask_combined = true_masks.sum(dim=0)
                    iou_score = calculate_iou(pred_mask_combined, true_mask_combined)
                    iou_scores.append(iou_score)
                else:
                    iou_scores.append(0.0)

                # Visualize prediction results (optional)
                # visualize_prediction(test_image, predicted_masks, true_masks, epoch=5, iou_score=iou_score)

    # Calculate average IOU
    average_iou = sum(iou_scores) / len(iou_scores)
    print(f"\nAverage IOU on Test Set: {average_iou:.4f}")


loading annotations into memory...
Done (t=1.75s)
creating index...
index created!
----------------- Dataset Split -----------------

Training set size: 5293
Validation set size: 1117
Test set size: 2299

loading annotations into memory...
Done (t=1.64s)
creating index...
index created!
loading annotations into memory...
Done (t=7.34s)
creating index...
index created!
loading annotations into memory...
Done (t=1.51s)
creating index...
index created!
----------------- Data Loaders -----------------

Training loader: 2647 batches
Validation loader: 559 batches
Test loader: 1150 batches

Starting epoch 1/5


Training Epoch 1/5:   0%|                   | 11/2647 [02:09<8:37:55, 11.79s/it]


KeyboardInterrupt: 