Connected to spyder (Python 3.12.3)

To implement the "Batch" or "Tiled" approach, you need to change your data access pattern. Instead of treating every 5x5 patch as an independent I/O request, you treat them as **sub-sections of a larger spatial block** that you load once and keep in memory.

This is highly effective because loading one $256 \times 256$ tile is often nearly as fast as loading one $5 \times 5$ patch due to how disk headers and network protocols (like S3 or local filesystems) handle data "bursts."

### üèóÔ∏è The Strategy: Tile-and-Cache

To make this work in a PyTorch `Dataset`, follow these three steps:

#### 1\. Group Points by Tile

In your `__init__`, you must divide your geographic extent into a grid of large tiles (e.g., $256 \times 256$ pixels). Assign every point in your metadata to a `tile_id`.

```python
# In __init__:
# Assume resolution is 100m
tile_size_px = 256
res = 100 
tile_size_m = tile_size_px * res

# Calculate which tile each point belongs to
self.metadata['tile_y'] = (self.metadata['northing'] // tile_size_m).astype(int)
self.metadata['tile_x'] = (self.metadata['easting'] // tile_size_m).astype(int)
self.metadata['tile_id'] = self.metadata['tile_y'].astype(str) + "_" + self.metadata['tile_x'].astype(str)

# IMPORTANT: Sort metadata by tile_id so sequential __getitem__ calls hit the same tile
self.metadata = self.metadata.sort_values('tile_id').reset_index(drop=True)
```

#### 2\. Implement the Internal Cache

In `__getitem__`, you check if the requested point falls within the "Active Tile" already in memory. If it doesn't, you fetch the new large tile.

```python
class VGDDataset(Dataset):
    def __init__(self, ...):
        # ... (other init code) ...
        self.current_tile_id = None
        self.tile_cache = {} # Stores the 256x256 data in RAM

    def __getitem__(self, item_idx):
        entry = self.metadata.iloc[item_idx]
        target_tile = entry['tile_id']

        # Check if we need to load a new tile
        if target_tile != self.current_tile_id:
            self._load_tile_to_cache(entry['tile_y'], entry['tile_x'])
            self.current_tile_id = target_tile

        # Now, instead of going to disk/Dask, slice from self.tile_cache
        # px, py are local offsets within the 256x256 tile
        px = int((entry['easting'] % tile_size_m) // res)
        py = int((entry['northing'] % tile_size_m) // res)
        
        # Fast slice from RAM (NumPy/Torch)
        patch = self.tile_cache['dynamic'][:, py-2:py+3, px-2:px+3]
        
        return patch # + static/target/etc
```

#### 3\. The Tile Loader

This function uses Xarray/Dask to pull one large chunk. Because it happens once every \~100-500 samples (depending on point density), the Dask scheduler overhead becomes negligible.

```python
def _load_tile_to_cache(self, tile_y, tile_x):
    # Define bounds for the 256x256 area
    y_min = tile_y * tile_size_m
    y_max = y_min + tile_size_m
    x_min = tile_x * tile_size_m
    x_max = x_min + tile_size_m
    
    # Single Xarray call for the whole block
    # Use .load() to bring it into actual RAM
    tile_ds = self.dynamic_data.sel(
        northing=slice(y_max, y_min), 
        easting=slice(x_min, x_max)
    ).load() 
    
    self.tile_cache['dynamic'] = tile_ds['variable_name'].values
```

-----

### ‚ö†Ô∏è Critical Note on PyTorch Multi-processing

If you use `num_workers > 0` in your `DataLoader`, each worker process will have its own `tile_cache`.

  * **The Problem:** If Worker 1 gets Sample A (Tile 1) and Worker 2 gets Sample B (also Tile 1), they both load the same tile, wasting RAM.
  * **The Fix:** PyTorch fills batches sequentially. Since we **sorted** the metadata by `tile_id` in step 1, Worker 1 will likely get a long sequence of samples all from Tile 1, while Worker 2 gets a sequence from Tile 2. This keeps the cache hit-rate very high.

