In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import xarray as xr
from torch.utils.data import random_split
import pytorch_lightning as pl
from typing import Dict, List, Union
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models.segmentation.deeplabv3 import ASPP
from torchvision.models.resnet import resnet50
from torchvision.models._utils import IntermediateLayerGetter
from pytorch_lightning.loggers import TensorBoardLogger

In [78]:
class HakeXarrayDatasets(Dataset):
    def __init__(self,
                 data: Union[List[str],Dict[str, str]],
                 desired_order: List[str] = ["120 kHz", "38 kHz", "18 kHz"],
                 slice_length: int=1000,
                 overlap: int=500):
        self.desired_order = desired_order
        self.slice_length = slice_length
        self.overlap = overlap

        if isinstance(data, list):
            self.is_predict = True
            self.file_dict = {}
            self.load_from_data_dir_list(data)
        elif isinstance(data, dict):
            self.is_predict = False
            self.file_dict = {}
            self.load_from_target_data_dict(data)
        else:
            raise ValueError("Invalid input type for 'data'. It should be either a list or a dictionary.")
        
    def generate_overlapping_slices(self, ds):
        ds_slices = [slice(len(ds["ping_time"]) - self.slice_length, len(ds["ping_time"]))]
        start_idx = 0
        while start_idx + self.slice_length <= len(ds["ping_time"]):
            end_idx = start_idx + self.slice_length
            ds_slices.append(slice(start_idx, end_idx))
            start_idx += self.slice_length - self.overlap

        return ds_slices

    def load_from_target_data_dict(self, target_data_dir_dict):
        counter = 0
        for mask_dir, ds_MVBS_dir in target_data_dir_dict.items():
            ds_MVBS_dir = Path(ds_MVBS_dir)
            mask_dir = Path(mask_dir)
            ds_MVBS_files = list(ds_MVBS_dir.glob('*.zarr'))
            for ds_MVBS_file in tqdm(ds_MVBS_files):
                ds_MVBS_filename = ds_MVBS_file.name
                mask_file = self.find_matching_target(ds_MVBS_filename, mask_dir)
                if mask_file is not None:
                    if self.check_ds_MVBS_file(ds_MVBS_file):
                        ds_MVBS = xr.open_dataset(ds_MVBS_file)
                        ds_slices = self.generate_overlapping_slices(ds_MVBS)
                        for ds_slice in ds_slices:
                            self.file_dict[counter] = {"ds_MVBS_file": str(ds_MVBS_file),
                                                    "mask_file": str(mask_file),
                                                    "ping_time_slice": ds_slice}
                            counter += 1

    def load_from_data_dir_list(self, ds_MVBS_dirs):
        counter = 0
        for ds_MVBS_dir in ds_MVBS_dirs:
            ds_MVBS_dir = Path(ds_MVBS_dir)
            unchecked_ds_MVBS_files = list(ds_MVBS_dir.glob('*.zarr'))
            for ds_MVBS_file in tqdm(unchecked_ds_MVBS_files):
                if self.check_ds_MVBS_file(ds_MVBS_file):
                    ds_MVBS = xr.open_dataset(ds_MVBS_file)
                    ds_slices = self.generate_overlapping_slices(ds_MVBS)
                    for ds_slice in ds_slices:
                        self.file_dict[counter] = {"ds_MVBS_file": str(ds_MVBS_file),
                                                   "ping_time_slice": ds_slice}
                        counter += 1

    def check_ds_MVBS_file(self, ds_MVBS_file):
        try:
            data_ds = xr.open_dataset(ds_MVBS_file)
            data_channels = data_ds["Sv"].channel.values
            return all(any(partial_name in ch for ch in data_channels) for partial_name in self.desired_order)
        except KeyError as e:
            print(f"KeyError: {e}, File Paths: {ds_MVBS_file}")
            return False

    def find_matching_target(self, ds_MVBS_filename, mask_dir):
        mask_files = list(mask_dir.glob('*.zarr'))
        for mask_file in mask_files:
            if mask_file.name == ds_MVBS_filename:
                return mask_file
        return None

    def __len__(self):
        if self.file_dict:
            return len(self.file_dict)
        else:
            return 0

    def __getitem__(self, idx):
        if self.file_dict is not None and self.is_predict:
            file_dict_idx = self.file_dict[idx]

            da_MVBS_tensor = torch.tensor(xr.open_dataset(
                file_dict_idx["ds_MVBS_file"]
            )["Sv"].isel(ping_time=file_dict_idx["ping_time_slice"]
            ).transpose("channel", "depth", "ping_time").data, dtype=torch.float32)

            return da_MVBS_tensor, file_dict_idx

        elif self.file_dict is not None and not self.is_predict:
            file_dict_idx = self.file_dict[idx]

            da_MVBS_tensor = torch.tensor(xr.open_dataset(
                file_dict_idx["ds_MVBS_file"]
            )["Sv"].isel(ping_time=file_dict_idx["ping_time_slice"]
            ).transpose("channel", "depth", "ping_time").data, dtype=torch.float32)

            mask_tensor = torch.tensor(xr.open_dataset(
                file_dict_idx["mask_file"]
            )["mask"].isel(ping_time=file_dict_idx["ping_time_slice"]
            ).transpose("depth", "ping_time").data, dtype=torch.float32)

            print(xr.open_dataset(
                file_dict_idx["mask_file"]
            )["mask"].isel(ping_time=file_dict_idx["ping_time_slice"]
            )["ping_time"])

            return da_MVBS_tensor, mask_tensor, file_dict_idx
        else:
            raise IndexError("Index out of range. No ds_MVBSfiles available.")

