# Import packages, define helper function

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import dask.array as da
import gcsfs
import os

import deepsensor.torch
from deepsensor.data import DataProcessor, TaskLoader, construct_circ_time_ds
from deepsensor.data.sources import get_era5_reanalysis_data, get_earthenv_auxiliary_data, \
    get_gldas_land_mask
from deepsensor.model import ConvNP
from deepsensor.train import set_gpu_default_device

In [2]:
def standardize_dates(ds):
    """
    Convert the 'time' dimension in an xarray dataset to date-only precision with datetime64[D].
    
    Parameters:
    ds (xarray.Dataset): The dataset whose 'time' dimension you wish to modify.
    
    Returns:
    xarray.Dataset: Modified dataset with time as datetime64[D].
    """
    if 'time' in ds.coords:
        # Convert time to day-level precision
        ds['time'] = ds['time'].dt.floor('D').values.astype('datetime64[D]')
    
    return ds

# Data Inventory and Preprocessing

### Temporal datasets: SST, ice concentration

In [3]:
# Path to your Zarr stores
ice_concentration_path = 'gs://great-lakes-osd/ice_concentration.zarr'
glsea_path = 'gs://great-lakes-osd/GLSEA_combined.zarr'
glsea3_path = 'gs://great-lakes-osd/GLSEA3_combined.zarr'

# Open the Zarr stores
ice_concentration_raw = xr.open_zarr(ice_concentration_path, chunks={'time': 366, 'lat': 200, 'lon': 200})
glsea_raw = xr.open_zarr(glsea_path, chunks={'time': 366, 'lat': 200, 'lon': 200})
glsea3_raw = xr.open_zarr(glsea3_path, chunks={'time': 366, 'lat': 200, 'lon': 200})

# Replace -1 (land value) with NaN
ice_concentration_raw = ice_concentration_raw.where(ice_concentration_raw != -1, float('nan'))

# Convert all times to date-only format, removing the time component
ice_concentration_raw = standardize_dates(ice_concentration_raw)
glsea_raw = standardize_dates(glsea_raw)
glsea3_raw = standardize_dates(glsea3_raw)

# Drop CRS - not needed
glsea_raw = glsea_raw.drop_vars('crs')
glsea3_raw = glsea3_raw.drop_vars('crs')

### Static datasets: lake mask, bathymetry

In [4]:
# Set up GCS filesystem
fs = gcsfs.GCSFileSystem(project='your-gcp-project')

# Path to the NetCDF files 
context_path = 'gs://great-lakes-osd/context/'

# Open the NetCDF files using xarray with gcsfs
bathymetry_raw = xr.open_dataset(fs.open(os.path.join(context_path, 'interpolated_bathymetry.nc')))
lakemask_raw = xr.open_dataset(fs.open(os.path.join(context_path, 'lakemask.nc')))

# Name the bathymetry variable
bathymetry_raw = bathymetry_raw.rename({'__xarray_dataarray_variable__': 'bathymetry'})

# Data Processor

In [5]:
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
print(data_processor)

DataProcessor with normalisation params:
{'coords': {'time': {'name': 'time'},
            'x1': {'map': None, 'name': 'lat'},
            'x2': {'map': None, 'name': 'lon'}}}


In [17]:
#glsea = data_processor(glsea_raw)
_ = data_processor(glsea_raw.sel(time=slice("2009-01-01", "2009-12-31")))
glsea = data_processor(glsea_raw)

# process the bathymetry and lake
bathymetry, lakemask = data_processor([bathymetry_raw, lakemask_raw], method="min_max")

In [18]:
# Select a subset of the ice concentration data to compute normalization parameters
_ = data_processor(ice_concentration_raw.sel(time=slice("2009-01-01", "2009-01-31")))

# Now apply the normalization parameters to the full ice concentration dataset
#ice_concentration = data_processor(ice_concentration_raw, method="min_max")