### üöÄ Why this fixes your Dask Dashboard

Looking at your previous image, the "Task Stream" was full of tiny slivers. With this tiled approach:

1.  Instead of 10,000 tiny `getitem` tasks, Dask will see **one large task** (the 256x256 read).
2.  The "white space" (scheduler overhead) disappears because the CPU is busy slicing from RAM for the next 500 samples without asking Dask for anything.

**Would you like me to help you calculate the exact pixel-offset math for your specific EPSG:3035 coordinate system?**


You are absolutely right. Shuffling is the "enemy" of caching. If your `DataLoader` has `shuffle=True`, it will request a point in Italy, then a point in Sweden, then another in Italy, forcing your code to reload massive files constantly.

To fix this while keeping your training effective, you have two main strategies: **Custom Batch Sampling** or a **Multi-Tile LRU Cache**.

-----

### Strategy 1: The "Geographic" Batch Sampler (Most Efficient)

Instead of shuffling every single point globally, you shuffle **by tiles**. This ensures that one batch of data (e.g., 32 samples) all comes from the same $256 \times 256$ area. This is a common practice in geospatial deep learning.

**How it works:**

1.  Group all metadata indices by their `tile_id`.
2.  Shuffle the list of tiles.
3.  For each tile, shuffle the points within it.
4.  Feed these sequences to the trainer.

<!-- end list -->

```python
from torch.utils.data import Sampler
import random

class TileBatchSampler(Sampler):
    def __init__(self, metadata, batch_size):
        self.batch_size = batch_size
        # Group indices by tile_id: { 'tile_1': [index1, index5...], 'tile_2': [index2, index3...] }
        self.tile_groups = metadata.groupby('tile_id').indices
        self.tile_ids = list(self.tile_groups.keys())

    def __iter__(self):
        # 1. Shuffle the order of the tiles themselves
        random.shuffle(self.tile_ids)
        
        batch = []
        for tid in self.tile_ids:
            indices = self.tile_groups[tid].tolist()
            # 2. Shuffle points within this specific tile
            random.shuffle(indices)
            
            for idx in indices:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
        if len(batch) > 0:
            yield batch

    def __len__(self):
        return len(self.tile_groups) # Roughly
```

**Usage:**

```python
sampler = TileBatchSampler(dataset.metadata, batch_size=32)
loader = DataLoader(dataset, batch_sampler=sampler) # Do NOT set shuffle=True here
```

-----

### Strategy 2: The Multi-Tile LRU Cache (Least Code Change)

If you must have global shuffling (picking any point in the country at any time), you need a **Least Recently Used (LRU) Cache**. Instead of keeping just one tile in memory, you keep the last 10 or 20.

Even with a global shuffle, if your points are somewhat dense, there is a high statistical probability that a "random" index will fall into one of the 20 tiles you already have in RAM.

```python
from collections import OrderedDict

class VGDDataset(Dataset):
    def __init__(self, ...):
        self.cache = OrderedDict()
        self.max_cache_size = 15 # Adjust based on your RAM (e.g., 15 tiles of 256x256)

    def _get_tile(self, tile_id, tile_y, tile_x):
        if tile_id in self.cache:
            # Move to end (mark as recently used)
            self.cache.move_to_end(tile_id)
            return self.cache[tile_id]
        
        # Load new tile if not in cache
        tile_data = self._load_from_disk(tile_y, tile_x)
        
        # Add to cache
        self.cache[tile_id] = tile_data
        
        # If cache too big, remove oldest
        if len(self.cache) > self.max_cache_size:
            self.cache.popitem(last=False)
            
        return tile_data
```

-----

### Which one should you choose?

| Feature | Geographic Batch Sampler | LRU Cache |
| :--- | :--- | :--- |
| **Speed** | üöÄ **Fastest**. Almost zero disk I/O during a batch. | üê¢ **Moderate**. Depends on point density. |
| **Randomness** | Good (shuffled tiles + shuffled points). | Perfect (global shuffle). |
| **RAM Usage** | Low (only needs 1 tile per worker). | High (needs $N$ tiles per worker). |
| **Implementation** | Requires a custom Sampler. | Easy logic inside Dataset. |

