In [None]:
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import os
from pycocotools.coco import COCO
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import models
import torch.optim as optim
import torch.nn as nn
from sklearn.model_selection import train_test_split
import cv2 as cv
import numpy as np
import segmentation_models_pytorch as smp


resize_transform = A.Compose([
    A.Resize(512, 512),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

annotation_file = "./turtles-data/data/updated_annotations.json"
coco = COCO(annotation_file)

class SeaTurtleDataset( Dataset ):
    def __init__ (self, image_ids, transform = None):
        self.coco = COCO('./turtles-data/data/updated_annotations.json')
        self.image_ids = image_ids
        self.cat_ids = self.coco.getCatIds()
        self.transform = transform

    def __len__( self ):
        return len(self.image_ids)

    def __getitem__(self, index):
        image_id = self.image_ids[index]
        image_data = self.coco.loadImgs([image_id])[0]

        image_path = os.path.join('./turtles-data/data', image_data['file_name'])
        image = cv.imread(image_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        
        mask = self._getmask(self.image_ids[index], image)



        if self.transform is  not  None : 
            augmentations = self.transform(image=image, mask =mask) 
            image = augmentations[ 'image' ] 
            mask = augmentations[ 'mask' ] 

        return image, mask 
    
    def _getmask(self, image_id, image):
        '''
        Background (0):   000000000000000
        Turtle Body (1):  111111100000000
        Flippers (2):     112222200000000
        Head (3):         112222233300000

        Final Mask:       112222233300000
        '''
        categories = {
            'turtle': 1,
            'flipper': 2,
            'head': 3
        }
        # Initialize the final mask with zeros
        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

        # Process each category
        for category_name, category_id in categories.items():
            ann_ids = self.coco.getAnnIds(imgIds=image_id, catIds=category_id, iscrowd=None)
            annotations = self.coco.loadAnns(ann_ids)
            
            # Create a temporary mask for the current category
            temp_mask = np.zeros_like(mask)
            
            for ann in annotations:
                temp_mask += self.coco.annToMask(ann)
            # Assign category-specific value to the final mask
            if category_name == 'turtle':
                mask[temp_mask > 0] = 1
            elif category_name == 'flipper':
                mask[temp_mask > 0] = 2
            elif category_name == 'head':
                mask[temp_mask > 0] = 3

        return mask
    

In [None]:
train_ids, test_ids = train_test_split(coco.getImgIds()[:4000], test_size=0.1, random_state=42)
train_ids, val_ids = train_test_split(train_ids, test_size=0.2, random_state=42)

train_dataset = SeaTurtleDataset(train_ids, transform=resize_transform)
val_dataset = SeaTurtleDataset(val_ids, transform=resize_transform)
test_dataset = SeaTurtleDataset(test_ids, transform=resize_transform)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0,pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0,pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=True)


In [None]:
def compute_IoU(outputs, masks, target_class):
    # Get the predicted class for each pixel
    outputs = outputs.argmax(dim=1)
    intersection = ((outputs == target_class) & (masks == target_class)).sum().item()
    union = ((outputs == target_class) | (masks == target_class)).sum().item()
    if union == 0:
        return float('nan')
    else:
        return intersection / union
    
model = smp.DeepLabV3Plus(
    encoder_name="resnet101",
    encoder_weights="imagenet",  
    classes=4,                    
    activation='softmax2d'
)

num_classes = 4  # background + 3 classes
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


In [None]:
# chatgpt generated function
def visualize_prediction(image, mask, prediction, figsize=(15,5)):
   # Convert tensors to numpy arrays
   if isinstance(image, torch.Tensor):
       image = image.cpu().numpy()
       if image.shape[0] == 3:
           image = image.transpose(1, 2, 0)
       
   if isinstance(mask, torch.Tensor):
       mask = mask.cpu().numpy()
       
   if isinstance(prediction, torch.Tensor):
       prediction = prediction.cpu().numpy()
   
   # Denormalize image
   mean = np.array([0.485, 0.456, 0.406])
   std = np.array([0.229, 0.224, 0.225])
   image = np.clip((image * std + mean), 0, 1)
   
   # Create color maps for mask and prediction
   colors = [(0,0,0), (1,0,0), (0,1,0), (0,0,1)]
   colored_mask = np.zeros((*mask.shape, 3))
   colored_pred = np.zeros((*prediction.shape, 3))
   
   for i, color in enumerate(colors):
       colored_mask[mask == i] = color
       colored_pred[prediction == i] = color
   
   # Plot
   plt.figure(figsize=figsize)
   
   plt.subplot(1, 3, 1)
   plt.imshow(image)
   plt.title('Original Image')
   plt.axis('off')
   
   plt.subplot(1, 3, 2)
   plt.imshow(colored_mask)
   plt.title('Ground Truth Mask')
   plt.axis('off')
   
   plt.subplot(1, 3, 3)
   plt.imshow(colored_pred)
   plt.title('Model Prediction')
   plt.axis('off')
   
   plt.tight_layout()
   plt.show()

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


# Define class labels for each category
turtle_class = 1
flipper_class = 2
head_class = 3

num_epochs = 10
model.train()

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    epoch_loss = 0
    
    # Training loop
    for batch_idx, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device).long()
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        
        # Print every 5 batches
        if batch_idx % 5 == 0:  
            print(f"Batch {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}")
    
    # Average epoch loss
    avg_epoch_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_epoch_loss:.4f}")
    
    # Validation step
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device).long()
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss after Epoch {epoch+1}: {avg_val_loss:.4f}")

    # Compute mIoU on the test set
    turtle_IoUs, flipper_IoUs, head_IoUs = [], [], []

    with torch.no_grad():
        for images, masks in test_loader:
            images, masks = images.to(device), masks.to(device).long()
            outputs = model(images)
            
            # Compute IoU for each category
            for i in range(len(images)):  # Process each image in the batch
                turtle_IoUs.append(compute_IoU(outputs[i:i+1], masks[i:i+1], turtle_class))
                flipper_IoUs.append(compute_IoU(outputs[i:i+1], masks[i:i+1], flipper_class))
                head_IoUs.append(compute_IoU(outputs[i:i+1], masks[i:i+1], head_class))

    turtle_mIoU = np.nanmean(turtle_IoUs)
    flipper_mIoU = np.nanmean(flipper_IoUs)
    head_mIoU = np.nanmean(head_IoUs)
    
    print(f"Turtle (Carapace) mIoU on Test Set after Epoch {epoch+1}: {turtle_mIoU:.4f}")
    print(f"Flippers mIoU on Test Set after Epoch {epoch+1}: {flipper_mIoU:.4f}")
    print(f"Head mIoU on Test Set after Epoch {epoch+1}: {head_mIoU:.4f}")
    
    # Visualize last few predictions after each epoch
    with torch.no_grad():
        # Get a single batch from the train_loader
        batch_images, batch_masks = next(iter(train_loader))
        
        # Select the last 3 images in the batch
        images = batch_images[-3:].to(device)
        masks = batch_masks[-3:]
        
        # Generate model predictions
        outputs = model(images)
        predictions = torch.argmax(outputs, dim=1)
        
        # Visualize predictions
        print(f"\nPredictions after Epoch {epoch+1}:")
        for i in range(3):  # Display the last 3 images
            visualize_prediction(
                images[i],
                masks[i],
                predictions[i]
            )
    
    model.train()  # Set back to training mode
