In [13]:
import os
import json
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
from PIL import Image
from pycocotools.coco import COCO
import numpy as np
import random
from tqdm import tqdm
from torchvision.transforms import functional as F
import matplotlib.pyplot as plt
import pandas as pd

# Custom Dataset class for SeaTurtleID2022
class SeaTurtleDataset(Dataset):
    def __init__(self, img_dir, ann_file, transforms=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.image_ids = list(self.coco.imgs.keys())
        self.transforms = transforms

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        ann_ids = self.coco.getAnnIds(imgIds=[img_id])
        anns = self.coco.loadAnns(ann_ids)

        # Load image
        img_info = self.coco.loadImgs([img_id])[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        
        # Load masks
        masks = []
        boxes = []
        labels = []
        for ann in anns:
            mask = self.coco.annToMask(ann)
            masks.append(mask)
            # Convert bbox format from [xmin, ymin, width, height] to [xmin, ymin, xmax, ymax]
            xmin, ymin, width, height = ann['bbox']
            xmax = xmin + width
            ymax = ymin + height
            boxes.append([xmin, ymin, xmax, ymax])  # Converted format
            labels.append(ann['category_id'])

        # Convert data to tensors
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        # Compute area and handle target dict
        area = torch.as_tensor([ann['area'] for ann in anns], dtype=torch.float32)
        iscrowd = torch.as_tensor([ann.get('iscrowd', 0) for ann in anns], dtype=torch.int64)  # Ensure 'iscrowd' key exists
        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id]),
            "area": area,
            "iscrowd": iscrowd
        }

        if self.transforms:
            image = self.transforms(image)

        return image, target

# Split dataset into train and val sets using metadata_splits.csv
def split_dataset(metadata_file, train_size=1000, val_size=200):
    splits = pd.read_csv(metadata_file)
    
    # Use the 'split_open' column for splitting
    train_ids = splits[splits['split_open'] == 'train']['id'].tolist()
    val_ids = splits[splits['split_open'] == 'val']['id'].tolist()

    # Randomly sample specified number of training and validation samples
    train_ids = random.sample(train_ids, min(train_size, len(train_ids)))
    val_ids = random.sample(val_ids, min(val_size, len(val_ids)))

    return train_ids, val_ids

# Create data loaders
def collate_fn(batch):
    return tuple(zip(*batch))

def get_data_loaders(img_dir, ann_file, metadata_file, batch_size=2):
    train_ids, val_ids = split_dataset(metadata_file, train_size=1000, val_size=200)

    # Create datasets
    transforms = T.Compose([
        T.ToTensor(),  # Converts image to tensor and normalizes to [0, 1]
    ])

    train_dataset = SeaTurtleDataset(img_dir, ann_file, transforms=transforms)
    val_dataset = SeaTurtleDataset(img_dir, ann_file, transforms=transforms)

    # Filter datasets based on sampled IDs
    train_dataset.image_ids = [img_id for img_id in train_dataset.image_ids if img_id in train_ids]
    val_dataset.image_ids = [img_id for img_id in val_dataset.image_ids if img_id in val_ids]

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

    return train_loader, val_loader

# Example usage
img_dir = "./turtles-data/data"
ann_file = "./turtles-data/data/updated_annotations.json"
metadata_file = "./turtles-data/data/metadata_splits.csv"
batch_size = 1

train_loader, val_loader = get_data_loaders(img_dir, ann_file, metadata_file, batch_size=batch_size)

# Load the Mask R-CNN model
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn, MaskRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

# Load pre-trained model with recommended weights
weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
model = maskrcnn_resnet50_fpn(weights=weights)

num_classes = 4  # 3 classes (turtle, flipper, head) + background

# Get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features

# Replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# Get number of input features for the mask classifier
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
hidden_layer = 256

# Replace the mask predictor with a new one
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, num_classes)

# Training loop example
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

def visualize_prediction(image, predicted_masks, true_masks, epoch):
    """
    Visualizes predictions and ground truth masks on a single image.
    """
    fig, ax = plt.subplots(1, 3, figsize=(18, 6))

    # Convert image tensor to PIL image for visualization
    image = F.to_pil_image(image.cpu())

    # Plot input image
    ax[0].imshow(image)
    ax[0].set_title("Input Image")
    ax[0].axis("off")

    # Plot predicted mask
    predicted_mask = predicted_masks.sum(dim=0).cpu().numpy() > 0  # Aggregate masks for visualization
    ax[1].imshow(image)
    ax[1].imshow(predicted_mask, alpha=0.5, cmap='jet')
    ax[1].set_title(f"Predicted Masks - Epoch {epoch+1}")
    ax[1].axis("off")

    # Plot true mask
    true_mask = true_masks.sum(dim=0).cpu().numpy() > 0
    ax[2].imshow(image)
    ax[2].imshow(true_mask, alpha=0.5, cmap='jet')
    ax[2].set_title("True Masks")
    ax[2].axis("off")

    plt.show()