### üí° Pro Tip for Geospatial ML

Most researchers use **Strategy 1**. Total global randomness is rarely necessary for convergence in Earth Observation models. As long as your batches move from a "Forest" tile to a "City" tile to a "Desert" tile, the model generalizes perfectly well. The 10x speedup usually outweighs the minor loss in shuffle entropy.

**Would you like me to refine the pixel-coordinate math so the 5x5 slice is always perfectly centered within these tiles?**

1. The Critical Bottlenecks
#1: Static Data Sampling (Lines 275‚Äì276)
Time: 8.09e+10 (The largest single value in the report)

The Issue: ds.where(ds.notnull()).sel(...)

Why it's slow: You are calling .where(ds.notnull()) inside the __getitem__ method. This creates a masked copy of the entire dataset 17,830 times (once for every static variable for every sample). This is likely triggering expensive disk reads or computation of the mask repeatedly.

#2: Seismic Data Indexing (Line 228)
Time: 7.34e+09

The Issue: ds[var_name].isel(...).sel(...)

Why it's slow: Using .isel() and .sel() with xr.DataArray objects as indices for every hit is computationally expensive in xarray, especially if the data isn't loaded into memory.

#3: Metadata Access (Line 193)
Time: 4.56e+07

The Issue: entry = self.metadata.iloc[idx]

Why it's slow: While smaller than the xarray hits, iloc on a Pandas DataFrame is notoriously slow when called millions of times in a DataLoader loop.

2. Immediate Optimization Strategies
To speed this up significantly (potentially by 10x or more), consider these changes:

A. Pre-Process/Load into Memory
The biggest win is avoiding xarray overhead during training.

Load Static Data: If your static data fits in RAM, call .load() on your static datasets before passing them to the Dataset class.

Remove .where(notnull()): Instead of checking for NaNs on every __getitem__ call, handle the NaNs once during initialization or use a simpler np.nan_to_num after converting to a numpy array.

B. Convert to NumPy/Torch Early
Xarray's .sel() and .isel() provide great flexibility but have high overhead for deep learning.

Convert your coordinates (lat/lon) to integer indices once during __init__.

Inside __getitem__, use standard NumPy-style slicing on the underlying .values or .data array rather than using labeled indexing.

C. Optimize Metadata Access
Convert your self.metadata DataFrame into a dictionary of arrays or a NumPy structured array in __init__.

Example:

D. Avoid Redundant Conversions
Line 256 & 299: You are calling torch.tensor(sampled) inside a loop. If possible, convert the entire dataset to a Torch tensor or a memory-mapped NumPy array once so you only perform a slice operation during training.

In [1]:
import xarray as xr
import geopandas as gpd
from rasterio.features import rasterize
from rasterio.transform import from_bounds
import rioxarray as rxr
import numpy as np
import rasterio
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
lithology = gpd.read_file(r"C:\Users\gmfet\Desktop\predictors\lithology\litology_italy.gpkg")
output_path = r'C:\Users\gmfet\Desktop\collaboration_with_Awais\lithology.tif'

# Get the bounds of the file
xmin, ymin, xmax, ymax = lithology.total_bounds
res = 0.001  # pixel size

# width = int((xmax - xmin) / res)
height = int((ymax - ymin) / res)
width = height
transform = from_bounds(xmin, ymin, xmax, ymax, width, height)

shapes = ((geom, value) for geom, value in zip(lithology.geometry, lithology["cat"]))

with rasterio.open(output_path, 'w', driver='GTiff', width=width, height=height, count=1, crs=lithology.crs, transform=transform, dtype='uint8') as geo_pckg:
    geo_pckg.write(rasterize(shapes, out_shape=(height, width), fill=0, transform=transform, dtype='uint8'), 1)


In [2]:
# italy shapefile 
shp_path = r"C:\Users\gmfet\vgd_italy\italy_aoi\gadm41_ITA_0.shp"

