# Spatial and temporal resampling of gridded ECMWF data onto a satellite's swath

For the retrieval of physical paramters in the atmosphere via Optimal Estimation we need apriori knowlegdge of the quantity we want to retrieve; this apriori knowledge can be obtained from a forecast model. 

ECMWF's forecasts are available in different spatial and temporal resolutions, in particular in this notebook we work with a *0.25/0.25* regular *lon/lat* grid in **space** and 16 steps for 2 analysis times per day, time/step combination gives **temporal** coverage of data every hour: 
*time* is the reference time where the analysis is performed and observations are used to update the model, whereas *step* is related to the time steps that contain the temporal evolution of the forecast model.

Once the ECMWF's data is available with a given spatial/temporal resolution, we focus on the satellite observations, to be specific "where" are these observations located (i.e. what longitude and latitude); because the satellite does not care much about grids, we rely on it's swath definition: how the instrument *samples* in space and time.

Different instruments (and data providers) might have different versions of how to refer to the swath (however the concept of swath is happily enough, unique); in our case we use CMSAF [data](https://wui.cmsaf.eu/safira/action/viewDoiDetails?acronym=FCDR_MWI_V003) (i.e. Brightness Temperatures). 
Each specific sample (observation) is located at a specific longitude and latitude combination, so our goal is: to interpolate the ECMWF's variable (which is defined on the regular *lon/lat* grid) onto the specific locations in our satellite's swath. 


In this notebook we use [Pyresample](https://pyresample.readthedocs.io/en/latest/)'s functionality to efficiently resample data from a regular grid onto a swath. 

In [None]:
import sys
import os

import xarray as xr

import cf_xarray as cfxr
import xesmf as xe

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import seaborn as sns
import cartopy
import cartopy.crs as ccrs
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER
import pyproj
import pyresample
from pyresample import create_area_def, load_area, data_reduce, utils, AreaDefinition
from pyresample.geometry import SwathDefinition, GridDefinition
from pyresample.kd_tree import resample_nearest, resample_gauss 
from pyresample.bilinear import XArrayBilinearResampler #NumpyBilinearResampler #

from sklearn import svm
from sklearn.linear_model import SGDOneClassSVM
from sklearn.model_selection import train_test_split
from sklearn.kernel_approximation import Nystroem
from sklearn.pipeline import make_pipeline


%matplotlib inline
#%matplotlib notebook

In [None]:
#sys.path.append('/home/mario/Documents/Coursera\
#/Unsupervised/week1/Labs/Lab2/Files/home/jovyan/work')
#from utils import *
print(pyresample.__version__)          

In [None]:
os.environ["MALLOC_TRIM_THRESHOLD_"] = "0"#"65536"

from dask.distributed import Client, progress, LocalCluster

cluster = LocalCluster()
client = Client(cluster)
client

In [None]:
# Satellite data:
#dataSatDir = '/home/mario/Data/CMSAF/ssims/F16/'
dataSatDir = '/home/mario/Data/CMSAF/ssims/F16/ORD47662/'
#dataSatDir = '/nobackup/users/echeverr/data/cmsaf/ssmis/F16/'
fileSatID = 'BTRin20140909000000324SSF1601GL.nc'

# ECMWF data:
#dataECMWFDir ='/home/mario/Data/Covariance_means/MARS_api_data/datasetsApriori/'
#dataECMWFDir = '/nobackup/users/echeverr/data/ECMWF_era5/MARS_api_data/datasetsApriori/'
dataECMWFDir ='/home/mario/Data/Covariance_means/MARS_api_data/datasetsAprioriRegGrid/'

#Test only:

auxDataECMWFDir = '/home/mario/Data/Covariance_means/MARS_api_data/ERA5_data/datasets/'


In [None]:
profile_info = xr.open_mfdataset(dataECMWFDir+'profiles*.grib', 
                                 engine="cfgrib") #, chunks={'time': 50, 'latitude': 50, 'longitude': 200})
surface_info = xr.open_mfdataset(dataECMWFDir+'surface*.grib', 
                                 engine="cfgrib") #, chunks={'time': 50,'latitude': 50, 'longitude': 200})


work_ds = profile_info.merge(surface_info).copy()

work_ds

In [None]:
# This block is only auxiliary: I used "cdo" to convert the reduced Gaussian 
# grid datasets (N320) into a regular 0.25x0.25 deg**2 grid (ECMWF MARS was not
# available at the time of this test, so I could not download the dataset in 
# the regular grid).
# The interpolation uses a bilinear interpolation (cdo documentation); but the
# resulting grid (lon, lat) has slight differences respect to the ECMWF regular grid

# In this block I just take the (lon,lat) from another 0.25x0.25 deg**2 ECMWF
# dataset and replace my cdo regular grid with it for consistency.

aux_info = xr.open_mfdataset(auxDataECMWFDir+'surface*.grib', 
                                 engine="cfgrib")
work_ds['latitude'] = aux_info['latitude'][::-1] # The order in cdo is different
work_ds['longitude'] = aux_info['longitude']
work_ds

In [None]:
#work_ECMWF_ds = work_ds.isel(time=slice(0,2),step=0)
#work_ECMWF_ds

In [None]:
# First create dataset (ECMWF) indexing the right time-step combinations.
# We use stack to have a single reference time at the end:
ECMWF_ds = work_ds.isel(step=slice(0,12)).stack(time2=("time","step"))

# After stacking time/step the new multi-index variable is not very useful
# as a time reference; we then create the new time dimension as the sum
# of the analysis time and the step (so time + step):
ECMWF_ds['time2'] = (work_ds.isel(step=slice(0,12)).time + 
 work_ds.isel(step=slice(0,12)).step).stack(time2=("time","step"))
#ECMWF_ds['longitude'].values = ECMWF_ds.longitude.values - 180.0    # Reset lon to [-180,180]
ECMWF_ds

In [None]:
# Open satellite dataset at highest level (just to get the channels information):

#ds = xr.open_dataset(dataSatDir+fileID)
ds = xr.open_mfdataset(dataSatDir+'*.nc')
ds

In [None]:
# Open specific scenes containing the satellite observations:

scenes_list = ['scene_env1', 'scene_env2']
scene_BT = []

for scene in scenes_list:        
    scene_BT.append(xr.open_mfdataset(
        dataSatDir+'*.nc', combine = 'nested', 
        concat_dim='time', group = scene)) 

#for scene in scenes_list:
    #scene_BT.append(xr.open_dataset(dataSatDir+fileID, group = scene))
    #scene_BT.append(xr.open_mfdataset(dataSatDir+'*.nc', group = scene))

In [None]:
scene_BT[1]

In [None]:
ds_BT = xr.concat(scene_BT, dim = 'scene_channel').drop_vars([
])

ds_BT['lat'] = ds_BT.lat[0,:,:]
ds_BT['lon'] = ds_BT.lon[0,:,:]
ds_BT['eia'] = ds_BT.eia[0,:,:]
ds_BT['sft'] = ds_BT.sft[0,:,:]
ds_BT['qc_fov'] = ds_BT.qc_fov[0,:,:]
ds_BT['laz'] = ds_BT.laz[0,:,:]


In [None]:
ds_BT

In [None]:
ds_aux = ds_BT.assign_coords(time=ds.time).sel(
    scene_channel=[11,12,14,15]).where(ds_BT.sft==0)

ds_aux['central_freq'] = ds['central_freq'][0,0,ds_aux['scene_channel']]


# Create working satellite dataset:

SAT_ds = ds_aux #.drop_dims(drop_dims = ['date','channel'])

In [None]:
SAT_ds

In [None]:
# User defined desired period of time to analyze:
initSat_date = np.datetime64('2014-10-02T00:00:00.000') 
endSat_date = np.datetime64('2014-10-02T00:59:59.000')

# Find best match (e.g. nearest) for the times present in the dataset:
init_date = SAT_ds.time.sel(time=initSat_date, method = "nearest")
end_date = SAT_ds.time.sel(time=endSat_date, method = "nearest")

In [None]:
work_SAT_ds = SAT_ds.sel(time=slice(init_date,end_date),
                              #scene_channel = slice(11,15)
                        )
                             #.transpose(...,"scene_channel")
work_SAT_ds

In [None]:
delta_2h = np.timedelta64(1, 'h') # Useful delta to create overlap in time

# Initial and final times of the overlap between the two datasets (satellite and ECMWF).
# We select as initial time the initial observation time "minus" 2 hours and
# as final time the final observation time "plus" 2 hours
timeOverlapInit = ECMWF_ds.time2.sel(
    time2=work_SAT_ds.time.min() - delta_2h, method = "nearest")
timeOverlapEnd = ECMWF_ds.time2.sel(
    time2=work_SAT_ds.time.max() + delta_2h, method = "nearest")

work_ECMWF_ds = ECMWF_ds.sel(time2 =  slice(timeOverlapInit,timeOverlapEnd)
                            )

# We reorder the dimensions with time2 as last dimension;
# this because we want to exploit pyresample's ability
# to resample multiple "channels" at the same time, as long
# as the "channels" (or time instants in this setting)
# are located in the last dimension:
work_ECMWF_ds = work_ECMWF_ds.transpose(...,
                                        'latitude','longitude','time2')
work_ECMWF_ds

In [None]:
# Define swath using PyResample's SwathDefinition (geometry def.): 
SAT_swath_def = SwathDefinition(lons = work_SAT_ds.lon.values, 
                                lats = work_SAT_ds.lat.values)

# Define grid using PyResample's GridDefiniton (geometry def.):
lon2d,lat2d = np.meshgrid(work_ECMWF_ds.longitude, 
                          work_ECMWF_ds.latitude)
ECMWF_grid_def = GridDefinition(lons=lon2d, lats=lat2d)


In [None]:
# Resampling using nearest neighbour:

ECMWF_on_SAT = resample_nearest(ECMWF_grid_def, work_ECMWF_ds.u10n.values, \
        SAT_swath_def, radius_of_influence=30000, fill_value=None)

In [None]:
# Resampling Gaussian:

ECMWF_on_SAT_gauss = resample_gauss(ECMWF_grid_def, work_ECMWF_ds.u10n.values, \
               SAT_swath_def, radius_of_influence=30000, neighbours=10,\
               sigmas=30000*np.ones(len(work_ECMWF_ds.time2.values)), fill_value=None)

In [None]:
# BilinearResampler for xarray objects does not support either
# SwathDefinition or GridDefinition at present! (26 Sept. 2022)
#resampler = XArrayBilinearResampler(ECMWF_grid_def, SAT_swath_def, 50e3)
#result = resampler.resample(work_ECMWF_ds.u10n[0,:,:])
#result

#resampler = XArrayBilinearResampler(ECMWF_grid_def, SAT_swath_def, 50e3)
#result = resampler.resample(work_ECMWF_ds.u10n[0,:,:])
#result


In [None]:
# Save interpolated data: this evidently needs to improve!
#work_SAT_ds['data']=work_SAT_ds.lat
#work_SAT_ds['data'].data = ECMWF_on_SAT
#work_SAT_ds['data']

work_SAT_ds['u10n_apriori_nn'] = xr.DataArray(
                data   = ECMWF_on_SAT,  # enter data here
                dims   = ['time','scene_across_track','time4interpolation'],
                coords = {'time': work_SAT_ds.time, 
                          'scene_across_track': work_SAT_ds.scene_across_track,
                         'time4interpolation': work_ECMWF_ds.time2.values},
                attrs  = {
                    #'_FillValue': -999.9,
                    'description': 'u10n from ECMWFs forecast resampled with\
                    PyResample (Nearest Neighbour resampler) to satellite swath',
                    'units'     : 'm/s'
                    }
                ) #.chunk({"time": chunk_size_time,
                  #       "scene_across_track": chunk_size_s_a_t})
    
work_SAT_ds['u10n_apriori_nn']

In [None]:
# Save interpolated data: this evidently needs to improve!
#work_SAT_ds['dataGauss']=work_SAT_ds.lat
#work_SAT_ds['dataGauss'].data = ECMWF_on_SAT_gauss
#work_SAT_ds['dataGauss']

work_SAT_ds['u10n_apriori_gauss'] = xr.DataArray(
                data   = ECMWF_on_SAT_gauss,  # enter data here
                dims   = ['time','scene_across_track','time4interpolation'],
                coords = {'time': work_SAT_ds.time, 
                          'scene_across_track': work_SAT_ds.scene_across_track,
                         'time4interpolation': work_ECMWF_ds.time2.values},
                attrs  = {
                    #'_FillValue': -999.9,
                    'description': 'u10n from ECMWFs forecast resampled with\
                    PyResample (Nearest Neighbour resampler) to satellite swath',
                    'units'     : 'm/s'
                    }
                ) #.chunk({"time": chunk_size_time,
                  #       "scene_across_track": chunk_size_s_a_t})
work_SAT_ds['u10n_apriori_gauss']

In [None]:
# Time interpolation:
# We want to interpolate to the middle of the observation batch:
time_interp = init_date + (end_date-init_date)/2 

# Select the nearest valid time i.e. a time instant
# that exists in the observations:
time_interp = work_SAT_ds.time.sel(time=time_interp, method = "nearest")

# Interpolate using xarray's (scipy under the hood) capabilities:
# xarray.interp documentation: https://docs.xarray.dev/en/stable/user-guide/interpolation.html
# Check xarray.interp documentation for explanation on the use of: 
# interpolate_na() and dropna()

#work_SAT_ds['u10n_apriori_gauss_interp'] =\
#    work_SAT_ds.u10n_apriori_gauss.interpolate_na("time").dropna("time")\
#               .interp(time4interpolation=time_interp.values, method="cubic")
work_SAT_ds['u10n_apriori_gauss_interp'] =\
    work_SAT_ds.u10n_apriori_gauss.interp(time4interpolation=time_interp.values, method="linear")

work_SAT_ds['u10n_apriori_gauss_interp'] 

In [None]:
def defineArea(corners, proj_id, datum):
    #corners=parseMeta(data_name)

    lat_0 = '{lat_0:5.2f}'.format_map(corners)
    lon_0= '{lon_0:5.2f}'.format_map(corners)
    lon_bbox = [corners['min_lon'],corners['max_lon']]
    lat_bbox = [corners['min_lat'],corners['max_lat']]
    area_dict = dict(datum=datum,lat_0=lat_0,lon_0=lon_0,
                proj=proj_id,units='m')

    #area_dict = dict(datum=datum,lat_0=-15,lon_0=60,
    #            proj=proj_id,units='m',a=6370997.0,)

    prj=pyproj.Proj(area_dict)
    x, y = prj(lon_bbox, lat_bbox)
    xsize=200
    ysize=200
    area_id = 'granule'
    area_name = 'modis swath 5min granule'
    area_extent = (x[0], y[0], x[1], y[1])
    print(area_extent)
    area_def = AreaDefinition(area_id, area_name, proj_id, 
                                   area_dict, xsize, ysize,area_extent)
    return area_def



In [None]:

# Creation of area of interest:
#corners = {"min_lon": 25 , "max_lon": 75, "min_lat": -30 , "max_lat": 0, "lat_0": 60, "lon_0":-15}
corners = {"min_lon": -95 , "max_lon": 20, "min_lat": 3 , "max_lat": 50, "lat_0": 27, "lon_0":-57}
proj_id = 'eqc'  # eqc
datum = 'WGS84'
area_interest = defineArea(corners, proj_id, datum)


area_def_world = load_area('areas.yaml', 'worldeqc30km')# 'worldeqc30km70') # for plots


In [None]:
def get_Sat_frame(ds, area_interest, chan = 0, var=None, begin_t=None, end_t=None):
    
    grid_lons_interest, grid_lats_interest = area_interest.get_lonlats()

    swathDef = SwathDefinition(lons=ds.lon.values, lats=ds.lat.values)
    lon_scene, lat_scene = swathDef.get_lonlats()

    if(chan>=0):
        
        reduced_lon_scene, reduced_lat_scene, reduced_data_scene = \
                           data_reduce.swath_from_lonlat_grid(
            grid_lons_interest, grid_lats_interest,
            lon_scene, lat_scene, ds[var][:,:,chan].values,
            radius_of_influence=3000)
    else:
        reduced_lon_scene, reduced_lat_scene, reduced_data_scene = \
                           data_reduce.swath_from_lonlat_grid(
            grid_lons_interest, grid_lats_interest,
            lon_scene, lat_scene, ds[var][:,:].values,
            radius_of_influence=3000)

    return reduced_lon_scene, reduced_lat_scene, reduced_data_scene

In [None]:
def get_TB_frame(ds, area_interest, channel, begin_t=None, end_t=None):
    
    grid_lons_interest, grid_lats_interest = area_interest.get_lonlats()

    swathDef = SwathDefinition(lons=ds.lon.values, lats=ds.lat.values)
    lon_scene, lat_scene = swathDef.get_lonlats()

    reduced_lon_scene, reduced_lat_scene, reduced_data_scene = \
                           data_reduce.swath_from_lonlat_grid(
        grid_lons_interest, grid_lats_interest,
        lon_scene, lat_scene, ds.tb[:,channel,:].values,
        radius_of_influence=3000)

    return reduced_lon_scene, reduced_lat_scene, reduced_data_scene

In [None]:
def basicMapPlotScat(x,y,data,namefile, area, vmin=0, vmax=300):
    # Make a Mercator map of the data using Cartopy
    
    crs = area.to_cartopy_crs()
    
    fig = plt.figure(figsize=(8, 6))
    #plt.figure(figsize=(8, 6))
    ax = plt.axes(projection=crs)   
    ax.add_feature(cartopy.feature.LAND, zorder=0, edgecolor='black')
    ax.set_global()
    ax.gridlines()        
    ax.set_title("TB")
    
    gl = ax.gridlines(crs=ccrs.PlateCarree(), linewidth=0.1, 
                      color='black', alpha=0.5, linestyle='--', draw_labels=True)
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER    

    # Plot the air temperature as colored circles and the wind speed as vectors.
    im = ax.scatter(
        x,
        y,
        c=data,
        s=0.15,
        cmap="viridis",
        transform=ccrs.PlateCarree(),
        #vmin=3, vmax=18         #180, 270
        #vmin=130, vmax=270         #180, 270
        vmin=vmin, vmax=vmax         #180, 270
    )
    fig.colorbar(im).set_label("Brightness temperature [K]")
    
# Use an utility function to add tick labels and land and ocean features to the map.

    #plt.tight_layout()
    plt.show()
    #plt.savefig(namefile+'.png', bbox_inches='tight', dpi=150) 
    
def basicMapPlotScat1(x,y,data,namefile, area, vmin=0, vmax=300):
    # Make a Mercator map of the data using Cartopy
    
    fig = plt.figure()
    
    ortho = ccrs.PlateCarree() #ccrs.Orthographic(60,-15)
    ax = plt.axes(projection=ortho)
    
    #crs = ccrs.RotatedPole(pole_longitude=177.5, pole_latitude=37.5)
    geo = ccrs.Geodetic()
    #crs = ccrs.Orthographic(60,-15)
    
    ax.add_feature(cartopy.feature.LAND, zorder=0, edgecolor='black')
    
    xy = ortho.transform_points(geo, x, y)

    ax.set_global()
    ax.gridlines()    
    
    #ax.set_title("TB")
    #ax.coastlines() 
    # Plot the air temperature as colored circles and the wind speed as vectors.
    im = ax.scatter(
        xy[:,0],
        xy[:,1],
        c=data,
        s=0.05,
        cmap="viridis",
        #transform=crs,
        #vmin=3, vmax=18,  # 180, 270
        #vmin=130, vmax=270         #180, 270
        vmin=vmin, vmax=vmax         #180, 270       
    )
    #fig.colorbar(im).set_label("10m Wind Speed, HOAPS [m/s]")
    fig.colorbar(im).set_label("Temp. Bright [K]")
    
# Use an utility function to add tick labels and land and ocean features to the map.

    plt.tight_layout()
    #plt.show()
    plt.savefig(namefile+'.png', bbox_inches='tight', dpi=300)      

In [None]:
for channel in range(4):
    reduced_lon_scene, reduced_lat_scene, reduced_data_scene =\
    get_TB_frame(work_SAT_ds, area_def_world, channel)
    
    basicMapPlotScat1(reduced_lon_scene, reduced_lat_scene, reduced_data_scene,
                 'scene_channel_'+str(channel), area_interest, vmin=130, vmax=270)

In [None]:
# Plot the resampled (Nearest neighb.) ECMWF data in the new 'grid' 
# (i.e. the satellite swath):

reduced_lon_scene, reduced_lat_scene, reduced_data_scene =\
get_Sat_frame(work_SAT_ds, area_def_world, chan=3, 
              var = 'u10n_apriori_nn', begin_t=None, end_t=None)

basicMapPlotScat1(reduced_lon_scene, reduced_lat_scene, reduced_data_scene,
                 'resampled', area_interest, vmin=-25, vmax=25)

In [None]:
# Plot the resampled (Gaussian interp.) ECMWF data in the new 'grid' 
# (i.e. the satellite swath):

#reduced_lon_scene, reduced_lat_scene, reduced_data_scene =\
#get_Sat_frame(work_SAT_ds, area_def_world, var='dataGauss', begin_t=None, end_t=None)

reduced_lon_scene, reduced_lat_scene, reduced_data_scene =\
get_Sat_frame(work_SAT_ds, area_def_world, chan=3, 
              var = 'u10n_apriori_gauss', begin_t=None, end_t=None)

basicMapPlotScat1(reduced_lon_scene, reduced_lat_scene, reduced_data_scene,
                 'resampledGauss', area_interest, vmin=-25, vmax=25)

In [None]:
work_SAT_ds['differenceNN_GN'] = np.abs(work_SAT_ds['u10n_apriori_gauss'][:,:,0]-
                                        work_SAT_ds['u10n_apriori_gauss'][:,:,3])

reduced_lon_scene, reduced_lat_scene, reduced_data_scene =\
get_Sat_frame(work_SAT_ds, area_def_world, chan=-1, 
              var = 'differenceNN_GN', begin_t=None, end_t=None)

basicMapPlotScat1(reduced_lon_scene, reduced_lat_scene, reduced_data_scene,
                 'difference_NN_GN', area_interest, vmin=0, vmax=0.5)

In [None]:
#work_SAT_ds['u10n_apriori_gauss_interp'] 

reduced_lon_scene, reduced_lat_scene, reduced_data_scene =\
get_Sat_frame(work_SAT_ds, area_def_world, chan=-1, 
              var = 'u10n_apriori_gauss_interp', begin_t=None, end_t=None)

basicMapPlotScat1(reduced_lon_scene, reduced_lat_scene, reduced_data_scene,
                 'spaceTimeInterpolated', area_interest, vmin=-25, vmax=25)

In [None]:
# Plot origin data (ECMWF on regular grid, to compare with the resampled one):

fig = plt.figure()
ax = plt.axes(projection=ccrs.PlateCarree())

ax.coastlines()
ax.gridlines()
work_ECMWF_ds.u10n[:,:,0].where(
    work_ECMWF_ds.lsm[:,:,0]==0).plot(ax=ax,
    transform=ccrs.PlateCarree(), cmap="viridis")
ax.scatter(reduced_lon_scene, reduced_lat_scene,marker='.',color='red')
plt.tight_layout()
plt.savefig('allWind_and_swath.png', bbox_inches='tight', dpi=300) 

In [None]:
def mapPlotScatZoom(x,y,data,namefile, mini, maxi, orthoCenter=None, area=None):
    # Make a Mercator map of the data using Cartopy
      
    fig = plt.figure()
    
    if(area==None):
        #ortho = ccrs.Orthographic(0,-15) # ccrs.Orthographic(60,-15)
        ortho = ccrs.PlateCarree()
        ax = plt.axes(projection=ortho)
        
        #crs = ccrs.RotatedPole(pole_longitude=177.5, pole_latitude=37.5)
        geo = ccrs.PlateCarree() #ccrs.Geodetic()
        #crs = ccrs.Orthographic(60,-15)
    else:
        crs = area.to_cartopy_crs()
        #ortho = crs.Orthographic(0,-15) # crs.Orthographic(60,-15)
        ortho = crs #crs.PlateCarree()
        ax = plt.axes(projection=ortho)
        
        #crs = crs.RotatedPole(pole_longitude=177.5, pole_latitude=37.5)
        geo = ccrs.PlateCarree() #ccrs.Geodetic()
        #crs = crs.Orthographic(60,-15)
        
    
    ax.add_feature(cartopy.feature.LAND, zorder=0, edgecolor='black',linewidth=0.1)
    
    xy = ortho.transform_points(geo, x, y)

    ax.set_global()
    #ax.gridlines()    
    gl = ax.gridlines(crs=ccrs.PlateCarree(), linewidth=0.07, 
                      color='black', alpha=0.5, linestyle='--', draw_labels=True)
    gl.xformatter = LONGITUDE_FORMATTER
    gl.yformatter = LATITUDE_FORMATTER  
    
    work_ECMWF_ds.u10n[:,:,0].where(
    work_ECMWF_ds.lsm[:,:,0]==0).plot(ax=ax,
    transform=ccrs.PlateCarree(), cmap="viridis")
    
    # Plot the air temperature as colored circles and the wind speed as vectors.
    im = ax.scatter(
        xy[:,0],
        xy[:,1],
        #c=data,
        #marker='.',
        s=6, #0.15
        edgecolors= 'none',
        marker = matplotlib.markers.MarkerStyle(marker='o',fillstyle='full'),#"o",
        color='red'
        #cmap="viridis",
        )
    #fig.colorbar(im).set_label("10m Wind Speed, RadEst [m/s]")
    
# Use an utility function to add tick labels and land and ocean features to the map.
    
    plt.tight_layout()
    #plt.show()
    plt.savefig(namefile+'.jpg', bbox_inches='tight', dpi=900)  


corners = {"min_lon": 55 , "max_lon": 57, "min_lat": -22 , "max_lat": -20, "lat_0": 0, "lon_0":0}
proj_id = 'eqc'  # eqc
datum = 'WGS84'
area_interest = defineArea(corners, proj_id, datum)
#grid_lons_zoom, grid_lats_zoom = area_interest.get_lonlats()


zoom_lon_scene, zoom_lat_scene, zoom_data_scene =\
get_Sat_frame(work_SAT_ds, area_interest, chan=-1, 
              var = 'u10n_apriori_gauss_interp', begin_t=None, end_t=None)

#basicMapPlotScat1(zoom_lon_scene, zoom_lat_scene, zoom_data_scene,
#                 'spaceTimeInterpolated', area_interest, vmin=-25, vmax=25)

mapPlotScatZoom(zoom_lon_scene, zoom_lat_scene, zoom_data_scene,
                 'zoom3', -25,25,area=area_interest)

In [None]:
zoom_lon_scene

In [None]:
# Some histograms:

#ds_tb_log = np.log10(ds_work.tb[:,0,:]) 
#ds_work.tb[:,0,:].plot.hist(bins=20,)
#ds_tb_log.plot.hist(bins=30,)

In [None]:
def bigHistogram(da, numbins=20):
    # Computing histogram of all the values contained in dataarray da:
    # We resort to this way of computing the histogram because
    # the normal xarray.plot.hist produced strange plots:

    datamin = np.nanmin(da.values)
    datamax = np.nanmax(da.values)
    #numbins = 20

    delta = (datamax-datamin)/numbins
    mybins =np.linspace(datamin+delta/2,
                    datamax-delta/2,
                    numbins) # Bins midpoint locations
    # Cycle in time:
    #hist, _ = np.histogram(da.isel(time=0).values.ravel(), bins = numbins,
    #                       range=(np.nanmin(da.isel(time=0)),np.nanmax(da.isel(time=0))))
    #for i in range(1, len(da["time"])):
    #    hist += np.histogram(da.isel(time=i).values.ravel(), bins = numbins,
    #                        range=(np.nanmin(da.isel(time=i)),np.nanmax(da.isel(time=i))))[0]

    hist, _ = np.histogram(da.isel(scene_across_track=0).values.ravel(), bins = numbins,
                       range=(np.nanmin(da.isel(scene_across_track=0)),
                              np.nanmax(da.isel(scene_across_track=0))))
    for i in range(1, len(da["scene_across_track"])):
        hist += np.histogram(da.isel(scene_across_track=i).values.ravel(), bins = numbins,
                        range=(np.nanmin(da.isel(scene_across_track=i)),
                               np.nanmax(da.isel(scene_across_track=i))))[0]
        print('Step '+str(i)+' of '+
             str(len(da["scene_across_track"]))+
             ' done!')
    
    return hist, mybins


In [None]:
# channels: 
# 0 => 19 GHz, H
# 1 => 19 GHz, V
# 2 => 37 GHz, H
# 3 => 37 GHz, V

In [None]:
da = SAT_ds.tb[:,3,:].dropna(
    dim='time', how='all').chunk(
    chunks={'time':45000})

numbins = 20
hist, bins = bigHistogram(da, numbins=numbins)

In [None]:
# Plot histogram using seaborn:
plt.figure()
sns.histplot(x=bins, weights=hist, discrete=False, bins=numbins)
plt.xlabel('Temperature Brightness [K] ')
plt.grid(visible=True)
plt.title('Distribution of Temp. Brightness in channel 37V')
plt.savefig('hist_TB_channel37V.png',dpi =150) 

In [None]:
da0 = SAT_ds.tb[:,0,:].dropna(
    dim='time', how='all').chunk(
    chunks={'time':45000})
da1 = SAT_ds.tb[:,3,:].dropna(
    dim='time', how='all').chunk(
    chunks={'time':45000})


In [None]:
da0

In [None]:
plt.figure()
plt.scatter(da0.stack(index=("time","scene_across_track")), 
           da1.stack(index=("time","scene_across_track")))
plt.xlabel('Temperature Brightness [K], 19H')
plt.ylabel('Temperature Brightness [K], 37V')
plt.grid(visible=True)
plt.title('Scatter plot 19H vs 37V')
#plt.show()
plt.savefig('scatter_19H_37V.png',dpi =150) 

In [None]:
ds_tb = SAT_ds.tb[:,:,:].dropna(
    dim='time', how='all')
ds_tb

In [None]:
nrows = SAT_ds.tb[:,:,:].stack(
    index=('time','scene_across_track'
          )).transpose("index", "scene_channel"
                      ).dropna(how='all', dim = 'index'
                   ).to_pandas().shape[0] #.to_csv('scores.csv')

newIndex = np.arange(nrows)

dataframe_TB = SAT_ds.tb[:,:,:].stack(
    index=('time','scene_across_track'
          )).transpose("index", "scene_channel"
                      ).dropna(how='all', dim = 'index'
                              ).to_pandas().set_index(
    keys=newIndex)
dataframe_TB.index.name = 'example'
dataframe_TB #.to_csv('eigenVal.csv')

In [None]:
#dataframe_TB.to_csv('dataframe_TB.csv')
dataframe_TB = pd.read_csv('dataframe_TB.csv')
del dataframe_TB['example']
dataframe_TB.index.name = 'example'
dataframe_TB

In [None]:
X_train, X_test = train_test_split(dataframe_TB, test_size=0.2, random_state=42)

In [None]:
X_test, X_outliers = train_test_split(X_test, test_size=0.2, random_state=42)

In [None]:
X_outliers.iloc[0:100000,:] = X_outliers.iloc[0:100000,:] + 3
X_outliers.iloc[100001:200000,:] = X_outliers.iloc[100001:200000,:] - 3
X_outliers.iloc[200001:300000,:] = X_outliers.iloc[200001:300000,:] + 5
X_outliers.iloc[300001:400000,:] = X_outliers.iloc[300001:400000,:] - 5
X_outliers.iloc[400001:500000,:] = X_outliers.iloc[400001:500000,:] + 10
X_outliers.iloc[500001:600000,:] = X_outliers.iloc[500001:600000,:] - 10
X_outliers.iloc[600001:645857,:] = X_outliers.iloc[600001:645857,:] + 15

In [None]:
# fit the model

#clf = svm.OneClassSVM(nu=0.1, kernel="rbf", gamma=0.1, verbose = 1)
#clf.fit(X_train)
#y_pred_train = clf.predict(X_train)
#y_pred_test = clf.predict(X_test)
#y_pred_outliers = clf.predict(X_outliers)
#n_error_train = y_pred_train[y_pred_train == -1].size
#n_error_test = y_pred_test[y_pred_test == -1].size
#n_error_outliers = y_pred_outliers[y_pred_outliers == 1].size

nu = 0.05
gamma = 2.0
random_state = 42
# Fit the One-Class SVM using a kernel approximation and SGD
transform = Nystroem(gamma=gamma, random_state=random_state)
clf_sgd = SGDOneClassSVM(nu=nu, shuffle=True, 
                         fit_intercept=True, random_state=random_state, 
                         tol=1e-4, verbose = 1)

pipe_sgd = make_pipeline(transform, clf_sgd)
pipe_sgd.fit(X_train)
y_pred_train_sgd = pipe_sgd.predict(X_train)
y_pred_test_sgd = pipe_sgd.predict(X_test)
y_pred_outliers_sgd = pipe_sgd.predict(X_outliers)
n_error_train_sgd = y_pred_train_sgd[y_pred_train_sgd == -1].size
n_error_test_sgd = y_pred_test_sgd[y_pred_test_sgd == -1].size
n_error_outliers_sgd = y_pred_outliers_sgd[y_pred_outliers_sgd == 1].size

In [None]:
def covariance(da):
    
    # Inputs:
    # da, xarray datarray
    
    # Outputs:
    # listMatrices, list of covariances to be shaped as a numpy 2D array.
    
    listMatrices = []
    #listIndices = []
    
    for channel1 in da.scene_channel:
        for channel2 in da.scene_channel:
        
            listMatrices.append(  # Compute the variance and append it to the list of variances.
                xr.cov( da.sel(scene_channel=channel1).stack(
                    index=('time','scene_across_track')).chunk(
                    chunks={'index':1000000}), 
                   da.sel(scene_channel=channel2).stack(
                    index=('time','scene_across_track')).chunk(
                    chunks={'index':1000000}), 
                       dim='index').compute().values
            ) 
            #print('Variable: '+str(channel1)+str(channel2)+', appended')          

    print("Computed variances: ")
    print(listMatrices)
    #print(listIndices)
    
    
    # Return the diagonal matrix of covariances and the names of the indices
    return listMatrices #np.diag(out), listIndices  

In [None]:
# With xarray option 1 (only diagonal terms):

#ds_cov = xr.cov(ds_tb, ds_tb, dim = 'index')
#ds_cov

# With xarray option 2 (full matrix):
covList = covariance(ds_tb)
covList

covMatrix = np.asarray(covList).reshape((4,4))


In [None]:
eigenVal, eigenVec = np.linalg.eig(covMatrix)

In [None]:
eigVal_DataArray = xr.DataArray(data=np.diag(eigenVal), 
                                dims=['channel_latentSpace','channel_latentSpace_T'])
eigVal_DataArray

In [None]:
cov_DataArray = xr.DataArray(data=covMatrix, 
                             dims=['scene_channel','scene_channel_T'])
cov_DataArray

In [None]:
eigenVec_DataArray = xr.DataArray(data=eigenVec, 
                                  dims=['scene_channel','scene_channel_reduced'])
eigenVec_DataArray

In [None]:
cov_DataArray.to_pandas().to_csv('covariance.csv')
eigenVec_DataArray.to_pandas().to_csv('eigenVec.csv')
eigVal_DataArray.to_pandas().to_csv('eigenVal.csv')

In [None]:

#ds_T = ds_tb.stack(
#    index=('time','scene_across_track')).chunk(
#    chunks={'index':1000000}).dot(w_DataArray)
#ds_T

ds_T = xr.dot(ds_tb.stack(
    index=('time','scene_across_track')).chunk(
    chunks={'index':1000000}), 
              eigenVec_DataArray)
ds_T

In [None]:
plt.figure()
plt.scatter(ds_T[:,0], 
           ds_T[:,1])
plt.xlabel('Scores_0 [Units]')
plt.ylabel('Scores_3 [Units]')
plt.grid(visible=True)
plt.title('Scatter plot Scores_0 vs Scores_3')
#plt.show()
plt.savefig('scatter_Scores_0_Scores_3.png',dpi =150) 

In [None]:


fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(ds_T[:,0], 
           ds_T[:,1], ds_T[:,2])

ax.set_xlabel('Scores_0 [Units]')
ax.set_ylabel('Scores_1 [Units]')
ax.set_zlabel('Scores_2 [Units]')

#plt.grid(visible=True)
#plt.title('Scatter plot Scores_0_1_2')
plt.show()

In [None]:
nrows = ds_T.dropna(how='all', dim = 'index').to_pandas().shape[0] #.to_csv('scores.csv')
newIndex = np.arange(nrows)

dataframe_scores = ds_T.dropna(how='all', dim = 'index'
                              ).to_pandas().set_index(
    keys=newIndex)
dataframe_scores.index.name = 'example'
dataframe_scores #.to_csv('eigenVal.csv')

In [None]:
dataframe_scores.to_csv('scores.csv')

In [None]:
#scores = pd.read_csv('scores.csv')
#scores
dataframe_scores.iloc[:,0]

In [None]:
# Plot histogram using seaborn:
plt.figure()
sns.histplot(data = dataframe_scores.iloc[:,0], bins=20)
plt.xlabel('Score_0')
plt.grid(visible=True)
plt.title('Distribution of Score 0')
plt.savefig('hist_Score0.png',dpi =150) 

In [None]:
plt.figure()
sns.jointplot(dataframe_scores.iloc[:,0:2], x = 0, y = 1)
plt.xlabel('Score_0')
plt.ylabel('Score_1')
plt.grid(visible=True)
plt.title('Distribution of Score 0 and 1')
plt.savefig('JoinPlot_Score0_1.png',dpi =150) 

In [None]:
plt.figure()
sns.displot(dataframe_scores.iloc[:,0:2], x = 0, y = 1)
plt.xlabel('Score_0')
plt.ylabel('Score_1')
plt.grid(visible=True)
plt.title('Distribution of Score 0 and 1')
plt.savefig('hist2D_Score0_1.png',dpi =150) 

In [None]:
scores = pd.read_csv('scores.csv')

In [None]:
scores