In [1]:
# COMP9517 Group Project

In [2]:
### Model 1 : KNN

# Please run this file after dataVisualisation.ipynb. Afterwards, you can run any model in any order.

In [3]:
import numpy as np
import cv2 as cv
from pycocotools.coco import COCO
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from torch.utils.data import Dataset, DataLoader, Subset
import pandas as pd
import albumentations as A
import os
import matplotlib.pyplot as plt


  check_for_updates()


In [4]:
# Define transformation for resizing and normalizing
resize_transform = A.Compose([
    A.Resize(512, 512),  # Resize images to a consistent size
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalization
])

class SeaTurtleDatasetForKNN(Dataset):
    def __init__(self, image_ids, transform=None):
        self.coco = COCO("./turtles-data/data/updated_annotations.json")
        self.image_ids = image_ids
        self.cat_ids = self.coco.getCatIds()
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        # Load image
        image_id = self.image_ids[index]
        image_data = self.coco.loadImgs([image_id])[0]

        image_path = os.path.join("./turtles-data/data", image_data["file_name"])
        image = cv.imread(image_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)

        # Prepare mask
        mask = self._getmask(image_id, image)

        # Apply transformations
        if self.transform is not None:
            augmented = self.transform(image=image)
            image = augmented["image"]
            mask = cv.resize(mask, (512, 512), interpolation=cv.INTER_NEAREST)

        # Flatten images and masks for KNN input
        return image.flatten(), mask.flatten()

    def _getmask(self, image_id, image):
        """
        Generate mask with labels for each category.
        """
        categories = {"turtle": 1, "flipper": 2, "head": 3}
        # Initialize the final mask with zeros
        mask = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

        # Process each category
        for category_name, category_id in categories.items():
            ann_ids = self.coco.getAnnIds(imgIds=image_id, catIds=category_id, iscrowd=None)
            annotations = self.coco.loadAnns(ann_ids)

            # Create a temporary mask for the current category
            temp_mask = np.zeros_like(mask)
            for ann in annotations:
                temp_mask += self.coco.annToMask(ann)
            
            # Assign category-specific value to the final mask
            if category_name == "turtle":
                mask[temp_mask > 0] = 1
            elif category_name == "flipper":
                mask[temp_mask > 0] = 2
            elif category_name == "head":
                mask[temp_mask > 0] = 3

        return mask

In [5]:
# Load metadata and map image IDs
metadata_path = "./turtles-data/data/metadata_splits.csv"
metadata = pd.read_csv(metadata_path)

In [6]:
# Initialize the COCO object once
coco = SeaTurtleDatasetForKNN([]).coco

# Get all image IDs once
image_ids = coco.getImgIds()

# Create the img_to_filename dictionary
img_to_filename = {img_id: coco.loadImgs(img_id)[0]["file_name"] for img_id in image_ids}

# Create the reverse mapping file_to_img dictionary
file_to_img = {v: k for k, v in img_to_filename.items()}


loading annotations into memory...
Done (t=10.58s)
creating index...
index created!


In [7]:
# Get image IDs for each split
train_img_ids = [file_to_img[filename] for filename in metadata[metadata["split_open"] == "train"]["file_name"] if filename in file_to_img]
val_img_ids = [file_to_img[filename] for filename in metadata[metadata["split_open"] == "valid"]["file_name"] if filename in file_to_img]
test_img_ids = [file_to_img[filename] for filename in metadata[metadata["split_open"] == "test"]["file_name"] if filename in file_to_img]


In [None]:
# Create limited Subset datasets with only 200 samples for training and 100 for testing
train_dataset = Subset(SeaTurtleDatasetForKNN(train_img_ids, transform=resize_transform), range(200))
test_dataset = Subset(SeaTurtleDatasetForKNN(test_img_ids, transform=resize_transform), range(100))


loading annotations into memory...


Done (t=10.43s)
creating index...
index created!
loading annotations into memory...
Done (t=22.41s)
creating index...
index created!


In [9]:
# Create DataLoaders
batch_size = 8
num_workers = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
#val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


In [10]:
# Prepare the data for KNN by processing in batches
X_train, y_train = [], []

In [None]:
# Iterate over DataLoader batches
for images, masks in train_loader:
    X_train.extend(images.numpy())
    y_train.extend(masks.numpy())

In [None]:
# Prepare test data
X_test, y_test = [], []

In [None]:
# Iterate over DataLoader batches
for images, masks in test_loader:
    X_test.extend(images.numpy())
    y_test.extend(masks.numpy())

In [None]:
# Convert to numpy arrays
X_train = np.array(X_train)
y_train = np.concatenate(y_train).astype(int)  # Flatten all masks into a single array for y_train
X_test = np.array(X_test)
y_test = np.concatenate(y_test).astype(int)  # Flatten all masks into a single array for y_test

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

def display_images_with_segmentation(image_ids, knn_model):
    for image_id in image_ids:
        # Load image metadata
        try:
            img_info = coco.loadImgs(image_id)[0]
            image_path = f"turtles-data/data/{img_info['file_name']}"
        except TypeError:
            print(f"[DEBUG]: This image ID is missing: {image_id}")
            continue

        try:
            # Load the original image
            image = np.array(Image.open(image_path))
            plt.figure(figsize=(20, 10))

            # Original Image
            plt.subplot(1, 4, 1)
            plt.imshow(image)
            plt.axis("off")
            plt.title("Original Image")

            # Ground Truth Annotations
            plt.subplot(1, 4, 2)
            plt.imshow(image)
            cat_ids = coco.getCatIds()
            ann_ids = coco.getAnnIds(imgIds=img_info["id"], catIds=cat_ids, iscrowd=None)
            anns = coco.loadAnns(ann_ids)
            coco.showAnns(anns)
            plt.axis("off")
            plt.title("Ground Truth Annotations")

            # Ground Truth Mask
            plt.subplot(1, 4, 3)
            mask = np.zeros((img_info["height"], img_info["width"]), dtype=np.uint8)
            for ann in anns:
                mask += coco.annToMask(ann)
            mask_resized = cv.resize(mask, (512, 512), interpolation=cv.INTER_NEAREST)
            plt.imshow(mask_resized, cmap="plasma")
            plt.axis("off")
            plt.title("Ground Truth Mask")

            # KNN Predicted Mask
            plt.subplot(1, 4, 4)
            image_flattened = image.flatten().reshape(1, -1)
            pred_mask_flat = knn_model.predict(image_flattened)
            pred_mask = pred_mask_flat.reshape((512, 512))
            plt.imshow(pred_mask, cmap="plasma")
            plt.axis("off")
            plt.title("KNN Predicted Mask")

            plt.tight_layout()
            plt.show()

        except FileNotFoundError:
            print(f"[DEBUG]: This image was already deleted: {image_id}. It has been removed from the dataset.")



In [None]:
# Initialize and train KNN
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)

In [None]:
# Predict and evaluate
y_pred = knn.predict(X_test)
print(classification_report(y_test, y_pred))

In [None]:
# Display segmentation results for a few test images
sample_image_ids = test_img_ids[:10]  # Use a small sample for quick visualization
display_images_with_segmentation(sample_image_ids, knn)
