In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import xarray as xr
from torch.utils.ds_MVBSimport random_split
import pytorch_lightning as pl
from typing import Dict, List, Union
from pathlib import Path
from tqdm import tqdm
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torchvision.models.segmentation.deeplabv3 import ASPP
from torchvision.models.resnet import resnet50
from torchvision.models._utils import IntermediateLayerGetter
from pytorch_lightning.loggers import TensorBoardLogger

In [45]:
class HakeXarrayDatasets(Dataset):
    def __init__(self,
                 data: Union[List[str],Dict[str, str]],
                 desired_order: List[str] = ["120 kHz", "38 kHz", "18 kHz"],
                 time_span: str="10min",
                 max_depth: float=800):
        self.desired_order = desired_order
        self.time_span = time_span
        self.max_depth = max_depth

        if isinstance(data, list):
            self.is_predict = True
            self.file_dict = {}
            self.load_from_data_dir_list(data)
        elif isinstance(data, dict):
            self.is_predict = False
            self.file_dict = {}
            self.load_from_target_data_dict(data)
        else:
            raise ValueError("Invalid input type for 'data'. It should be either a list or a dictionary.")

    def load_from_target_data_dict(self, target_data_dir_dict):
        counter = 0
        for mask_dir, ds_MVBS_dir in target_data_dir_dict.items():
            ds_MVBS_dir = Path(ds_MVBS_dir)
            mask_dir = Path(mask_dir)

            # Get all .zarr files in the ds_MVBS directory
            ds_MVBS_files = list(ds_MVBS_dir.glob('*.zarr'))

            for ds_MVBS_file in tqdm(ds_MVBS_files[:30]):
                # Extract filename without considering the path
                ds_MVBS_filename = ds_MVBS_file.name

                # Find a matching target file by filename in the target directory
                mask_file = self.find_matching_target(ds_MVBS_filename, mask_dir)

                if mask_file is not None:
                    if self.check_ds_MVBS_file(ds_MVBS_file):
                        ds_MVBS = xr.open_dataset(ds_MVBS_file)
                        resampled_obj = ds_MVBS["ping_time"].resample(
                            ping_time=self.time_span, skipna=True
                        )
                        for _,ds_MVBS_ping_time_slice in resampled_obj.groups.items():
                            self.file_dict[counter] = {"ds_MVBS_file": ds_MVBS_file,
                                                    "mask_file": mask_file,
                                                    "ping_time_slice": ds_MVBS_ping_time_slice}
                            counter += 1

    def load_from_data_dir_list(self, ds_MVBS_dirs):
        counter = 0
        for ds_MVBS_dir in ds_MVBS_dirs:
            ds_MVBS_dir = Path(ds_MVBS_dir)

            # Get all .zarr files in the ds_MVBSdirectory
            unchecked_ds_MVBS_files = list(ds_MVBS_dir.glob('*.zarr'))

            for ds_MVBS_file in tqdm(unchecked_ds_MVBS_files[:30]):
                if self.check_ds_MVBS_file(ds_MVBS_file):
                    ds_MVBS = xr.open_dataset(ds_MVBS_file)
                    resampled_obj = ds_MVBS["ping_time"].resample(
                        ping_time=self.time_span, skipna=True
                    )
                    for _, ping_time_slice in resampled_obj.groups.items():
                        self.file_dict[counter] = {"ds_MVBS_file": ds_MVBS_file,
                                                   "ping_time_slice": ping_time_slice}
                        counter += 1

    def check_ds_MVBS_file(self, ds_MVBS_file):
        try:
            # Open datasets and check for variables and desired channels
            data_ds = xr.open_dataset(ds_MVBS_file)
            data_channels = data_ds["Sv"].channel.values

            return all(any(partial_name in ch for ch in data_channels) for partial_name in self.desired_order)

        except KeyError as e:
            # Print KeyError and file paths
            print(f"KeyError: {e}, File Paths: {ds_MVBS_file}")
            return False

    def find_matching_target(self, ds_MVBS_filename, mask_dir):
        mask_files = list(mask_dir.glob('*.zarr'))

        for mask_file in mask_files:
            if mask_file.name == ds_MVBS_filename:
                return mask_file

        return None

    def __len__(self):
        if self.file_dict:
            return len(self.file_dict)
        else:
            return 0

    def __getitem__(self, idx):
        if self.file_dict is not None and self.is_predict:
            file_dict_idx = self.file_dict[idx]
            ds_MVBS_partition = xr.open_dataset(
                file_dict_idx["ds_MVBS_file"]
            ).isel(ping_time=file_dict_idx["ping_time_slice"]
            ).sel(depth=slice(0, self.max_depth))

            return ds_MVBS_partition

        elif self.file_dict is not None and not self.is_predict:
            file_dict_idx = self.file_dict[idx]
            ds_MVBS_partition = xr.open_dataset(
                file_dict_idx["ds_MVBS_file"]
            ).isel(ping_time=file_dict_idx["ping_time_slice"]
            ).sel(depth=slice(0, self.max_depth))

            # Mask Processing
            mask_partition = xr.open_dataset(
                file_dict_idx["mask_file"]
            )["mask"].isel(ping_time=file_dict_idx["ping_time_slice"]
            ).sel(depth=slice(0, self.max_depth)
            ).transpose("depth", "ping_time")

            return ds_MVBS_partition, mask_partition
        else:
            raise IndexError("Index out of range. No ds_MVBSfiles available.")