def train_model(train_loader, val_loader, model, optimizer, device, num_epochs=5):
    model.to(device)
    for epoch in range(num_epochs):
        model.train()
        print(f"Starting epoch {epoch+1}/{num_epochs}")
        
        for images, targets in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            # Forward pass and loss calculation
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())

            optimizer.zero_grad()
            losses.backward()
            optimizer.step()

        print(f"Epoch {epoch+1} completed with loss: {losses.item()}")

        # Validation step: Show predictions for a few images
        model.eval()  # Set the model to evaluation mode
        with torch.no_grad():
            for val_images, val_targets in val_loader:
                val_images = [img.to(device) for img in val_images]
                val_targets = [{k: v.to(device) for k, v in t.items()} for t in val_targets]
                outputs = model(val_images)  # Get model predictions

                # Visualize the first validation image's predictions vs ground truth
                for i, (val_image, output) in enumerate(zip(val_images, outputs)):
                    predicted_masks = output['masks'] > 0.5  # Threshold to get binary masks
                    true_masks = val_targets[i]['masks']
                    visualize_prediction(val_image, predicted_masks, true_masks, epoch)

                    break  # Only visualize one sample per epoch
                break  # Only show one batch per epoch

# Call the training function
train_model(train_loader, val_loader, model, optimizer, device, num_epochs=5)


loading annotations into memory...
Done (t=3.30s)
creating index...
index created!
loading annotations into memory...
Done (t=1.67s)
creating index...
index created!
Starting epoch 1/5


Training Epoch 1/5: 100%|███████████████████| 999/999 [1:53:55<00:00,  6.84s/it]


Epoch 1 completed with loss: 0.5177559852600098
Starting epoch 2/5


Training Epoch 2/5: 100%|███████████████████| 999/999 [2:52:08<00:00, 10.34s/it]


Epoch 2 completed with loss: 0.33852246403694153
Starting epoch 3/5


Training Epoch 3/5: 100%|███████████████████| 999/999 [1:39:40<00:00,  5.99s/it]


Epoch 3 completed with loss: 0.20086316764354706
Starting epoch 4/5


Training Epoch 4/5: 100%|███████████████████| 999/999 [1:57:04<00:00,  7.03s/it]


Epoch 4 completed with loss: 0.280353307723999
Starting epoch 5/5


Training Epoch 5/5: 100%|███████████████████| 999/999 [5:09:19<00:00, 18.58s/it]

Epoch 5 completed with loss: 0.3652145266532898





In [5]:
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import os
from pycocotools.coco import COCO
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import models
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
import cv2 as cv
import numpy as np

resize_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

annotation_file = "./turtles-data/data/updated_annotations.json"
coco = COCO(annotation_file)

class SeaTurtleDataset( Dataset ):
    def __init__ (self, image_ids, transform = None):
        self.coco = COCO('./turtles-data/data/updated_annotations.json')
        self.image_ids = image_ids
        self.cat_ids = self.coco.getCatIds()
        self.transform = transform

    def __len__( self ):
        return len(self.image_ids)

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_data = self.coco.loadImgs([image_id])[0]

        image_path = os.path.join('./turtles-data/data', image_data['file_name'])
        image = cv.imread(image_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        
        mask = self._getmask(self.image_ids[index], image)



        if self.transform is  not  None : 
            augmentations = self.transform(image=image, mask =mask) 
            image = augmentations[ 'image' ] 
            mask = augmentations[ 'mask' ] 

        return image, mask 
    
    def _getmask(self, image_id, image):
        '''
        Background (0):   000000000000000
        Turtle Body (1):  111111100000000
        Flippers (2):     112222200000000
        Head (3):         112222233300000

        Final Mask:       112222233300000
        '''
        categories = {
            'turtle': 1,
            'flipper': 2,
            'head': 3
        }
        # Initialize the final mask with zeros
        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

        # Process each category
        for category_name, category_id in categories.items():
            ann_ids = self.coco.getAnnIds(imgIds=image_id, catIds=category_id, iscrowd=None)
            annotations = self.coco.loadAnns(ann_ids)
            
            # Create a temporary mask for the current category
            temp_mask = np.zeros_like(mask)
            
            for ann in annotations:
                temp_mask += self.coco.annToMask(ann)
            # Assign category-specific value to the final mask
            if category_name == 'turtle':
                mask[temp_mask > 0] = 1
            elif category_name == 'flipper':
                mask[temp_mask > 0] = 2
            elif category_name == 'head':
                mask[temp_mask > 0] = 3

        return mask
    
train_ids, test_ids = train_test_split(coco.getImgIds()[:4000], test_size=0.1, random_state=42)
train_ids, val_ids = train_test_split(train_ids, test_size=0.2, random_state=42)

train_dataset = SeaTurtleDataset(train_ids, transform=resize_transform)
val_dataset = SeaTurtleDataset(val_ids, transform=resize_transform)
test_dataset = SeaTurtleDataset(test_ids, transform=resize_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0,pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=True)




def get_deeplab_model(num_classes):
    model = models.segmentation.deeplabv3_resnet50(pretrained=True)
    model.classifier = models.segmentation.deeplabv3.DeepLabHead(2048, num_classes)
    return model

def compute_IoU(outputs, masks, target_class):
    # Get the predicted class for each pixel
    outputs = outputs.argmax(dim=1)
    intersection = ((outputs == target_class) & (masks == target_class)).sum().item()
    union = ((outputs == target_class) | (masks == target_class)).sum().item()
    if union == 0:
        return float('nan')
    else:
        return intersection / union
    
num_classes = 4  # background + 3 classes
model = get_deeplab_model(num_classes)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)