In [79]:
# Example usage dict
data_mask_dir_dict = {'/home/exouser/hake_data/Sv_mask/hake_clean/2007': '/home/exouser/hake_data/Sv_regridded/2007',
                    '/home/exouser/hake_data/Sv_mask/hake_clean/2009': '/home/exouser/hake_data/Sv_regridded/2009'}

# Create an instance of HakeXarrayDatasets
hake_dataset = HakeXarrayDatasets(data_mask_dir_dict)

# Print the length of the dataset
print("Dataset Length:", len(hake_dataset))

data_dim = None
target_dim = None

# Loop through the dataset and print some samples
for idx in range(len(hake_dataset)):
    data, target, file_dict = hake_dataset[idx]

    print("Data Shape:", data.shape)
    print("Target Shape:", target.shape)
    print(file_dict)

    if idx == 2:
        break

100%|██████████| 214/214 [00:01<00:00, 186.75it/s]
100%|██████████| 166/166 [00:00<00:00, 284.75it/s]


Dataset Length: 536
<xarray.DataArray 'ping_time' (ping_time: 1000)>
array(['2007-08-17T18:29:05.000000000', '2007-08-17T18:29:10.000000000',
       '2007-08-17T18:29:15.000000000', ..., '2007-08-17T19:52:10.000000000',
       '2007-08-17T19:52:15.000000000', '2007-08-17T19:52:20.000000000'],
      dtype='datetime64[ns]')
Coordinates:
  * ping_time  (ping_time) datetime64[ns] 2007-08-17T18:29:05 ... 2007-08-17T...
Attributes:
    axis:           T
    long_name:      Ping time
    standard_name:  time