import dask.array as da

# Open the Zarr file with dask chunks
ice_concentration_raw = xr.open_zarr(ice_concentration_path, chunks={'time': 366, 'lat': 200, 'lon': 200})

def process_chunk(chunk):
    # Calculate min and max for the chunk
    min_value = chunk.min()
    max_value = chunk.max()
    
    # Check if the min and max are the same (no variation)
    if min_value == max_value:
        return chunk  # or set it to a constant value, if needed
    
    # Otherwise, apply normalization using the DataProcessor
    return data_processor(chunk, method="min_max")

# Provide a template to map_blocks
template = ice_concentration_raw.isel(time=0)  # Take a single slice (e.g., the first time step) to use as the template

# Process the chunks of the ice concentration data lazily using Dask
normalized_ice_concentration = ice_concentration_raw.map_blocks(process_chunk, template=template)

# Compute the result (this will trigger the actual computation)
normalized_ice_concentration_computed = normalized_ice_concentration.compute()

# Now `normalized_ice_concentration_computed` will hold the normalized data


ValueError: could not broadcast input array from shape (134,200,200) into shape (134,200)

<xarray.Dataset> Size: 8MB
Dimensions:            (lat: 1024, lon: 1024)
Coordinates:
  * lat                (lat) float64 8kB 50.6 50.59 50.58 ... 38.9 38.89 38.87
  * lon                (lon) float64 8kB -92.41 -92.39 -92.38 ... -75.89 -75.87
    time               datetime64[ns] 8B 1972-12-01
Data variables:
    ice_concentration  (lat, lon) float64 8MB dask.array<chunksize=(200, 200), meta=np.ndarray>
Attributes: (12/23)
    coverage_area:            Great Lakes
    data_source:              NOAA
    description:              Great Lakes ice concentrations
    disclaimer:               Data collected and processed by NOAA and dissem...
    dissemination:            USNIC Website, CIS Website
    grid_resolution:          1.800 km
    ...                       ...
    product:                  GRID - Resolution 1800
    source:                   NAIS daily Great Lakes ice analysis
    source_url:               https://noaadata.apps.nsidc.org/NOAA/G10029/
    spatial_extent:         

AttributeError: 'NoneType' object has no attribute 'min'

In [None]:
data_processor.config

In [None]:
dates = pd.date_range(glsea.time.values.min(), glsea.time.values.max(), freq="D")

In [None]:
doy_ds = construct_circ_time_ds(dates, freq="D")
cosD = standardize_dates(doy_ds["cos_D"])
sinD = standardize_dates(doy_ds["sin_D"])

In [None]:
sinD

# Sanity checks after being run through the data_processor

### Sanity check: ice concentration

In [None]:
ice_concentration

In [None]:
# Load the dataset
ds = ice_concentration

# Select a single time slice to plot, e.g., the first time point
time_index = 20
time_slice = ds.isel(time=time_index)

# Plotting
plt.figure(figsize=(10, 6))
ice_conc_plot = time_slice.ice_concentration.plot(
    x='x2', 
    y='x1', 
    cmap='Blues',
    robust=True  # Automatically exclude extreme outliers from color scaling
)
plt.title(f"Ice Concentration on {str(time_slice.time.values)}")
plt.show()

### Sanity check: GLSEA

In [None]:
glsea

In [None]:
# Load the dataset
ds = glsea

# Select a single time slice to plot, e.g., the first time point
time_index = 20
time_slice = ds.isel(time=time_index)

# Plotting
plt.figure(figsize=(10, 6))
ice_conc_plot = time_slice.sst.plot(
    x='x2', 
    y='x1', 
    cmap='cividis',
    robust=True  # Automatically exclude extreme outliers from color scaling
)
plt.title(f"GLSEA on {str(time_slice.time.values)}")
plt.show()

### Sanity check: bathymetry

In [None]:
bathymetry

