In [1]:
import glob
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime

# Paths
SAVE_PATH = 'data/gridsat_cropped/'

# Constants
distance_km = 725  # Half the side of the box
resolution_km = 7.8  # Resolution of the data (rounded to 7.8 so that I get exactyl 128 pixels)
radius = int(np.round(distance_km / resolution_km))  # Number of grid points

ibtracs_format = "%d/%m/%Y %H:%M"
gridsat_format = "%Y.%m.%d.%H"

df = pd.read_csv('data/ibtracs/ibtracs_gridsat_train.csv')[:1000]

In [2]:
def nan_percent(data):
    return 100 - 100*np.count_nonzero(~np.isnan(subset.irwin_cdr))/np.prod(subset.irwin_cdr.shape)

In [3]:
ids_to_delete = []

for idx, row in tqdm(df.iterrows(), total=len(df)):
    gridsat_date = datetime.strptime(row.ISO_TIME, ibtracs_format).strftime(gridsat_format)
    with xr.open_dataset('data/gridsat.nc') as ds:
        subset = ds[['irwin_cdr', 'irwvp', 'vschn']].sel(lat=slice(None), lon=slice(None))

    gridsat_lats = subset.lat.values
    gridsat_lons = subset.lon.values

    # Adjust longitude values for circular continuity
    gridsat_lons_adjusted = np.where(gridsat_lons < 0, gridsat_lons + 360, gridsat_lons)

    # Center point
    lat_c = row.LAT
    lon_c = row.LON if row.LON >= 0 else row.LON + 360

    lat_index = np.argmin(np.abs(gridsat_lats - lat_c))
    lon_index = np.argmin(np.abs(gridsat_lons_adjusted - lon_c))

    first_lat = gridsat_lats[lat_index - radius]
    last_lat = gridsat_lats[lat_index + radius - 1]

    lon_indices = (lon_index - radius, lon_index + radius - 1)
    lon_indices = np.mod(lon_indices, len(gridsat_lons))  # Wrap-around using modulo
    lon_indices = sorted(lon_indices)

    if lon_indices[0] <= lon_indices[1]:
        # Normal case
        cropped_subset = subset.sel(lat=slice(first_lat, last_lat), lon=slice(gridsat_lons[lon_indices[0]], gridsat_lons[lon_indices[1]]))
    else:
        # Wrap-around case
        cropped_subset_part1 = subset.sel(lat=slice(first_lat, last_lat), lon=slice(gridsat_lons[lon_indices[0]], None))
        cropped_subset_part2 = subset.sel(lat=slice(first_lat, last_lat), lon=slice(None, gridsat_lons[lon_indices[1]]))
        cropped_subset = xr.concat([cropped_subset_part1, cropped_subset_part2], dim='lon')

    # Check that the cropped image has the right dimension
    assert cropped_subset.to_array().shape == (3, 1, 186, 186), f'WRONG SHAPE: {cropped_subset.irwin_cdr.shape}'

    # Check that no field of the cropped image has more than 10% missing data
    for variable in cropped_subset.variables:
        if nan_percent(variable) > 15:
            ids_to_delete.append(idx)
            break

    # Fill in any potential missing data using zeros
    cropped_subset = cropped_subset.fillna(0)
    
    # Save cropped image as .nc
    #cropped_subset.to_netcdf(SAVE_PATH + f'GRIDSAT-{row.IDX_TRUE}.nc')

  7%|▋         | 66/1000 [00:21<05:01,  3.10it/s]


AssertionError: WRONG SHAPE: (1, 186, 4959)