Data Shape: torch.Size([4, 3795, 1000])
Target Shape: torch.Size([3795, 1000])
{'ds_MVBS_file': '/home/exouser/hake_data/Sv_regridded/2007/x0116_0_wt_20070817_170834_f0006.zarr', 'mask_file': '/home/exouser/hake_data/Sv_mask/hake_clean/2007/x0116_0_wt_20070817_170834_f0006.zarr', 'ping_time_slice': slice(967, 1967, None)}
<xarray.DataArray 'ping_time' (ping_time: 1000)>
array(['2007-08-17T17:08:30.000000000', '2007-08-17T17:08:35.000000000',
       '2007-08-17T17:08:40.000000000', ..., '2

In [55]:
# Example usage ds_MVBS_dir_list
ds_MVBS_dir_list =['/home/exouser/hake_data/Sv_regridded/2007', '/home/exouser/hake_data/Sv_regridded/2009']

# Create an instance of HakeXarrayDatasets
hake_dataset = HakeXarrayDatasets(ds_MVBS_dir_list)

# Print the length of the dataset
print("Dataset Length:", len(hake_dataset))

# Loop through the dataset and print some samples
for idx in range(len(hake_dataset)):
    data, file_dict = hake_dataset[idx]

    print("Data Shape:", data.shape)
    print(file_dict)

    if idx == 2:
        break

100%|██████████| 214/214 [00:02<00:00, 90.30it/s]
100%|██████████| 166/166 [00:02<00:00, 78.75it/s]


Dataset Length: 3321
Data Shape: torch.Size([4, 3795, 240])
{'ds_MVBS_file': '/home/exouser/hake_data/Sv_regridded/2007/x0116_0_wt_20070817_170834_f0006.zarr', 'ping_time_slice': slice(1727, 1967, None)}
Data Shape: torch.Size([4, 3795, 240])
{'ds_MVBS_file': '/home/exouser/hake_data/Sv_regridded/2007/x0116_0_wt_20070817_170834_f0006.zarr', 'ping_time_slice': slice(0, 240, None)}
Data Shape: torch.Size([4, 3795, 240])
{'ds_MVBS_file': '/home/exouser/hake_data/Sv_regridded/2007/x0116_0_wt_20070817_170834_f0006.zarr', 'ping_time_slice': slice(160, 400, None)}


In [58]:
class HakeDataModule(pl.LightningDataModule):
    def __init__(self, training_target_data_dir_dict: dict, testing_target_data_dir_dict: dict,
                 pred_dir: dict, batch_size=32, num_workers=4, validation_split=0.1):
        super(HakeDataModule, self).__init__()
        self.training_target_data_dir_dict = training_target_data_dir_dict
        self.testing_target_data_dir_dict = testing_target_data_dir_dict
        self.pred_dir = pred_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.validation_split = validation_split

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            full_train_dataset = HakeXarrayDatasets(self.training_target_data_dir_dict)
            train_size = int((1 - self.validation_split) * len(full_train_dataset))
            val_size = len(full_train_dataset) - train_size
            self.train_dataset, self.val_dataset = random_split(full_train_dataset, [train_size, val_size])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        test_dataset = HakeXarrayDatasets(self.testing_target_data_dir_dict)
        return DataLoader(test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def predict_dataloader(self):
        pred_dataset = HakeXarrayDatasets(self.pred_dir)
        return DataLoader(pred_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

In [59]:
training_target_data_dir_dict = {'/home/exouser/hake_data/Sv_mask/hake_clean/2007': '/home/exouser/hake_data/Sv_regridded/2007',
                    '/home/exouser/hake_data/Sv_mask/hake_clean/2009': '/home/exouser/hake_data/Sv_regridded/2009'}
testing_target_data_dir_dict = {'/home/exouser/hake_data/Sv_mask/hake_clean/20011': '/home/exouser/hake_data/Sv_regridded/20011',
                    '/home/exouser/hake_data/Sv_mask/hake_clean/2013': '/home/exouser/hake_data/Sv_regridded/2013'}
pred_dir = ['/home/exouser/hake_data/Sv_regridded/2017', '/home/exouser/hake_data/Sv_regridded/2019']

data_module = HakeDataModule(
    training_target_data_dir_dict=training_target_data_dir_dict,
    testing_target_data_dir_dict=testing_target_data_dir_dict,
    pred_dir=pred_dir,
    batch_size=32,
    num_workers=4,
    validation_split=0.2
)

data_module.setup()

100%|██████████| 214/214 [00:01<00:00, 187.38it/s]
100%|██████████| 166/166 [00:00<00:00, 286.49it/s]


In [62]:
predict_dataloader = data_module.predict_dataloader()
for batch in predict_dataloader:
    print(batch)

100%|██████████| 176/176 [00:01<00:00, 102.64it/s]
100%|██████████| 167/167 [00:01<00:00, 125.68it/s]


TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 127, in collate
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 127, in <dictcomp>
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 150, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'slice'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 142, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 142, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 130, in collate
    return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 130, in <dictcomp>
    return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
  File "/home/exouser/miniforge3/envs/hake_ml_poc/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 150, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'slice'>
