# Tester notebook for RCC

In [None]:
import xarray as xr
import rioxarray
import matplotlib.pyplot as plt
import geopandas as gpd
from shapely.geometry import mapping
from tqdm import tqdm
import pandas as pd

#from clustering import *
from radially_constrained_cluster import *

In [None]:
italy_region = gpd.read_file('/work/users/jgrassi/data/italy.shp')

def plot_map(geo_df, column_to_color, cmap, vmin, vmax, title):

    # Defining the figure
    fig, ax = plt.subplots(1,1, figsize=(6,6)) 

    
    # Plot the map
    cax = geo_df.plot(column=column_to_color, cmap=cmap, ax=ax, legend=False, vmin=vmin, vmax=vmax)

    # Adding regional boundaries
    italy_region.boundary.plot(color = 'k', linewidth = 0.5, label = 'Regioni', ax=ax)

    # Customize the axes
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_xlabel('Longitude [째E]')
    ax.set_ylabel('Latitude [째N]')
    ax.grid(alpha=0.3)

    # Creating a custom colorbar
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=vmin, vmax=vmax))
    sm._A = []
    cbar = fig.colorbar(sm, ax=ax)

    ax.set_title(title)

    return ax, fig, cbar

In [None]:
data_path = '/work/users/jgrassi/data/ERA5'

data_obj = []

for files in ['2m_temperature/temperature_19*.nc', 'total_precipitation/precipitation_19*.nc']:

    obj = xr.open_mfdataset(rf'{data_path}/{files}')
    obj = obj.convert_calendar('noleap').rolling(time=15*24, center=True).mean().load()

    obj.rio.set_spatial_dims(x_dim="longitude", y_dim="latitude", inplace=True)
    obj.rio.write_crs("epsg:4326", inplace=True)

    data_obj.append(obj)

In [None]:
italy_region = gpd.read_file('/work/users/jgrassi/data/italy.shp')


# periods = [['1979','1982'],['1982','1985'],['1985','1988'],['1988','1991'],['1991','1994'],['1994','1997'],['1997','2000'],['2000','2003'],['2003','2006'],['2006','2009']]첫
periods = [['1979','1982'],['1982','1985'],['1985','1988'],['1988','1991'],['1991','1994'],['1994','1997'],['1997','2000']]

bp_macro = []
tot = []
all_arrays = []

for j in tqdm(range(20)):

    gpp = italy_region.loc[j]
    new = gpd.GeoDataFrame(gpp)
    new.columns = ['geometry']
    gdf = new.set_geometry('geometry')


    arrays = []

    for obj, var in zip(data_obj, ['t2m','tp','sfcWind']):

        for p in periods:

            obj_rast = obj.sel(time=slice(p[0], p[1]))
            obj_rast = obj_rast.groupby('time.dayofyear').mean()
            obj_rast = obj_rast.rio.clip(gdf.geometry.apply(mapping), gdf.crs, drop=True)

            obj_array = obj_rast[var].to_numpy()
            obj_array = np.reshape(obj_array, (365, obj_array.shape[1]*obj_array.shape[2]))
            obj_array = obj_array[:,~np.all(np.isnan(obj_array), axis=0)]
            #obj_array = (obj_array - obj_array.min(axis=0)) / (obj_array.max(axis=0) - obj_array.min(axis=0))
            obj_array = (obj_array - obj_array.mean(axis=0)) / (obj_array.std(axis=0))
            arrays.append(obj_array)



    # concatenate the two arrays on columns
    array_tot = np.concatenate((arrays), axis=1)
    
    #array_tot = array_tot[:,~np.all(np.isnan(array_tot), axis=0)]
    print(array_tot.shape)

    # min-max normalization on axis 1
    #array_tot = (array_tot - array_tot.min(axis=1).reshape(-1, 1)) / (array_tot.max(axis=1) - array_tot.min(axis=1)).reshape(-1, 1)


    model = Radially_Constrained_Cluster(data_to_cluster = array_tot,
                                     n_seas = 4,
                                     n_iter = 5000,
                                     learning_rate = 2,
                                     min_len = 60,
                                     scheduling_factor = 1,
                                     starting_bp = [60, 152, 244, 335],
                                     mode = 'single')

    model.fit()

    fig_learning = plt.figure(figsize = (7,3))
    plt.plot(model.error_history)
    plt.grid()
    plt.title(p)
    plt.xlabel('N째 iterations')
    plt.ylabel('Within-Cluster Sum of Square')

    gdf['spring_onset'] = model.breakpoints[0]
    gdf['summer_onset'] = model.breakpoints[1]
    gdf['autumn_onset'] = model.breakpoints[2]
    gdf['winter_onset'] = model.breakpoints[3]

    gdf['spring_lenght'] = model.breakpoints[1] -model.breakpoints[0]
    gdf['summer_lenght'] = model.breakpoints[2] -model.breakpoints[1]
    gdf['autumn_lenght'] = model.breakpoints[3] -model.breakpoints[2]
    gdf['winter_lenght'] = 365 - model.breakpoints[3] + model.breakpoints[0]

    tot.append(gdf)
    all_arrays.append(array_tot)

tot = pd.concat(tot)

In [None]:
for t in all_arrays:
    plt.figure()
    plt.imshow(t)
    plt.colorbar()

In [None]:
for onset, title, cmap in zip(['spring_onset','summer_onset','autumn_onset','winter_onset'],
                            ['Spring onset','Summer onset','Autumn onset','Winter_onset'],
                            ['spring', 'summer','autumn','winter']):
    
    fig, ax, cbar = plot_map(tot, onset, cmap, tot[onset].min(), tot[onset].max(), title)
    cbar.set_ticks(ticks=[0,31,60,91,121,152,182,213,244,274,305,335, 365],labels=['jan','feb','mar','apr','may','jun','jul','aug','sep','oct','nov','dec','jan'])



In [None]:
for lenght, title, cmap in zip(['spring_lenght','summer_lenght','autumn_lenght','winter_lenght'],
                            ['Spring lenght','Summer lenght','Autumn lenght','Winter lenght'],
                            ['Greens', 'Purples','Oranges','Blues']):
    
    fig, ax, cbar = plot_map(tot, lenght, cmap, tot[lenght].min(), tot[lenght].max(), title)


In [None]:
tot