### GPyTorch reprex

This notebook makes a minimal reproducible example for running a GP model on elevation data. We do this in the following steps:

 - `torch.DataLoader` to pull a chunk from the zarr dataset
 - Transformation function to turn the chunk into a `torch.Tensor`
 - Feed the `Tensor` into a GPyTorch model

### DataLoader

Adapted from [here](https://discuss.pytorch.org/t/dataloader-parallelization-synchronization-with-zarr-xarray-dask/176149).

In [None]:
%pip install torch tqdm

In [21]:
import numpy as np
import xarray as xr
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torch
import math

In [4]:
# Open the dataset
baker_url = 's3://petrichor/geosmart/baker.zarr/'
baker_ds = xr.open_dataset(
    baker_url, chunks='auto', engine='zarr', storage_options={"anon": True}
)

In [5]:
baker_ds

Unnamed: 0,Array,Chunk
Bytes,4.57 GiB,95.51 MiB
Shape,"(55, 5901, 3779)","(55, 843, 540)"
Dask graph,49 chunks in 2 graph layers,49 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 4.57 GiB 95.51 MiB Shape (55, 5901, 3779) (55, 843, 540) Dask graph 49 chunks in 2 graph layers Data type float32 numpy.ndarray",3779  5901  55,

Unnamed: 0,Array,Chunk
Bytes,4.57 GiB,95.51 MiB
Shape,"(55, 5901, 3779)","(55, 843, 540)"
Dask graph,49 chunks in 2 graph layers,49 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [12]:
chunkdict = baker_ds.chunks

In [18]:
chunkdict

Frozen({'time': (55,), 'y': (843, 843, 843, 843, 843, 843, 843), 'x': (540, 540, 540, 540, 540, 540, 539)})

In [37]:
class DaskChunkDataset(Dataset):
    def __init__(self, dask_array):
        super(DaskChunkDataset, self).__init__()
        self.dask_array = dask_array
        self.num_chunks = math.prod(map(len, dict(dask_array.chunks).values()))
        
    def __len__(self):
        return self.num_chunks
    
    def __getitem__(self, idx):
        return None
        # Convert linear index to block index
        block_idx_0 = idx // len(self.dask_array.chunks[1])
        block_idx_1 = idx % len(self.dask_array.chunks[1])
        # Fetch the chunk based on the block index
        chunk = self.dask_array.blocks[block_idx_0, block_idx_1].compute()
        # return torch.tensor(chunk)
    
# Create an instance of our dataset
dataset = DaskChunkDataset(baker_ds)

In [40]:
baker_ds.blocks

AttributeError: 'Dataset' object has no attribute 'blocks'