# Import Packages

In [1]:
#from netCDF4 import Dataset
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from datetime import datetime, timedelta
import cfgrib
import numpy as np
from scipy.interpolate import interpn

# Constants

In [2]:
Ra=6370000.0 # m : Radius of the Earth
cp=1005.0 # J/kg/K : Specific heat of dry air
Lv=2500000 # J/kg : Latent heat of vaporization at 0◦C
ZR=287.05 #J/K/mol: Gas constant for dry air
ZKP=ZR/cp  # Ra/Cp (needed for potential temperature calculation)
g=9.81 # N/kg: Acceleration due to gravity at sea level

# Functions definition

#### Read model data

In [None]:
def read_model(root_exp,file_name, file_type,list_coord,
               longitude_slice=(-180,180), latitude_slice=(-90,90),
               grib_selection={'typeOfLevel': 'isobaricInhPa'}
               ):
    
    if file_type=='grib':
        engine="cfgrib"
        backend_kwargs={'filter_by_keys': grib_selection}
    elif file_type=='netcdf':
        engine=None
        backend_kwargs=None


            
    if type(file_name)==str:
        file_data=converthPa_Pa(
                        convertIncreaseLat(
                                    convert360_180(
                                                  xr.open_dataset(root_exp+file_name,
                                                                  engine=engine,
                                                                  backend_kwargs=backend_kwargs
                                                                  ).rename(
                                                                          {list_coord[0]: 'longitude',
                                                                           list_coord[1]: 'latitude',
                                                                           list_coord[2]: 'pressure'}
                                                                           )

                                                  )
                                          ).sel(
                                                longitude=slice(longitude_slice[0], longitude_slice[1]),
                                                latitude=slice(latitude_slice[0], latitude_slice[1])
                                                )
                              )
    
    elif type(file_name)==list:
        # Open files in a list
        file_data_list=[converthPa_Pa(
                                  convertIncreaseLat(
                                           convert360_180(
                                                          xr.open_dataset(i,
                                                                          engine=engine,
                                                                          backend_kwargs=backend_kwargs
                                                                          ).rename(
                                                                                  {list_coord[0]: 'longitude',
                                                                                   list_coord[1]: 'latitude',
                                                                                   list_coord[2]: 'pressure'}
                                                                                   )
                                                         )
                                                    ).sel(
                                                          longitude=slice(longitude_slice[0], longitude_slice[1]),
                                                          latitude=slice(latitude_slice[0], latitude_slice[1])
                                                          )
                                      )
                        for i in file_name]
        
        # Concatenate according to time
        file_data=xr.concat(file_data_list, dim='time')
        # Delete unused var
        del file_data_list        

    LON_model=file_data.variables['longitude'].data
    LAT_model=file_data.variables['latitude'].data
    PRE_model=file_data.variables['pressure'].data
    
    return file_data, LON_model, LAT_model, PRE_model






#### Longitude/Latitude/Pressure conversion

In [None]:
def convert360_180(_ds):
    """
    convert longitude from 0-360 to -180 -- 180 deg
    """
    lon_name='longitude'
    # check if already 
    attrs = _ds[lon_name].attrs
    if _ds[lon_name].min() >= 0: # 0 - 360
        with xr.set_options(keep_attrs=True): 
            _ds.coords[lon_name] = (_ds[lon_name] + 180) % 360 - 180
        _ds = _ds.sortby(lon_name)
    return _ds

In [None]:
def convertIncreaseLat(_ds):
    """
    convert latitude from 90/-90 to -90/90 deg
    """
    lat_name='latitude'
    # check if already
    attrs = _ds[lat_name].attrs
    if _ds[lat_name][1]<_ds[lat_name][0]: # decreasing latitude
        _ds=_ds.isel(latitude=slice(None, None, -1))
    return _ds