In [None]:
# Load the dataset
ds = bathymetry

# Plotting
plt.figure(figsize=(10, 6))
ice_conc_plot = ds.bathymetry.plot(
    x='x2', 
    y='x1', 
    cmap='cividis',
    robust=True  # Automatically exclude extreme outliers from color scaling
)
plt.title(f"GLSEA on {str(time_slice.time.values)}")
plt.show()

### Sanity check: lake mask

In [None]:
lakemask

In [None]:
# Plot lakemask
lakemask_mask = lakemask['mask']

# Plot the data
plt.figure(figsize=(10, 8))
lakemask_mask.plot(cmap='Blues', add_colorbar=True)
plt.title('Mask from lakemask.nc')
plt.show()

# Tasks

## Generating random coordinates from inside the mask

In [None]:
def generate_random_coordinates(mask_da, N, data_processor=None):
    """
    Generate N random coordinates (lat, lon) from a mask with values 1 inside the lake area,
    and normalize them using the DataProcessor if provided.
    
    Parameters:
    mask_da: xarray DataArray containing the mask (with 1 for valid, 0 for invalid areas)
    N: Number of random points to generate
    data_processor: (optional) DataProcessor object for normalization if needed
    
    Returns:
    numpy.ndarray: Array of shape (2, N) with random latitudes and longitudes from the masked region
    """
    
    # Get the valid indices where the mask is 1
    mask = mask_da['mask'].values
    valid_indices = np.argwhere(mask == 1)
    
    # Randomly sample N points from the valid indices
    random_indices = valid_indices[np.random.choice(valid_indices.shape[0], N, replace=False)]
    
    # Get the latitude and longitude coordinates for the sampled indices
    latitudes = mask_da['lat'].values[random_indices[:, 0]]
    longitudes = mask_da['lon'].values[random_indices[:, 1]]
    
    # Create a dummy variable (e.g., zeros for now)
    dummy_variable = np.random.rand(N)
    
    # Create a Pandas DataFrame with latitudes, longitudes, and the dummy variable
    random_coords_df = pd.DataFrame({
        'lat': latitudes,
        'lon': longitudes,
        'dummy': dummy_variable
    })
    
    # Set the index to ['lat', 'lon'] to match what DataProcessor expects
    random_coords_df = random_coords_df.set_index(['lat', 'lon'])
    
    if data_processor:
        normalized_coords_df = data_processor(random_coords_df, method="min_max")
        return normalized_coords_df.index.to_frame(index=False).values.T
    else:
        return np.vstack((latitudes, longitudes))


In [None]:
# Example usage
N = 100  # Number of random points
random_lake_points = generate_random_coordinates(lakemask_raw, N, data_processor)
random_lake_points

In [None]:
import matplotlib.pyplot as plt

# Assuming random_coords is the (2, N) array from the previous step
latitudes = random_lake_points[0, :]
longitudes = random_lake_points[1, :]

# Create a scatter plot
plt.figure(figsize=(8, 6))
plt.scatter(longitudes, latitudes, color='blue', alpha=0.5, s=10)
plt.title("Scatter plot of N Random Coordinates within Lake Mask")
plt.xlabel("Longitude")
plt.ylabel("Latitude")
plt.show()


In [None]:
from deepsensor.data import TaskLoader
task_loader = TaskLoader(context=[glsea, ice_concentration, bathymetry, lakemask], target=glsea)

In [None]:
task_loader.context

In [None]:
task = task_loader("2011-08-16T00:00:00", context_sampling=random_lake_points, target_sampling="all")

In [None]:
task

In [None]:
fig = deepsensor.plot.task(task, task_loader)
plt.show()

## Attempt task with points sampled from lakes only

In [None]:
task = task_loader("2011-08-16T00:00:00", context_sampling=random_lake_points, target_sampling="all")

In [None]:
task

In [None]:
fig = deepsensor.plot.task(task, task_loader)
plt.show()