def visualize_prediction(image, mask, prediction, figsize=(15,5)):
   # Convert tensors to numpy arrays
   if isinstance(image, torch.Tensor):
       image = image.cpu().numpy()
       if image.shape[0] == 3:
           image = image.transpose(1, 2, 0)
       
   if isinstance(mask, torch.Tensor):
       mask = mask.cpu().numpy()
       
   if isinstance(prediction, torch.Tensor):
       prediction = prediction.cpu().numpy()
   
   # Denormalize image
   mean = np.array([0.485, 0.456, 0.406])
   std = np.array([0.229, 0.224, 0.225])
   image = np.clip((image * std + mean), 0, 1)
   
   # Create color maps for mask and prediction
   colors = [(0,0,0), (1,0,0), (0,1,0), (0,0,1)]
   colored_mask = np.zeros((*mask.shape, 3))
   colored_pred = np.zeros((*prediction.shape, 3))
   
   for i, color in enumerate(colors):
       colored_mask[mask == i] = color
       colored_pred[prediction == i] = color
   
   # Plot
   plt.figure(figsize=figsize)
   
   plt.subplot(1, 3, 1)
   plt.imshow(image)
   plt.title('Original Image')
   plt.axis('off')
   
   plt.subplot(1, 3, 2)
   plt.imshow(colored_mask)
   plt.title('Ground Truth Mask')
   plt.axis('off')
   
   plt.subplot(1, 3, 3)
   plt.imshow(colored_pred)
   plt.title('Model Prediction')
   plt.axis('off')
   
   plt.tight_layout()
   plt.show()



model = get_deeplab_model(num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Define class labels for each category
turtle_class = 1
flipper_class = 2
head_class = 3

num_epochs = 10
model.train()

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    epoch_loss = 0
    
    # Training loop
    for batch_idx, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device).long()
        
        optimizer.zero_grad()
        outputs = model(images)['out']
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # Print every 5 batches
        if batch_idx % 5 == 0:  
            print(f"Batch {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}")
    
    # Average epoch loss
    avg_epoch_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_epoch_loss:.4f}")
    
    # Validation step
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device).long()
            outputs = model(images)['out']
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss after Epoch {epoch+1}: {avg_val_loss:.4f}")

    # Compute mIoU on the test set
    turtle_IoUs, flipper_IoUs, head_IoUs = [], [], []

    with torch.no_grad():
        for images, masks in test_loader:
            images, masks = images.to(device), masks.to(device).long()
            outputs = model(images)['out']
            
            # Compute IoU for each category
            for i in range(len(images)):  # Process each image in the batch
                turtle_IoUs.append(compute_IoU(outputs[i:i+1], masks[i:i+1], turtle_class))
                flipper_IoUs.append(compute_IoU(outputs[i:i+1], masks[i:i+1], flipper_class))
                head_IoUs.append(compute_IoU(outputs[i:i+1], masks[i:i+1], head_class))

    turtle_mIoU = np.nanmean(turtle_IoUs)
    flipper_mIoU = np.nanmean(flipper_IoUs)
    head_mIoU = np.nanmean(head_IoUs)
    
    print(f"Turtle (Carapace) mIoU on Test Set after Epoch {epoch+1}: {turtle_mIoU:.4f}")
    print(f"Flippers mIoU on Test Set after Epoch {epoch+1}: {flipper_mIoU:.4f}")
    print(f"Head mIoU on Test Set after Epoch {epoch+1}: {head_mIoU:.4f}")
    
    # Visualize last few predictions after each epoch
    with torch.no_grad():
        images, masks = [], []
        for batch_images, batch_masks in train_loader:
            images, masks = batch_images[-3:], batch_masks[-3:]
        
        images = images.to(device)
        outputs = model(images)['out']
        predictions = torch.argmax(outputs, dim=1)
        
        print(f"\nPredictions after Epoch {epoch+1}:")
        for i in range(min(3, len(images))):
            visualize_prediction(
                images[i],
                masks[i],
                predictions[i]
            )
    
    model.train()  # Set back to training mode
    

loading annotations into memory...
Done (t=4.08s)
creating index...
index created!
Annotation 0: bbox = [644.0, 441.0, 70.0, 78.0]
Annotation 1: bbox = [913.0, 582.0, 128.0, 184.0]
Annotation 2: bbox = [1137.0, 478.0, 167.0, 48.0]
Annotation 3: bbox = [660.0, 569.0, 50.0, 216.0]
Annotation 4: bbox = [646.0, 374.0, 655.0, 406.0]