shp_file = gpd.read_file(shp_path)

shp_file.crs

<Geographic 2D CRS: EPSG:4326>
Name: WGS 84
Axis Info [ellipsoidal]:
- Lat[north]: Geodetic latitude (degree)
- Lon[east]: Geodetic longitude (degree)
Area of Use:
- name: World.
- bounds: (-180.0, -90.0, 180.0, 90.0)
Datum: World Geodetic System 1984 ensemble
- Ellipsoid: WGS 84
- Prime Meridian: Greenwich

## Static features
Save the categorical variables as uint8 dtype


In [None]:
import os
import xarray as xr


filepath = r'C:\Users\gmfet\vgd_italy\regression\original_data'

file_list = os.listdir(filepath)
for file in file_list:
    if file.endswith('.tif'):
        engine='rasterio'
    elif file.endswith('.nc'):
        engine = 'h5netcdf'
    else:
        continue
    
    print(file)
    
    with xr.open_dataset(os.path.join(filepath, file), engine=engine) as data_ds:
        rename_dict = {'x': 'longitude', 'y': 'latitude', 'lon': 'longitude', 'lat': 'latitude', 'band':'time'}
        data_ds = data_ds.rename({k: v for k, v in rename_dict.items() if k in data_ds.dims})
        print(data_ds.dims)

In [None]:
with xr.open_dataset(r"C:\Users\gmfet\vgd_italy\regression\original_data\genua.nc", engine='netcdf4') as ds:
    print(ds.chunk)

In [None]:
ds['ssm'].sel(time='2022-10-01').plot()

In [24]:
# with rxr.open_rasterio(r"C:\Users\gmfet\Desktop\predictors\LULC\*.tif", chunks=1000) as ds:
#     ds
with xr.open_mfdataset(r"C:\Users\gmfet\vgd_italy\regression\original_data\ksat.nc", engine='netcdf4') as ds:
    print(ds)
    ds = ds.rio.write_crs("EPSG:4326", inplace=False)


    # ds['band_data'] = ds['band_data'].rio.write_nodata(np.nan)
    # ds = ds.astype('float32')
    # ds = ds.rio.clip(shp_file.geometry, shp_file.crs, drop=True, all_touched=True)
    
    data_attrs = ds.attrs.copy()
    coord_attrs = {c: ds[c].attrs.copy() for c in ds.coords}
    # ds = ds.squeeze("band", drop=True)
    # # # ds = ds.drop_vars("spatial_ref")
    ds = ds.rename({"lon": "longitude", "lat": "latitude", 'Band1': 'ksat'})
    ds = ds.chunk({'latitude': 256, 'longitude': 256})

    


    # ds = ds.rename({"x": "longitude", "y": "latitude", 'band': 'time', 'band_data': 'drought_code'})
    # ds['time'] = pd.date_range('2017-01-01', periods=ds.sizes['time'], freq='D')
    # ds = ds.chunk({'time': -1})
    # ds = ds.interpolate_na(
    #     dim="time", 
    #     method="polynomial", 
    #     order=3,
    #     fill_value="extrapolate"
        
    # )

    # ds = ds.chunk({'time': -1, 'latitude': 256, 'longitude': 256})

    # # # # # # ds = ds.to_dataset(name='clay_content')
    ds.attrs.update(data_attrs)
    ds = ds.astype("float32")

    # restore coordinate attributes
    for c in ["longitude", "latitude"]:
        if c in coord_attrs:
            ds[c].attrs.update(coord_attrs[c])

    # # # ds = ds.dropna('latitude')

    # # # ds = ds.rio.clip(shp_file.geometry, shp_file.crs, drop=True, all_touched=True)
            
    ds.to_netcdf("ksat.nc", engine="h5netcdf", encoding= {'ksat': {"dtype": "float32", "zlib": True, "complevel": 0, "compression": "gzip", "compression_opts": 0}})
    

