In [1]:
# Load the autoreload extension
%load_ext autoreload

# Set autoreload mode
%autoreload 2

In [2]:
import albumentations as A
import mermaidseg.datasets.dataset
import numpy as np
from mermaidseg.io import setup_config, get_parser, update_config_with_args
import copy
import torch
from matplotlib import pyplot as plt

In [3]:
device_count = torch.cuda.device_count()
for i in range(device_count):
    print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
    
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = True

CUDA Device 0: Tesla T4
cuda


In [4]:
from torch.utils.data import DataLoader, random_split
from mermaidseg.model.meta import MetaModel
from mermaidseg.model.eval import EvaluatorSemanticSegmentation
from mermaidseg.logger import Logger
from mermaidseg.model.train import train_model

2025-12-17 15:00:43.933724: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-12-17 15:00:43.956509: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765983643.986520  110916 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765983643.995859  110916 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-12-17 15:00:44.024129: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

In [5]:
# Start off with a configuration file
cfg = setup_config(config_path='../configs/linear-dinov3-concept.yaml', config_base_path='../configs/concept_mermaid.yaml')

# Update the initial configuration file with command line arguments 
# (in the case of a notebook run these can be defined explicitly here)
args_input = "--run-name=dinov3-test-concept-run --batch-size=2 --epochs=5 --log-epochs=1"
args_input = args_input.split(" ")

parser = get_parser()
args = parser.parse_args(args_input)

cfg = update_config_with_args(cfg, args)
cfg_logger = copy.deepcopy(cfg)

In [6]:
transforms = {}
for split in cfg.augmentation:
    transforms[split] = A.Compose(
        [
            getattr(A, transform_name)(**transform_params) for transform_name, transform_params
                                                                 in cfg.augmentation[split].items()
        ]
    )

In [7]:
dataset_name = cfg.data.pop("name", None)
batch_size = cfg.data.pop("batch_size", 4)
whitelist_sources = cfg.data.pop("whitelist_sources", None)

In [8]:
dataset_dict = {}
dataset_dict["train"] = getattr(mermaidseg.datasets.dataset, dataset_name)(transform = transforms[split], **cfg.data)

In [9]:
len(dataset_dict["train"])

8073

In [10]:
total_size = len(dataset_dict["train"])
train_size = int(0.7 * total_size)
val_size = int(0.15 * total_size)
test_size = total_size - train_size - val_size

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset, test_dataset = random_split(dataset_dict["train"], [train_size, val_size, test_size], generator=generator)
train_dataset = torch.utils.data.Subset(train_dataset, range(5000))
val_dataset = torch.utils.data.Subset(val_dataset, range(1000))
test_dataset = torch.utils.data.Subset(test_dataset, range(1000))
# train_dataset = torch.utils.data.Subset(dataset_dict["train"], range(3000))
# val_dataset = torch.utils.data.Subset(dataset_dict["val"], range(500))
# test_dataset = torch.utils.data.Subset(dataset_dict["test"], range(500))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["val"].collate_fn)
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, drop_last=True, collate_fn = dataset_dict["test"].collate_fn)

In [11]:
print(f"Number of training batches: {len(train_loader)}")
print(f"Number of validation batches: {len(val_loader)}")
print(f"Number of test batches: {len(test_loader)}")

Number of training batches: 2500
Number of validation batches: 500
Number of test batches: 500


In [12]:
dataset_dict["train"].num_classes, dataset_dict["train"].num_concepts

(16, 20)

In [13]:
meta_model = MetaModel(run_name = cfg.run_name, 
                       num_classes = dataset_dict["train"].num_classes,
                       num_concepts = dataset_dict["train"].num_concepts,
                       device = device,
                       model_kwargs = cfg.model,
                       training_mode = cfg.training_mode,
                       training_kwargs = cfg.training,
                       concept_matrix = dataset_dict["train"].benthic_concept_matrix,
                       conceptid2labelid = dataset_dict["train"].conceptid2labelid,)

