In [1]:
import numpy as np
import xarray as xr
import gc

In [2]:
ds = xr.open_dataset('/Odyssey/public/glorys/reanalysis/glorys12_2010_2019_daily_zos.nc', chunks=365)

In [3]:
year_list = ["2010", "2011", "2012", "2013", "2014", "2015", "2016", "2017", "2018", "2019"]
adt_list =[]
for i in year_list :
    adt_list.append(ds.rename({"zos" : "adt"}).sel(time = i))
    
del ds
gc.collect

<function gc.collect(generation=2)>

In [4]:
def makeAnomaly(input_ds : xr.Dataset) -> xr.Dataset :
    
    mean = input_ds.groupby("time.month").mean(dim="time")
    anomaly = (input_ds.groupby("time.month") - mean).rename({"adt" : "sla"}).load()
    del mean
    gc.collect
    
    return anomaly


def interpolate(input_ds : xr.Dataset) -> xr.Dataset :
    
    # for i in range (0,10) :
        
    #     chunck_size = input_ds.coords["time"].size 
    #     processing = input_ds.isel(time = slice(int(i / 10 * chunck_size), int((i + 1) / 10 * chunck_size)))
    #     chunck_interp = (processing.interp(
    #                                 coords=dict(
    #                                     latitude=np.arange(input_ds.latitude[0], input_ds.latitude[-1], float(1/8)),
    #                                     longitude=np.arange(input_ds.longitude[0], input_ds.longitude[-1], float(1/8))
    #                                 ),
    #                                 method="linear"
    #                             )
    #                     )
        
    #     interp_list.append(chunck_interp)
    #     del chunck_interp, processing
    #     gc.collect

    new_lat = np.arange(input_ds.latitude[0], input_ds.latitude[-1], float(1/8))
    new_lon = np.arange(input_ds.longitude[0], input_ds.longitude[-1], float(1/8))
    
    chunked = input_ds.chunk(365)    
    lat_interped = chunked.interp(latitude=new_lat, method="linear", assume_sorted=True)
    interpolated = lat_interped.interp(longitude = new_lon, method= "linear", assume_sorted=True)

                
    return interpolated.load()
        
def getNorm(input_ds : xr.Dataset) -> list[float] :
    chunked = input_ds.chunk(365)
    
    n = chunked.count()
    sum = chunked.sum(skipna=True)
    mu = (sum / n).load()
    var = ((chunked - mu)**2).sum(skipna=True) / n
    sigma = var.load()**0.5
    
    out_list = [mu, sigma]
    
    del n, sum, mu, var, sigma
    gc.collect
    
    return out_list


In [5]:
processed_list = []

for i in adt_list :

    a = makeAnomaly(i)
    b = interpolate(a)
    processed_list.append(b)
    
    del a, b
    gc.collect
    
del adt_list


In [6]:
processed_ds = xr.concat(processed_list, dim="time").drop_vars("month")
del processed_list
gc.collect

<function gc.collect(generation=2)>

In [7]:
processed_ds.to_netcdf("/Odyssey/public/glorys/reanalysis/glorys12_2010_2019_daily_sla_8th_test.nc")

In [8]:
from datetime import timedelta, datetime
start = processed_ds.time[0].dt.date
last = processed_ds.time[-1].dt.date
train_ds = processed_ds.sel(time = slice(start, last-timedelta(days=365)))
valid_ds = processed_ds.sel(time = slice(last-timedelta(days=365), last))

In [9]:
del processed_ds
gc.collect


<function gc.collect(generation=2)>

In [11]:

print(f"Training normalizer : {getNorm(train_ds)}")


Training normalizer : [<xarray.Dataset> Size: 8B
Dimensions:  ()
Data variables:
    sla      float64 8B -2.073e-20, <xarray.Dataset> Size: 8B
Dimensions:  ()
Data variables:
    sla      float64 8B 0.04391]


In [10]:

print(f"Validation normalizer : {getNorm(valid_ds)}")

Validation normalizer : [<xarray.Dataset> Size: 8B
Dimensions:  ()
Data variables:
    sla      float64 8B -3.198e-06, <xarray.Dataset> Size: 8B
Dimensions:  ()
Data variables:
    sla      float64 8B 0.04427]