In [54]:
# Example usage dict
data_mask_dir_dict = {'/home/exouser/hake_data/Sv_mask/hake_clean/2007': '/home/exouser/hake_data/Sv_regridded/2007',
                    '/home/exouser/hake_data/Sv_mask/hake_clean/2009': '/home/exouser/hake_data/Sv_regridded/2009'}

# Create an instance of HakeXarrayDatasets
hake_dataset = HakeXarrayDatasets(data_mask_dir_dict, max_depth = 400)

# Print the length of the dataset
print("Dataset Length:", len(hake_dataset))

# Loop through the dataset and print some samples
for idx in range(len(hake_dataset)):
    data, target = hake_dataset[idx]

    print(data.coords)
    print(target.coords)

    if idx == 2:
        break

100%|██████████| 30/30 [00:00<00:00, 184.52it/s]
100%|██████████| 30/30 [00:00<00:00, 305.24it/s]

Dataset Length: 356
Coordinates:
  * channel    (channel) <U35 'GPT  18 kHz 009072034d55 3 ES18-11' ... 'GPT 2...
  * depth      (depth) float64 0.0 0.2 0.4 0.6 0.8 ... 399.4 399.6 399.8 400.0
  * ping_time  (ping_time) datetime64[ns] 2007-08-17T17:08:30 ... 2007-08-17T...
Coordinates:
  * depth      (depth) float64 0.0 0.2 0.4 0.6 0.8 ... 399.4 399.6 399.8 400.0
  * ping_time  (ping_time) datetime64[ns] 2007-08-17T17:08:30 ... 2007-08-17T...
Coordinates:
  * channel    (channel) <U35 'GPT  18 kHz 009072034d55 3 ES18-11' ... 'GPT 2...
  * depth      (depth) float64 0.0 0.2 0.4 0.6 0.8 ... 399.4 399.6 399.8 400.0
  * ping_time  (ping_time) datetime64[ns] 2007-08-17T17:10:00 ... 2007-08-17T...
Coordinates:
  * depth      (depth) float64 0.0 0.2 0.4 0.6 0.8 ... 399.4 399.6 399.8 400.0
  * ping_time  (ping_time) datetime64[ns] 2007-08-17T17:10:00 ... 2007-08-17T...