<xarray.Dataset> Size: 104MB
Dimensions:  (lon: 7200, lat: 3600)
Coordinates:
  * lon      (lon) float64 58kB -180.0 -179.9 -179.9 ... 179.9 179.9 180.0
  * lat      (lat) float64 29kB 89.97 89.92 89.88 89.82 ... -89.88 -89.92 -89.97
Data variables:
    crs      int32 4B ...
    Band1    (lat, lon) float32 104MB dask.array<chunksize=(3600, 7200), meta=np.ndarray>
Attributes:
    CDI:                        Climate Data Interface version 1.9.8 (https:/...
    Conventions:                CF-1.5
    history:                    Wed Dec 16 16:42:50 2020: cdo -mulc,10. ksat3...
    GDAL_AREA_OR_POINT:         Area
    GDAL:                       GDAL 3.0.4, released 2020/01/28
    history_of_appended_files:  Fri Dec 11 16:53:00 2020: Appended file /huge...
    NCO:                        netCDF Operators version 4.9.2 (Homepage = ht...
    CDO:                        Climate Data Operators version 1.9.8 (https:/...


## Dynamic features

In [None]:
# with rxr.open_rasterio(r"C:\Users\gmfet\Desktop\predictors\LULC\*.tif", chunks=1000) as ds:
#     ds
# from dask.distributed import Client
# client = Client() # This will give you a link to a dashboard (usually http://localhost:8787)

encoding = {
    "temperature": {
        "dtype": "float32",        # Store as 16-bit integers
        "scale_factor": 0.01,   # Preservation of 2 decimal places
        # "_FillValue": -9999,    # Map NaNs to -9999
        "zlib": True,           # Enable compression
        "complevel": 0,         # Moderate compression level
        # "shuffle": True,        # Better compression efficiency
        # "chunksizes": (1, 180, 360) # Optimize for time-series access
    }
}

# ds.to_netcdf("output.nc", encoding=encoding)
with xr.open_mfdataset(r"C:\Users\gmfet\vgd_italy\data\dynamic\temperature.tif", engine='rasterio') as ds:
    ds
    ds['band_data'] = ds['band_data'].rio.write_nodata(np.nan)
    # ds = ds.chunk({'x': 1024, 'y': 1024})
    ds = ds.astype('float32')
    # # # # # ds = ds.rio.clip(shp_file.geometry, shp_file.crs, drop=True, all_touched=True)
    
    # ds = ds.fillna(0).astype("float32")
    data_attrs = ds.attrs.copy()
    coord_attrs = {c: ds[c].attrs.copy() for c in ds.coords}
    # # # ds = ds.drop_vars("spatial_ref")
    ds = ds.rename({"x": "longitude", "y": "latitude", 'band': 'time', 'band_data': 'drought_code'})
    ds['time'] = pd.date_range('2017-01-01', periods=ds.sizes['time'], freq='D')
    # ds = ds.chunk({'longitude': -1, 'latitude': 1024, 'time': 512})

    # ds = ds.interpolate_na(
    #     dim="longitude", 
    #     method="linear", 
    #     fill_value="extrapolate"
        
    # )

    ds = ds.chunk({'time': -1, 'latitude': 1024, 'longitude': 1024,})

    # ds = ds.interpolate_na(
    #     dim="time", 
    #     method="linear", 
    #     fill_value="extrapolate"
        
    # )

    
    

    # # # # # # # ds = ds.to_dataset(name='bulk_density')
    ds.attrs.update(data_attrs)
    ds = ds.astype("float32")



    

    # restore coordinate attributes
    for c in ["longitude", "latitude"]:
        if c in coord_attrs:
            ds[c].attrs.update(coord_attrs[c])

    # # # # ds = ds.dropna('latitude')

    # # # # ds = ds.rio.clip(shp_file.geometry, shp_file.crs, drop=True, all_touched=True)
    
    ds.to_netcdf("temperature.nc", engine="h5netcdf", encoding= encoding)
    

In [None]:
# ds = ds.astype("uint8")
print(ds)

In [None]:
ds['temperature'].sel(time='2018-08-01').plot()
# plt.show()

In [None]:

np.unique(ds['drought_code'].sel(time='2017-01-01').values)