# Tests

In [14]:
import time 
import tqdm

In [15]:
time_dict = {}

## Typical Test

In [23]:
time_dict["typical"] = {"data_io":0, "forward":0}
num_iterations = 100
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["typical"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["typical"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [04:33<00:00,  2.73s/it]


In [24]:
time_dict["typical"]

{'data_io': 1.9751221752166752, 'forward': 0.691829288005829}

## Num Workers

In [26]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
time_dict["num_workers=2"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["num_workers=2"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["num_workers=2"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [02:20<00:00,  2.81s/it]


In [27]:
time_dict["num_workers=2"]

{'data_io': 2.024768481254578, 'forward': 0.7219918203353882}

In [28]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)
time_dict["num_workers=4"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["num_workers=4"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["num_workers=4"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

100%|██████████| 50/50 [02:35<00:00,  3.11s/it]


In [29]:
time_dict["num_workers=4"]

{'data_io': 2.2996981573104858, 'forward': 0.7402305412292479}

## Local Storage

In [31]:
dataset_dict["train"].df_images

Unnamed: 0,image_id,region_id,region_name
0,00059a47-03b8-47f3-adf6-3ab5616922cf,1d31d9ea-e78d-438b-8667-0d63d1aba257,Western Indo-Pacific
1,00086e76-2b0d-48ff-a25b-31020c226047,1d31d9ea-e78d-438b-8667-0d63d1aba257,Western Indo-Pacific
2,00138e67-611c-4e04-a382-46e0484f2f95,1d31d9ea-e78d-438b-8667-0d63d1aba257,Western Indo-Pacific
3,0015772a-fcc3-4bd8-bfb8-2a3b67520f35,983267a0-7349-4d3e-a23e-fb9353ca8ba5,Central Indo-Pacific
4,001580fa-3324-4053-a74b-c5ef08d49d07,1d31d9ea-e78d-438b-8667-0d63d1aba257,Western Indo-Pacific
...,...,...,...
8068,fff22808-8720-4a33-9947-f43e3b1ada6b,1d31d9ea-e78d-438b-8667-0d63d1aba257,Western Indo-Pacific
8069,fff84912-4e02-4724-8a90-2443b2b43130,983267a0-7349-4d3e-a23e-fb9353ca8ba5,Central Indo-Pacific
8070,fff8f7ad-2772-43f1-ad66-c65994bea5f2,983267a0-7349-4d3e-a23e-fb9353ca8ba5,Central Indo-Pacific
8071,fffc7eff-cb8d-47b3-9081-a1f6d092e781,1d31d9ea-e78d-438b-8667-0d63d1aba257,Western Indo-Pacific


In [16]:
import boto3
import io

In [17]:
def get_image_s3(
    s3,
    bucket: str,
    key: str,
    thumbnail: bool = False,
):
    """
    Fetches an image from an S3 bucket and returns it as a PIL Image object.
    Args:
        s3 (boto3.client): The Boto3 S3 client used to interact with S3.
        bucket (str): The name of the S3 bucket.
        key (str): The key (path) of the image in the S3 bucket.
        thumbnail (bool, optional): If True, fetches the thumbnail version of the image by modifying the key. Defaults to False.
    Returns:
        PIL.Image.Image: The image loaded from S3 as a PIL Image object.
    """

    if thumbnail:
        key = key.replace(".png", "_thumbnail.png")

    response = s3.get_object(Bucket=bucket, Key=key)
    image_data = response["Body"].read()

    image = Image.open(io.BytesIO(image_data))
    return image

def create_annotation_mask(
    annotations,
    shape,
    label2id,
    padding = None,
) -> np.ndarray:
    """
    Creates an annotation mask for a given image based on provided annotations.
    Args:
        annotations (pd.DataFrame): DataFrame containing annotation rows with 'row', 'col', and 'benthic_attribute_name' columns.
        shape (Tuple[int, int]): Shape of the output mask (height, width).
        label2id (Dict[str, int]): Mapping from label names to integer IDs.
    Returns:
        np.ndarray: Annotation mask with integer class IDs.
    """
    ## TODO: Make Padding percentage based so that it is applicable to all class sizes
    mask = np.zeros(shape[:2])
    for _, annotation in annotations.iterrows():
        if annotation["benthic_attribute_name"] is not None:
            if padding is not None and padding > 0:
                mask[
                    annotation["row"] - padding : annotation["row"] + padding,
                    annotation["col"] - padding : annotation["col"] + padding,
                ] = label2id[annotation["benthic_attribute_name"]]
            else:
                mask[annotation["row"], annotation["col"]] = label2id[
                    annotation["benthic_attribute_name"]
                ]

    return mask

In [18]:
def read_image(self, image_id: str, **row_kwargs):
    """
    Read an image given its ID. Needs to be implemented in subclasses.
    """
    key = f"mermaid/{image_id}.png"  # f"mermaid/{image_id}_thumbnail.png"
    image = np.array(
        get_image_s3(s3=self.s3, bucket=self.source_bucket, key=key).convert("RGB")
    )

    return image

In [19]:
from torch.utils.data import Dataset

class Dataset_tmp1(Dataset):
    def __init__(self):
        self.df_images = dataset_dict["train"].df_images
        self.transform = dataset_dict["train"].transform

    def __len__(self):
        return len(self.df_images)
    def __getitem__(self, idx: int):
        image_id = self.df_images.loc[idx, "image_id"]
        row_kwargs = self.df_images.loc[idx].to_dict()

        image = dataset_dict["train"].read_image(**row_kwargs)

        annotations = dataset_dict["train"].df_annotations.loc[
            dataset_dict["train"].df_annotations["image_id"] == image_id,
            [
                "row",
                "col",
                "benthic_attribute_name",
            ],
        ]

        mask = create_annotation_mask(
            annotations, image.shape, dataset_dict["train"].label2id, padding=dataset_dict["train"].padding
    )

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"].transpose(2, 0, 1)
            mask = transformed["mask"]

        return image, mask, annotations

    def collate_fn(self, batch):
        """
        Collate function for MermaidDataset and CoralNetDataset.
        Args:
            batch: List of tuples (image, mask, annotations)
        Returns:
            images: Tensor or ndarray batch of images
            masks: Tensor or ndarray batch of masks
            annotations: List of annotation DataFrames
        """
        # images, masks, annotations = zip(*batch)

        # Filter out entries where image or mask is None
        filtered = [
            (img, msk, ann)
            for img, msk, ann in batch
            if img is not None and msk is not None
        ]
        images, masks, annotations = zip(*filtered)

        # Handle empty batch
        if len(images) == 0:
            return torch.tensor([]), torch.tensor([]), []

        # Convert to tensors if they aren't already
        if isinstance(images[0], torch.Tensor):
            images = torch.stack(images)
            masks = torch.stack(masks)
        else:
            # Convert numpy arrays to tensors for consistency
            images = torch.stack(
                [
                    torch.from_numpy(img) if isinstance(img, np.ndarray) else img
                    for img in images
                ]
            )
            masks = torch.stack(
                [
                    torch.from_numpy(mask) if isinstance(mask, np.ndarray) else mask
                    for mask in masks
                ]
            )

        return images, masks

In [22]:
dataset_tmp1 = Dataset_tmp1()
train_loader = DataLoader(dataset_tmp1, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True, collate_fn = dataset_tmp1.collate_fn)
time_dict["typical_test_2"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["typical_test_2"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["typical_test_2"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  4%|▍         | 2/50 [00:10<04:00,  5.01s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 40.00 MiB. GPU 0 has a total capacity of 14.56 GiB of which 9.75 MiB is free. Process 51245 has 4.42 GiB memory in use. Process 72870 has 4.66 GiB memory in use. Process 93954 has 3.48 GiB memory in use. Including non-PyTorch memory, this process has 1.99 GiB memory in use. Of the allocated memory 1.67 GiB is allocated by PyTorch, and 202.55 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [51]:
time_dict["typical_test_2"]

{'data_io': 1.83486671924591, 'forward': 0.7137363386154176}

### Local

In [24]:
import os
from PIL import Image

# for idx in range(100):    
#         image_id = dataset_tmp1.df_images.loc[idx, "image_id"]
#         row_kwargs = dataset_tmp1.df_images.loc[idx].to_dict()

#         image = dataset_dict["train"].read_image(**row_kwargs)
#         out_dir = "../data_images"
#         os.makedirs(out_dir, exist_ok=True)
#         out_path = os.path.join(out_dir, f"image_{idx}.png")
#         Image.fromarray(image).save(out_path)

In [20]:
class Dataset_tmp2(Dataset):
    def __init__(self):
        self.df_images = dataset_dict["train"].df_images[:100]
        self.transform = dataset_dict["train"].transform

    def __len__(self):
        return len(self.df_images)
    def __getitem__(self, idx: int):
        image_id = self.df_images.loc[idx, "image_id"]
        row_kwargs = self.df_images.loc[idx].to_dict()

        image_path = os.path.join("../data_images", f"image_{idx}.png")
        image = np.array(Image.open(image_path).convert("RGB"))

        annotations = dataset_dict["train"].df_annotations.loc[
            dataset_dict["train"].df_annotations["image_id"] == image_id,
            ["row", "col", "benthic_attribute_name"],
        ]

        mask = create_annotation_mask(
            annotations, image.shape, dataset_dict["train"].label2id, padding=dataset_dict["train"].padding
    )

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"].transpose(2, 0, 1)
            mask = transformed["mask"]

        return image, mask, annotations

    def collate_fn(self, batch):
        """
        Collate function for MermaidDataset and CoralNetDataset.
        Args:
            batch: List of tuples (image, mask, annotations)
        Returns:
            images: Tensor or ndarray batch of images
            masks: Tensor or ndarray batch of masks
            annotations: List of annotation DataFrames
        """
        # images, masks, annotations = zip(*batch)

        # Filter out entries where image or mask is None
        filtered = [
            (img, msk, ann)
            for img, msk, ann in batch
            if img is not None and msk is not None
        ]
        images, masks, annotations = zip(*filtered)

        # Handle empty batch
        if len(images) == 0:
            return torch.tensor([]), torch.tensor([]), []

        # Convert to tensors if they aren't already
        if isinstance(images[0], torch.Tensor):
            images = torch.stack(images)
            masks = torch.stack(masks)
        else:
            # Convert numpy arrays to tensors for consistency
            images = torch.stack(
                [
                    torch.from_numpy(img) if isinstance(img, np.ndarray) else img
                    for img in images
                ]
            )
            masks = torch.stack(
                [
                    torch.from_numpy(mask) if isinstance(mask, np.ndarray) else mask
                    for mask in masks
                ]
            )

        return images, masks

In [64]:
dataset_tmp2 = Dataset_tmp2()
train_loader = DataLoader(dataset_tmp2, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True, collate_fn = dataset_tmp2.collate_fn)
time_dict["local_test"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["local_test"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["local_test"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [02:37<00:00,  3.16s/it]


In [65]:
time_dict["local_test"]

{'data_io': 2.348804125785827, 'forward': 0.7396216869354246}

In [66]:
dataset_tmp2 = Dataset_tmp2()
train_loader = DataLoader(dataset_tmp2, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, collate_fn = dataset_tmp2.collate_fn)
time_dict["local_test_num_workers_2"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["local_test_num_workers_2"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["local_test_num_workers_2"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

100%|██████████| 50/50 [02:49<00:00,  3.38s/it]


In [68]:
time_dict["local_test_num_workers_2"]

{'data_io': 2.576789755821229, 'forward': 0.7402181720733643}

## No Mask

In [83]:
class Dataset_tmp3(Dataset):
    def __init__(self):
        self.df_images = dataset_dict["train"].df_images[:100]
        self.transform = dataset_dict["train"].transform

    def __len__(self):
        return len(self.df_images)
    def __getitem__(self, idx: int):
        image_id = self.df_images.loc[idx, "image_id"]
        row_kwargs = self.df_images.loc[idx].to_dict()

        image = dataset_dict["train"].read_image(**row_kwargs)

        annotations = []
        # annotations = dataset_dict["train"].df_annotations.loc[
        #     dataset_dict["train"].df_annotations["image_id"] == image_id,
        #     ["row", "col", "benthic_attribute_name"],
        # ]

        mask = np.zeros(image.shape[:2], dtype = np.uint8)
        # mask = create_annotation_mask(
        #     annotations, image.shape, dataset_dict["train"].label2id, padding=dataset_dict["train"].padding
        # )

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"].transpose(2, 0, 1)
            mask = transformed["mask"]

        return image, mask, annotations

    def collate_fn(self, batch):
        """
        Collate function for MermaidDataset and CoralNetDataset.
        Args:
            batch: List of tuples (image, mask, annotations)
        Returns:
            images: Tensor or ndarray batch of images
            masks: Tensor or ndarray batch of masks
            annotations: List of annotation DataFrames
        """
        # images, masks, annotations = zip(*batch)

        # Filter out entries where image or mask is None
        filtered = [
            (img, msk, ann)
            for img, msk, ann in batch
            if img is not None and msk is not None
        ]
        images, masks, annotations = zip(*filtered)
    
        # Handle empty batch
        if len(images) == 0:
            return torch.tensor([]), torch.tensor([]), []

        # Convert to tensors if they aren't already
        if isinstance(images[0], torch.Tensor):
            images = torch.stack(images)
            masks = torch.stack(masks)
        else:
            # Convert numpy arrays to tensors for consistency
            images = torch.stack(
                [
                    torch.from_numpy(img) if isinstance(img, np.ndarray) else img
                    for img in images
                ]
            )
            masks = torch.stack(
                [
                    torch.from_numpy(mask) if isinstance(mask, np.ndarray) else mask
                    for mask in masks
                ]
            )

        return images, masks

In [85]:
dataset_tmp3 = Dataset_tmp3()
train_loader = DataLoader(dataset_tmp3, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True, collate_fn = dataset_tmp3.collate_fn)
time_dict["no_transform"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["no_transform"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["no_transform"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [01:56<00:00,  2.32s/it]


In [73]:
time_dict["no_transform"]

{'data_io': 1.7039057350158688, 'forward': 0.6439124155044555}

# Improve collation

In [78]:
class Dataset_tmp4(Dataset):
    def __init__(self):
        self.df_images = dataset_dict["train"].df_images[:100]
        self.transform = dataset_dict["train"].transform

    def __len__(self):
        return len(self.df_images)
    def __getitem__(self, idx: int):
        image_id = self.df_images.loc[idx, "image_id"]
        row_kwargs = self.df_images.loc[idx].to_dict()

        image = dataset_dict["train"].read_image(**row_kwargs)

        annotations = dataset_dict["train"].df_annotations.loc[
            dataset_dict["train"].df_annotations["image_id"] == image_id,
            ["row", "col", "benthic_attribute_name"],
        ]

        mask = np.zeros(image.shape[:2], dtype = np.uint8)
        mask = create_annotation_mask(
            annotations, image.shape, dataset_dict["train"].label2id, padding=dataset_dict["train"].padding
        )

        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed["image"].transpose(2, 0, 1)
            mask = transformed["mask"]

        return image, mask

    # def collate_fn(self, batch):
    #     """
    #     Collate function for MermaidDataset and CoralNetDataset.
    #     Args:
    #         batch: List of tuples (image, mask, annotations)
    #     Returns:
    #         images: Tensor or ndarray batch of images
    #         masks: Tensor or ndarray batch of masks
    #         annotations: List of annotation DataFrames
    #     """
    #     # images, masks, annotations = zip(*batch)

    #     # Filter out entries where image or mask is None
    #     # filtered = [
    #     #     (img, msk, ann)
    #     #     for img, msk, ann in batch
    #     #     if img is not None and msk is not None
    #     # ]
    #     # images, masks, annotations = zip(*filtered)
    #     images, masks, annotations = zip(*batch)
    #     # Handle empty batch
    #     if len(images) == 0:
    #         return torch.tensor([]), torch.tensor([]), []

    #     # Convert to tensors if they aren't already
    #     if isinstance(images[0], torch.Tensor):
    #         images = torch.stack(images)
    #         masks = torch.stack(masks)
    #     else:
    #         # Convert numpy arrays to tensors for consistency
    #         images = torch.stack(
    #             [
    #                 torch.from_numpy(img) if isinstance(img, np.ndarray) else img
    #                 for img in images
    #             ]
    #         )
    #         masks = torch.stack(
    #             [
    #                 torch.from_numpy(mask) if isinstance(mask, np.ndarray) else mask
    #                 for mask in masks
    #             ]
    #         )

    #     return images, masks

In [79]:
dataset_tmp4 = Dataset_tmp4()
train_loader = DataLoader(dataset_tmp4, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
time_dict["collation_1"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["collation_1"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["collation_1"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [02:03<00:00,  2.47s/it]


In [80]:
for idx in tqdm.tqdm(range(100)):
    image_id = dataset_tmp4.df_images.loc[idx, "image_id"]
    row_kwargs = dataset_tmp4.df_images.loc[idx].to_dict()
    image = dataset_dict["train"].read_image(**row_kwargs)

100%|██████████| 100/100 [00:32<00:00,  3.12it/s]


## No Transform

In [90]:
class Dataset_tmp5(Dataset):
    def __init__(self):
        self.df_images = dataset_dict["train"].df_images[:100]
        self.transform = dataset_dict["train"].transform

    def __len__(self):
        return len(self.df_images)
    def __getitem__(self, idx: int):
        image_id = self.df_images.loc[idx, "image_id"]
        row_kwargs = self.df_images.loc[idx].to_dict()

        image = dataset_dict["train"].read_image(**row_kwargs)

        annotations = dataset_dict["train"].df_annotations.loc[
            dataset_dict["train"].df_annotations["image_id"] == image_id,
            ["row", "col", "benthic_attribute_name"],
        ]

        mask = create_annotation_mask(
            annotations, image.shape, dataset_dict["train"].label2id, padding=dataset_dict["train"].padding
        )

        image = image[:512, :512].transpose(2, 0, 1)
        mask = mask[:512, :512]
        # if self.transform:
        #     transformed = self.transform(image=image, mask=mask)
        #     image = transformed["image"].transpose(2, 0, 1)
        #     mask = transformed["mask"]

        return image, mask, annotations

    def collate_fn(self, batch):
        """
        Collate function for MermaidDataset and CoralNetDataset.
        Args:
            batch: List of tuples (image, mask, annotations)
        Returns:
            images: Tensor or ndarray batch of images
            masks: Tensor or ndarray batch of masks
            annotations: List of annotation DataFrames
        """
        # images, masks, annotations = zip(*batch)

        # Filter out entries where image or mask is None
        filtered = [
            (img, msk, ann)
            for img, msk, ann in batch
            if img is not None and msk is not None
        ]
        images, masks, annotations = zip(*filtered)
        # Handle empty batch
        if len(images) == 0:
            return torch.tensor([]), torch.tensor([]), []

        # Convert to tensors if they aren't already
        if isinstance(images[0], torch.Tensor):
            images = torch.stack(images)
            masks = torch.stack(masks)
        else:
            # Convert numpy arrays to tensors for consistency
            images = torch.stack(
                [
                    torch.from_numpy(img) if isinstance(img, np.ndarray) else img
                    for img in images
                ]
            )
            masks = torch.stack(
                [
                    torch.from_numpy(mask) if isinstance(mask, np.ndarray) else mask
                    for mask in masks
                ]
            )

        return images, masks

In [91]:
dataset_tmp5 = Dataset_tmp5()
train_loader = DataLoader(dataset_tmp5, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, collate_fn = dataset_tmp5.collate_fn)
time_dict["no_transform"] = {"data_io":0, "forward":0}
num_iterations = 50
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["no_transform"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["no_transform"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [02:03<00:00,  2.47s/it]


In [92]:
time_dict["no_transform"]

{'data_io': 1.69118703365326, 'forward': 0.7109530687332154}

## Batch Size = 1

In [94]:
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)

time_dict["batch_size_1"] = {"data_io":0, "forward":0}
num_iterations = 100*(batch_size//1)
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["batch_size_1"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["batch_size_1"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/200 [00:00<?, ?it/s]

100%|██████████| 200/200 [05:04<00:00,  1.52s/it]


In [97]:
{k:v*(batch_size//1) for k,v in time_dict["batch_size_1"].items()}

{'data_io': 2.210755922794341, 'forward': 0.7570949435234068}

In [100]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=1, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)

time_dict["batch_size_4"] = {"data_io":0, "forward":0}
num_iterations = 100*batch_size//4
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["batch_size_4"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["batch_size_4"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

  0%|          | 0/50 [00:00<?, ?it/s]

100%|██████████| 50/50 [04:18<00:00,  5.17s/it]


In [102]:
{k:v*batch_size/4 for k,v in time_dict["batch_size_4"].items()}

{'data_io': 1.7883255410194396, 'forward': 0.7358332681655884}

In [None]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, drop_last=True, collate_fn = dataset_dict["train"].collate_fn)

time_dict["batch_size_4_nw4"] = {"data_io":0, "forward":0}
num_iterations = 100*batch_size//4
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["batch_size_4_nw4"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["batch_size_4_nw4"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

 98%|█████████▊| 49/50 [04:37<00:05,  5.80s/it]

In [None]:
{k:v*batch_size/4 for k,v in time_dict["batch_size_4_nw4"].items()}

In [25]:
dataset_tmp2 = Dataset_tmp2()
train_loader = DataLoader(dataset_tmp2, batch_size=4, shuffle=True, num_workers=4, drop_last=True, collate_fn = dataset_tmp2.collate_fn)
time_dict["local_test_batch_size_4"] = {"data_io":0, "forward":0}
num_iterations = 100*batch_size//4
for i in tqdm.tqdm(range(num_iterations)):
        start_time = time.time()
        data = next(iter(train_loader))
        time_dict["local_test_batch_size_4"]["data_io"] += (time.time() - start_time)/num_iterations

        _, labels = data
        labels = labels.long().to(meta_model.device)

        start_time = time.time()
        loss, outputs, concept_outputs = meta_model.batch_predict_loss(data)
        time_dict["local_test_batch_size_4"]["forward"] += (time.time() - start_time)/num_iterations
        
        del loss, outputs, concept_outputs
        torch.cuda.empty_cache()

100%|██████████| 50/50 [06:25<00:00,  7.72s/it]


In [26]:
{k:v*batch_size/4 for k,v in time_dict["local_test_batch_size_4"].items()}

{'data_io': 3.0653994894027714, 'forward': 0.7343459630012513}