Coordinates:
  * channel    (channel) <U35 'GPT  18 kHz 009072034d55 3 ES18-11' ... 'GPT 2...
  * depth      (depth) float64 0




In [48]:
# Example usage ds_MVBS_dir_list
ds_MVBS_dir_list =['/home/exouser/hake_data/Sv_regridded/2007', '/home/exouser/hake_data/Sv_regridded/2009']

# Create an instance of HakeXarrayDatasets
hake_dataset = HakeXarrayDatasets(ds_MVBS_dir_list)

# Print the length of the dataset
print("Dataset Length:", len(hake_dataset))

# Loop through the dataset and print some samples
for idx in range(len(hake_dataset)):
    data = hake_dataset[idx]

    print(data)

    if idx == 2:
        break

100%|██████████| 30/30 [00:00<00:00, 99.49it/s]
100%|██████████| 30/30 [00:00<00:00, 102.88it/s]


Dataset Length: 765
<xarray.Dataset>
Dimensions:            (channel: 4, ping_time: 18, depth: 3795)
Coordinates:
  * channel            (channel) <U35 'GPT  18 kHz 009072034d55 3 ES18-11' .....
  * depth              (depth) float64 0.0 0.2 0.4 0.6 ... 758.4 758.6 758.8
  * ping_time          (ping_time) datetime64[ns] 2007-08-17T17:08:30 ... 200...
Data variables:
    Sv                 (channel, ping_time, depth) float64 ...
    frequency_nominal  (channel) float64 ...
    latitude           (ping_time) float64 ...
    longitude          (ping_time) float64 ...
Attributes:
    processing_function:          commongrid.compute_MVBS
    processing_level:             Level 3A
    processing_level_url:         https://echopype.readthedocs.io/en/stable/p...
    processing_software_name:     echopype
    processing_software_version:  0.8.3.dev1+g87fd9af
    processing_time:              2023-11-28T02:13:39Z
<xarray.Dataset>
Dimensions:            (channel: 4, ping_time: 120, depth: 3795)
C

In [16]:
ds_MVBS = data

# Partitioning and regridding
resampled_obj = ds_MVBS["ping_time"].resample(
    ping_time="10min", skipna=True
)

In [18]:
for pt_index, (k, v) in enumerate(
    tqdm(resampled_obj.groups.items(), desc="Processing partitions")
):
    print(pt_index)
    print(k, v)

Processing partitions: 100%|██████████| 24/24 [00:00<00:00, 57325.34it/s]

0
2007-08-01T23:50:00.000000000 slice(0, 4, None)
1
2007-08-02T00:00:00.000000000 slice(4, 124, None)
2
2007-08-02T00:10:00.000000000 slice(124, 244, None)
3
2007-08-02T00:20:00.000000000 slice(244, 364, None)
4
2007-08-02T00:30:00.000000000 slice(364, 484, None)
5
2007-08-02T00:40:00.000000000 slice(484, 604, None)
6
2007-08-02T00:50:00.000000000 slice(604, 724, None)
7
2007-08-02T01:00:00.000000000 slice(724, 844, None)
8
2007-08-02T01:10:00.000000000 slice(844, 964, None)
9
2007-08-02T01:20:00.000000000 slice(964, 1084, None)
10
2007-08-02T01:30:00.000000000 slice(1084, 1204, None)
11
2007-08-02T01:40:00.000000000 slice(1204, 1324, None)
12
2007-08-02T01:50:00.000000000 slice(1324, 1444, None)
13
2007-08-02T02:00:00.000000000 slice(1444, 1564, None)
14
2007-08-02T02:10:00.000000000 slice(1564, 1684, None)
15
2007-08-02T02:20:00.000000000 slice(1684, 1804, None)
16
2007-08-02T02:30:00.000000000 slice(1804, 1924, None)
17
2007-08-02T02:40:00.000000000 slice(1924, 2044, None)
18
2007-0