In [None]:
def converthPa_Pa(_ds):
    """
    convert longitude from 0-360 to -180 -- 180 deg
    """
    pre_name='pressure'
    # check if already 
    attrs = _ds[pre_name].attrs
    if _ds[pre_name].attrs.get('units') =='millibars' or _ds[pre_name].attrs.get('units') =='millibar' or _ds[pre_name].attrs.get('units') =='hPa':
        with xr.set_options(keep_attrs=True): 
            _ds.coords[pre_name] = _ds[pre_name] *100
        _ds[pre_name].attrs["units"] = 'Pa'
    return _ds  

#### Defines time list for trajectories

In [3]:
def TimeList(time_init, direction, time_step, time_format, time_end=None, duration=None, format_output=None):
    time_step*=direction
    if time_end==None:
        time_end=(datetime.strptime(time_init,time_format) + direction * timedelta(hours=duration)).strftime(time_format)

    if format_output==None:
        return [(datetime.strptime(time_init,time_format) +  i * timedelta(hours=time_step)) for i in range((datetime.strptime(time_end,time_format)-datetime.strptime(time_init,time_format)) // timedelta(hours=time_step) + 1 )]
    else:
        return [(datetime.strptime(time_init,time_format) +  i * timedelta(hours=time_step)).strftime(time_format) for i in range((datetime.strptime(time_end,time_format)-datetime.strptime(time_init,time_format)) // timedelta(hours=time_step) + 1 )]

#### Generate Seeds

In [8]:
def GenerateSeeds(lon_Init=None,lon_End=None,lon_Resolution=None,lon_Number=None,
                   lat_Init=None,lat_End=None,lat_Resolution=None,lat_Number=None,
                   pre_Init=None,pre_End=None,pre_Resolution=None,pre_Number=None,
                   CV=False):
    """ 
    Function that generate seeding points coordinates (lon/lat/pres) by giving limits (included) of longitude, latitude pressure and their respective resolutions or numbers.
    For exemple, to generate seeding points in a rectangle area spreading from 136°E to 148°E and from 63°S to 52°S and from 500hPa to 950hPa, with a resolution of 0.1° and 50hPa:
        write: lon_Init=136, lon_End=148, lon_Resolution=0.1, lat_Init=-63, lat_End=-51, lat_Resolution=0.1, pre_Init=50000, pre_End=95000, pre_Resolution=5000
    or to generate 50 seeding points on 10 levels along a cross-section from 136°E to 148°E and from 63°S to 52°S and from 500hPa to 950hPa:
        write: lon_Init=136, lon_End=148, lon_Number=50, lat_Init=-63, lat_End=-51, lat_Number=50, pre_Init=50000, pre_End=95000, pre_Number=10
    or to generate 50 seeding points on 10 levels along a cross-section from 136°E to 148°E and from 63°S to 52°S and from 500hPa to 950hPa with a resolution of 0.1° and 50hPa:
        write: lon_Init=136, lon_Resolution=0.1, lon_Number=50, lat_Init=-63, lat_Resolution=0.1, lat_Number=50, pre_Init=50000, pre_Resolution=5000, pre_Number=10
    
    Longitude/Latitude resolutions are in degrees while pressure one is in Pa.
  
    If you do not want seeding points in a rectangle but along a Cross-section, defined by the same limits, put CV at True.
    Warning: the number of seeding points in longitude and latitude must be the same !

    Warning: lat_End > lat_Init, lon_End > lon_Init, pre_End > pre_Init
    """
    from math import ceil

    if lon_Init==None or lat_Init==None or pre_Init==None:
        print('Error: lon_Init, lat_Init and pre_Init must be defined')
        return None,None,None, None, None, None,None
    if lon_End!=None or lat_End!=None or pre_End!=None:
        if lat_End < lat_Init or lon_End < lon_Init or pre_End < pre_Init:
            print('Error in lon_End, lat_End or pre_End. They must follow these conditions: lat_End > lat_Init, lon_End > lon_Init, pre_End > pre_Init')
            return None,None,None, None, None, None,None

    def coords_seeding_properties(coord_Init, coord_End, coord_Number, coord_Resolution):
        if coord_Resolution==None and coord_Number!=1:
            nb_coord=coord_Number
            coord_Resolution=(coord_End-coord_Init)/(coord_Number-1)
        elif coord_Resolution==None and coord_Number==1:
            nb_coord=coord_Number
            coord_Resolution=0
        elif coord_Number==None:
            nb_coord=ceil((coord_End-coord_Init)/coord_Resolution)+1
        elif coord_End==None:
            nb_coord=coord_Number
        return coord_Init, coord_End, coord_Number, coord_Resolution, nb_coord

    lon_Init, lon_End, lon_Number, lon_Resolution, nb_lon= coords_seeding_properties(lon_Init, lon_End, lon_Number, lon_Resolution)
    lat_Init, lat_End, lat_Number, lat_Resolution, nb_lat= coords_seeding_properties(lat_Init, lat_End, lat_Number, lat_Resolution)
    pre_Init, pre_End, pre_Number, pre_Resolution, nb_pre= coords_seeding_properties(pre_Init, pre_End, pre_Number, pre_Resolution)   
    
    if CV:
        print('Seeding along a cross-section')
        if nb_lon!=nb_lat:
            print('Impossible to generate seeding points along a cross-section! \nNumber of longitude and latitude must be the same !')
            return None,None,None, None, None, None,None
        else:
            LON=np.zeros((nb_lon,nb_pre))
            LON[0]=lon_Init
            for i in range(1,nb_lon,1):
                LON[i,:] =  LON[i-1,:] + lon_Resolution
                
            LAT=np.zeros((nb_lat,nb_pre))
            LAT[0]=lat_Init
            for i in range(1,nb_lat,1):
                LAT[i,:] =  LAT[i-1,:] + lat_Resolution

            PRE=np.zeros((nb_lon,nb_pre))
            PRE[0]=pre_Init
            for i in range(1,nb_pre,1):
                PRE[:,i] =  PRE[:,i-1] + pre_Resolution

            LON=LON.reshape(nb_lon*nb_pre)
            LAT=LAT.reshape(nb_lon*nb_pre)
            PRE=PRE.reshape(nb_lon*nb_pre)

            print(str(nb_lon*nb_pre)+' trajectories seeds generated')
            return LON, LAT, PRE, nb_lon*nb_pre, nb_lon, nb_lat, nb_pre
    else:
        if nb_lon==1 or nb_lat==1:
            print('Seeding along a cross-section')
        else:
            print('Seeding in a rectangle area')

        LON=np.zeros((nb_lon, nb_lat, nb_pre))
        LON[0,:,:]=lon_Init
        for i in range(1,nb_lon,1):
            LON[i,:,:] =  LON[i-1,:,:] + lon_Resolution
            
        LAT=np.zeros((nb_lon, nb_lat, nb_pre))
        LAT[:,0,:]=lat_Init
        for i in range(1,nb_lat,1):
            LAT[:,i,:] =  LAT[:,i-1,:] + lat_Resolution

        PRE=np.zeros((nb_lon, nb_lat, nb_pre))
        PRE[:,:,0]=pre_Init
        for i in range(1,nb_pre,1):
            PRE[:,:,i] =  PRE[:,:,i-1] + pre_Resolution
            
        LON=LON.reshape(nb_lon*nb_lat*nb_pre)
        LAT=LAT.reshape(nb_lon*nb_lat*nb_pre)
        PRE=PRE.reshape(nb_lon*nb_lat*nb_pre)
        
        print(str(nb_lon*nb_lat*nb_pre)+' trajectories seeds generated')
        return LON, LAT, PRE, nb_lon*nb_lat*nb_pre, nb_lon, nb_lat, nb_pre

#### Maps definition plot

In [None]:
def GetPressureLevel(PRE_model,levelInPa=85000):
    diff=np.abs(PRE_model - levelInPa)
    return np.where(diff==np.min(diff))[0][0]

In [None]:
def GetTimeStep(YYMMDDHH_model, time_wanted):
    # Which time step (initial) ?
    #if file_type=='netcdf':
    #    for i,j in enumerate(file_data.time.data):
    #        if j==np.datetime64(Initial_time_traj):
    #            i_time=i
    #else:
    return YYMMDDHH_model.index(time_wanted)   


In [6]:
def TrajMap(lon_min=100, lon_max=160, lat_min=-70,lat_max=-30):
    fig, ax = plt.subplots(
                            figsize=(10, 8),
                            subplot_kw={'projection': ccrs.PlateCarree()}) 
    
    # Zoom
    ax.set_extent([lon_min, lon_max, lat_min,lat_max], crs=ccrs.PlateCarree())
    ax.add_feature(cfeature.COASTLINE, linewidth=0.8, color='dimgrey')
    ax.add_feature(cfeature.BORDERS, linestyle=':', linewidth=0.5,color='dimgrey')
    # Lon-Lat Grid
    gl = ax.gridlines(draw_labels=True, color='gray', linestyle='--', linewidth=0.2)
    gl.top_labels = gl.right_labels = False  
    

    return fig, ax
    
def TrajMap_AddField(ax,i_var_tmp, LON_tmp, LAT_tmp,var_tmp,cmap_tmp):
    contour_ICON=ax.contourf(LON_tmp, 
                   LAT_tmp, 
                   var_tmp,
                   cmap=cmap_tmp ,
                   transform=ccrs.PlateCarree(),
                   extend="both") 
        
    cbar = plt.colorbar(contour_ICON,  pad=0.05, ax=ax , fraction=0.02)
    cbar.set_label(i_var_tmp) 
    
def TrajMap_AddArrow(ax, LON_tmp, LAT_tmp, U_tmp, V_tmp, resol=1):
    lon,lat=np.meshgrid(LON_tmp[::resol],LAT_tmp[::resol])
    ax.quiver(lon,lat,
              U_tmp[::resol,::resol], V_tmp[::resol,::resol],
              units='xy',scale=10);

    
def TrajMap_AddSeedingPoint(ax,LON_seed, LAT_seed, resol=1):
    ax.plot(LON_seed[::resol], LAT_seed[::resol], 'ko',transform=ccrs.PlateCarree())
   # ax.scatter(LON_seed[::resol],LAT_seed[::resol], c=P_traj[::resol], edgecolors='black',
   #        cmap='Greens',transform=ccrs.PlateCarree())
    
def TrajMap_AddTrajectories(fig,ax,LON_traj, LAT_traj, color):
    print('Trajectories plot: ',end='')
    for i_traj in range(nb_traj):
        points = np.array([LON_traj[i_traj,:], LAT_traj[i_traj,:]]).T.reshape(-1, 1, 2)
        segments = np.concatenate([points[:-1], points[1:]], axis=1)
        norm = plt.Normalize(color.min(),color.max())
        lc = LineCollection(segments, cmap='jet', norm=norm ,transform=ccrs.PlateCarree())
        lc.set_array(color[i_traj,:])
        lc.set_linewidth(2)
        line = ax.add_collection(lc)
    print('ok')
    

    colo = fig.colorbar(lc)
    colo.ax.tick_params(labelsize=15)
    
#ax.set_xlabel('Longitude [°]')
#ax.set_ylabel('Latitude [°]')  

#### Compute Lagrangian Trajectories

In [None]:
def TrajCompute(Initial_time_traj,End_time_traj,dt_traj,time_format,
                dt_modelOutput,Trajectories_duration,nb_traj,LON_seed, LAT_seed, PRE_seed,
                list_coord, list_var_advec, list_var, file_data, YYMMDDHH_model
                ):

    # Type of trajectories
    if Initial_time_traj>End_time_traj:
        direction=-1
        type_traj='backward'
    else:
        direction=1
        type_traj='forward'

    # Time along trajectories
    # string in time_format
    YYMMDDHH_traj= TimeList(Initial_time_traj, direction, dt_traj, time_format, duration=Trajectories_duration, format_output=time_format)
    # in datetime.datetime type
    YYMMDDHH=      TimeList(Initial_time_traj, direction, dt_traj, time_format, duration=Trajectories_duration,)
    # Number of time step
    nb_time=len(YYMMDDHH)
    print(YYMMDDHH_traj)
    # Initial time step of trajectories
    i_time=GetTimeStep(YYMMDDHH_model,time_wanted=Initial_time_traj)
    
    # Initialisation of Traj (dictionnary of all variables, containing array data with all trajectories)
    Traj={}
    for i_var in list_coord+list_var_advec+list_var:
        Traj[i_var]=np.zeros((nb_traj,nb_time))
        if i_var not in list_coord:
            Traj[i_var][:]=np.nan

    for i_var,var_seed in zip(list_coord,[LON_seed,LAT_seed,PRE_seed]):
        Traj[i_var][:,0]=var_seed

    for i_var in list_var_advec+list_var:
        Traj[i_var][:,0]=interpn((PRE_model,LAT_model,LON_model),file_data.variables[i_var].data[0,:,:,:],np.array([PRE_seed,LAT_seed,LON_seed]).T)

    Traj['time'],tmp=np.meshgrid(np.array(YYMMDDHH),range(nb_traj))
    Traj['time2']=np.zeros((nb_traj,nb_time))



    # Number of file in 24h
    njour= int(24/dt_modelOutput) 
    # Number of time step between two files (temporal resolution of lagrangian trajectories) 
    npdt= int(dt_modelOutput /dt_traj)
    # Time resolution in seconds for trajectories calculation
    DT=direction*dt_traj*3600. #DT is in seconds

    print('Trajectories calculation: ')
    ipdt=0                                                                              
    for i_ech in range(i_time*npdt,
                       i_time*npdt + direction*int((npdt*njour*Trajectories_duration)/24),
                       npdt*direction):
        print('\t time step : ',i_ech) #, i_ech/npdt)
        
        # Get variables to interpolate along trajectories at time t and time t +/- 1
        Data_t={}
        Data_t1={}
        for i_var in list_coord:
            Data_t[i_var]= file_data.variables[i_var].data
            Data_t1[i_var]=file_data.variables[i_var].data   
        for i_var in list_var_advec+list_var:
            Data_t[i_var]= file_data.variables[i_var].data[int(i_ech/npdt),:,:,:]
            Data_t1[i_var]=file_data.variables[i_var].data[int(i_ech/npdt)+direction,:,:,:]            
 

        # Reverse in order to get increasing latitude (flip other fields consequently)
        #if Data_t[list_coord[1]][1]>Data_t[list_coord[1]][0]:
        #    print('need to reverse model field')
        #    Data_t[list_coord[1]]=Data_t[list_coord[1]][::-1] # ascending latitude
        #    for i_var in list_var_advec+list_var:
        #        Data_t[i_var]=Data_t[i_var][:,::-1,:]
        #        Data_t[i_var]=Data_t[i_var][:,::-1,:]
        #        Data_t1[i_var]=Data_t1[i_var][:,::-1,:]
        #        Data_t1[i_var]=Data_t1[i_var][:,::-1,:]
                
        coord_model=(Data_t[list_coord[2]],Data_t[list_coord[1]],Data_t[list_coord[0]])
                
        # Get meteorological field for the current time step and the previous/next one
        U_model=Data_t[list_var_advec[0]]
        V_model=Data_t[list_var_advec[1]]
        W_model=Data_t[list_var_advec[2]]
                
        U2_model=Data_t1[list_var_advec[0]]
        V2_model=Data_t1[list_var_advec[1]]
        W2_model=Data_t1[list_var_advec[2]]
        
        
                
        # Computation
        for ipd in range(0,npdt):                                               
                                                                          
            ipdt=ipdt+1
            poi=1/(2*(npdt-ipd))
            
            lo=Traj[list_coord[0]][:,ipdt-1]
            la=Traj[list_coord[1]][:,ipdt-1]
            pre=Traj[list_coord[2]][:,ipdt-1]
     
            u_tmp=Traj[list_var_advec[0]][:,ipdt-1]
            v_tmp=Traj[list_var_advec[0]][:,ipdt-1]
            w_tmp=Traj[list_var_advec[0]][:,ipdt-1]
            
            u=interpn(coord_model,U_model,np.array([pre,la,lo]).T)
            v=interpn(coord_model,V_model,np.array([pre,la,lo]).T)  
            w=interpn(coord_model,W_model,np.array([pre,la,lo]).T)
            
            var_tmp1={}
            for i_var in list_var:
               var_tmp1[i_var]=interpn(coord_model,Data_t[i_var],np.array([pre,la,lo]).T)
               var_tmp1[i_var]= Traj[i_var][:,ipdt-1]
                      
            if ipd == 0:
                  Traj[list_var_advec[0]][:,ipdt]=u
                  Traj[list_var_advec[1]][:,ipdt]=v              
                  Traj[list_var_advec[2]][:,ipdt]=w

                  for i_var in list_var:
                      Traj[i_var][:,ipdt]=var_tmp1[i_var]

                  u_tmp=u
                  v_tmp=v
                  w_tmp=w

            if ipd > 0:
                  for i_var in list_var:
                      var_tmp1[i_var]=Traj[i_var][:,ipdt-1]
                      
                  u_tmp=Traj[list_var_advec[0]][:,ipdt-1]
                  v_tmp=Traj[list_var_advec[1]][:,ipdt-1]
                  w_tmp=Traj[list_var_advec[2]][:,ipdt-1]
              
            u1=u_tmp
            v1=v_tmp
            w1=w_tmp
              
            for iter in range(1,3):
                  coco=np.cos(la*np.pi/180.0);
 
                  lo_1=lo+(npdt-ipd)*u_tmp*DT/(Ra*coco)*180.0/np.pi                                      
                  la_1=la+(npdt-ipd)*v_tmp*DT/Ra*180.0/np.pi
                  pre_1=pre+(npdt-ipd)*w_tmp*DT
 
                  pre_1=np.minimum(pre_1,100000.0);

                  u2=interpn(coord_model,U2_model,np.array([pre_1,la_1,lo_1]).T)
                  v2=interpn(coord_model,V2_model,np.array([pre_1,la_1,lo_1]).T)  
                  w2=interpn(coord_model,W2_model,np.array([pre_1,la_1,lo_1]).T)
                  
                  
                  var_tmp2={}
                  for i_var in list_var:
                     var_tmp2[i_var]=interpn(coord_model,Data_t1[i_var],np.array([pre_1,la_1,lo_1]).T)
                  
                  u_tmp=0.5*(u1 + u2)
                  v_tmp=0.5*(v1 + v2)
                  w_tmp=0.5*(w1 + w2)


              
            uf=(1-poi)*u1 + poi*u2
            vf=(1-poi)*v1 + poi*v2
            wf=(1-poi)*w1 + poi*w2
            
            var_tmp={}
            for i_var in list_var:
                var_tmp[i_var]=(1-poi)*var_tmp1[i_var] + poi*var_tmp2[i_var]
              
            loav=lo
            laav=la
            preav=pre
              
            coco=np.cos(la*np.pi/180.0);
            lo=lo+uf*DT/(Ra*coco )*180.0/np.pi
            la=la+vf*DT/Ra*180.0/np.pi
            pre=pre+wf*DT
              
            pre=np.minimum(pre,100000.0)

            Traj[list_coord[1]][:,ipdt]=la
            Traj[list_coord[0]][:,ipdt]=lo
            Traj[list_coord[2]][:,ipdt]=pre
            Traj['time2'][:,ipdt]=i_ech+(direction*ipd)
            
            if ipd > 0:
                Traj[list_var_advec[0]][:,ipdt]=uf
                Traj[list_var_advec[1]][:,ipdt]=vf
                Traj[list_var_advec[2]][:,ipdt]=wf
                
                for i_var in list_var:
                    Traj[i_var][:,ipdt]=var_tmp[i_var]
                    
    return Traj        