
# work with example float output
# deep-ocean water-mass properties

# v2: sample every ten days with floats

# Fig. 7 in GMD paper


In [1]:

import sys
sys.path.append('/global/homes/c/cnissen/scripts/seawater-3.3.4/seawater/')
sys.path.append("/global/homes/c/cnissen/scripts/python_gsw_py3/")
import os
import numpy as np
import xarray as xr
import cartopy
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
from matplotlib import cm
import seawater
#from seawater import dist
import seawater as sw
import matplotlib.path as mpath
from cartopy.util import add_cyclic_point
import matplotlib.gridspec as gridspec
import matplotlib.ticker as mticker
from cartopy.mpl.ticker import (LongitudeFormatter, LatitudeFormatter,
                                LatitudeLocator)
import random
from numba import njit
from math import sin, cos, sqrt, atan2, radians
from gsw import rho # rho from SA, CT, p
from gsw import pt0_from_t # potTemp from SA, t, p (at reference pressure 0)
from gsw import pt_from_t # potTemp from SA, t, p and reference pressure
from gsw import pot_rho_t_exact # potRho from SA, t, p and reference pressure
from gsw import p_from_z # get pressure from z and lat
#from gsw import sigma0_pt0_exact
from gsw import SA_from_SP
from tqdm import tqdm
from statsmodels.stats.weightstats import DescrStatsW
import math


In [2]:
#-----
# saving plots
#-----

savepath     = '/global/cfs/cdirs/m4003/cnissen/Plots/E3SM_floats/deep_ocean_properties_10_daily_sampling/'
# check existence of paths
if not os.path.exists(savepath):
    print ('Created '+savepath)
    os.makedirs(savepath)
    

In [3]:
####
# specifics for trajectory output
####

rad_to_deg = 180.0/np.pi
latlim = -45.0

path_mesh = '/global/cfs/cdirs/m4003/maltrud/'
meshID = 'EC30to60E2r2'
meshfile = xr. open_dataset(path_mesh+'ocean.'+meshID+'.210210.nc')
#print(meshfile)

lon  = meshfile['lonCell'].values*rad_to_deg
lat  = meshfile['latCell'].values*rad_to_deg
topo = meshfile['bottomDepth'].values
area = meshfile['areaCell'].values
zlevs            = meshfile['refBottomDepth'].values
layerThickness   = meshfile['layerThickness'].values
restingThickness = meshfile['restingThickness'].values

print(len(lon),'nodes in mesh')
print(topo.shape)
print(area.shape)
print('Min/Max lon:',np.min(lon),np.max(lon))
print('Min/Max lat:',np.min(lat),np.max(lat))
print('layerThickness.shape:',layerThickness.shape)
print('restingThickness.shape:',restingThickness.shape)

meshfile.close()

dd1 = np.argmin(np.abs(zlevs-2000))
print('Average below depth level',dd1,zlevs[dd1])
dz = np.diff(np.hstack((0,zlevs)))[dd1:]


236853 nodes in mesh
(236853,)
(236853,)
Min/Max lon: 0.0007300572350528742 359.997672445938
Min/Max lat: -78.53259417674468 89.94461290099375
layerThickness.shape: (1, 236853, 60)
restingThickness.shape: (236853, 60)
Average below depth level 45 2074.8740234375005


In [4]:
#----
# load daily float output
#----

# only store data below XXm

path = '/global/cfs/cdirs/m4003/maltrud/6year/floats/'
year_list = ['0055','0056','0057','0058','0059','0060']

for yy in tqdm(range(0,len(year_list))):
    print('Load year '+year_list[yy])
    file1 = 'floats.year'+year_list[yy]+'.nc'   
    data = xr. open_dataset(path+file1)

    lon_1   = data['particleColumnLon'] #.values*rad_to_deg 
    lat_1   = data['particleColumnLat'] #.values*rad_to_deg 
    temp_1  = data['particleColumnTemperature'] #.values 
    salt_1  = data['particleColumnSalinity'] #.values
    oxy_1   = data['particleColumnO2'] #.values
    #print('lat_all',lat_all.shape)

  #  # set missing values to NaN (deep ocean layers) 
  #  lat_1[temp_1==-1]=np.nan
  #  lon_1[temp_1==-1]=np.nan
  #  temp_1[temp_1==-1]=np.nan
  #  salt_1[salt_1==-1]=np.nan
    
    if yy==0: # first time
        lat_all = lat_1[:,0,:].values*rad_to_deg 
        lon_all = lon_1[:,0,:].values*rad_to_deg 
        temp_all = temp_1[:,dd1:,:].values
        salt_all = salt_1[:,dd1:,:].values
        oxy_all  = oxy_1[:,dd1:,:].values
    else:
        lat_all = np.concatenate((lat_all,lat_1[:,0,:].values*rad_to_deg))
        lon_all = np.concatenate((lon_all,lon_1[:,0,:].values*rad_to_deg))
        temp_all = np.concatenate((temp_all,temp_1[:,dd1:,:].values))
        salt_all = np.concatenate((salt_all,salt_1[:,dd1:,:].values))
        oxy_all  = np.concatenate((oxy_all,oxy_1[:,dd1:,:].values))
        
    del lon_1,lat_1,temp_1,salt_1,oxy_1
    
# set missing values to NaN (deep ocean layers) 
temp_all[salt_all==-1]=np.nan # use salt here just in case temp accidentally is exactly -1 somewhere
salt_all[salt_all==-1]=np.nan
oxy_all[oxy_all==-1]=np.nan

# the part below is just to make sure there is no random zeros in the fields...
temp_all[salt_all==0]=np.nan # use salt
salt_all[salt_all==0]=np.nan
oxy_all[oxy_all==0]=np.nan

print ('done')


  0%|          | 0/6 [00:00<?, ?it/s]

Load year 0055


 17%|█▋        | 1/6 [01:10<05:52, 70.46s/it]

Load year 0056


 33%|███▎      | 2/6 [02:24<04:50, 72.62s/it]

Load year 0057


 50%|█████     | 3/6 [03:45<03:48, 76.22s/it]

Load year 0058


 67%|██████▋   | 4/6 [05:02<02:33, 76.79s/it]

Load year 0059


 83%|████████▎ | 5/6 [06:18<01:16, 76.51s/it]

Load year 0060


100%|██████████| 6/6 [07:35<00:00, 75.85s/it]


done


In [5]:
#---
# functions
#---

@njit
def get_avg_biomess_per_day(phyC_int_mean,num_profiles,days,phyC_int):
    # get average for each day over all floats
    for dd in range(1,phyC_int_mean.shape[0]+1):
        ind = np.where(days==dd)[0]
        num_profiles[dd-1] = ind.shape[0]
        phyC_int_mean[dd-1] = np.mean(phyC_int[ind])
        #del ind
    return phyC_int_mean,num_profiles

@njit
def get_vertical_avg(data,dz,data_avg):
    # get vertical average of a property (e.g., average temperature below 2000m)
    # assumed format: time x floats
    # pass initialized resulting array to the function
    for tt in range(0,data.shape[0]):
        for nn in range(0,data.shape[2]):
            aux = data[tt,:,nn]
            #ind_not_nan = np.where(~np.isnan(aux))[0]  # njit does not like NaNs
            ind_not_nan = np.where(aux>-990)[0]
            if len(ind_not_nan)>0:
                data_avg[tt,nn] = np.sum(aux[ind_not_nan]*dz[ind_not_nan])/np.sum(dz[ind_not_nan])
    return data_avg



In [6]:
#---
# reduce to 10-daily data
#---

temp10  = temp_all[::10,:,:]
salt10  = salt_all[::10,:,:]
oxy10   = oxy_all[::10,:,:]
lat10   = lat_all[::10,:]
lon10   = lon_all[::10,:]
print (lat10.shape,temp10.shape)
#
##----
## get average temperature below e.g. 2000m
##----
dz = np.diff(np.hstack((0,zlevs)))[dd1:] # add a zero to vector; assume zlevs to give depth at the bottom
temp10[np.isnan(temp10)] = -999 # njit does not like NaNs
temp10[temp10==0] = -999
temp10_avg = np.ones([temp10.shape[0],temp10.shape[2]]) # initialize array to pass to function
temp10_avg = get_vertical_avg(temp10,dz,temp10_avg) # get avg temp
temp10_avg[temp10_avg==1]=np.nan # NaN -> areas shallower than 2000m
print('shape, min, max',temp10_avg.shape,np.nanmin(temp10_avg),np.nanmax(temp10_avg))

salt10[np.isnan(salt10)] = -999 # njit does not like NaNs
salt10[salt10==0] = -999
salt10_avg = np.ones([salt10.shape[0],salt10.shape[2]]) # initialize array to pass to function
salt10_avg = get_vertical_avg(salt10,dz,salt10_avg) # get avg temp
salt10_avg[salt10_avg==1]=np.nan # NaN -> areas shallower than 2000m
print('shape, min, max',salt10_avg.shape,np.nanmin(salt10_avg),np.nanmax(salt10_avg))

oxy10[np.isnan(oxy10)] = -999 # njit does not like NaNs
oxy10[oxy10==0] = -999
oxy10_avg = np.ones([oxy10.shape[0],oxy10.shape[2]]) # initialize array to pass to function
oxy10_avg = get_vertical_avg(oxy10,dz,oxy10_avg) # get avg temp
oxy10_avg[oxy10_avg==1]=np.nan # NaN -> areas shallower than 2000m
print('shape, min, max',oxy10_avg.shape,np.nanmin(oxy10_avg),np.nanmax(oxy10_avg))

print('done')


(219, 10560) (219, 15, 10560)
shape, min, max (219, 10560) -0.8824534171210164 20.780603408813477
shape, min, max (219, 10560) 34.0919189453125 40.02397537231445
shape, min, max (219, 10560) 5.040485382080078 301.62677001953125
done


In [9]:
#----
# load mask of biomes
#----

path_mask = '/global/cfs/cdirs/m4003/cnissen/masks/'
file_mask = 'reccap_mask_regions_e3sm_mesh_EC30to60E2r2_wSubregions.nc'

ff = xr. open_dataset(path_mask+file_mask)
mask_global=ff['mask_e3sm_all_regions'].values.squeeze()
ff.close()

#subareas = ['Atlantic','Pacific','Indian','Arctic','SouthernOcean']
#subareas = ['Atlantic','Pacific','Indian','Arctic','STSS','SPSS','ICE']

print('Min/Max mask_e3sm_all_regions:',np.min(mask_global),np.max(mask_global)) 

print(mask_global.shape,lat.shape,lon.shape)


Min/Max mask_e3sm_all_regions: 0.0 27.0
(236853,) (236853,) (236853,)


In [10]:
#---
# load file in which profiles are colocated with biomes
#---

path_biome = '/global/cfs/cdirs/m4003/cnissen/6year_run/'
year_list = ['0055','0056','0057','0058','0059','0060']

for yy in range(0,len(year_list)):
    print('Load year '+year_list[yy])
    
    file_biome = 'Float_positions_colocated_with_biomes_year'+year_list[yy]+'_v2.nc' # v2 (July20023) has a bug fix for colocation
    ff = xr.open_dataset(path_biome+file_biome)

    mask1=ff['biome_index'].values.squeeze()
    
    if yy==0: # first time
        mask_biomes_all = mask1
    else:
        mask_biomes_all = np.concatenate((mask_biomes_all,mask1))
        
print(mask_biomes_all.shape)

print('Min/Max index biomes:',mask_biomes_all.shape,np.min(mask_biomes_all),np.max(mask_biomes_all))

#"1.NA SPSS, 2.NA STSS, 3.NA STPS, 4.AEQU, 5.SA STPS, 6.MED (not in FM14)" ;
#"7.IND STPS, 8.(not in FM14)" ;
#"9.NP SPSS, 10.NP STSS, 11.NP STPS, 12.PEQU-W, 13.PEQU-E, 14.SP STPS" ;
#"15.ARCTIC ICE (not in FM14), 16.NP ICE, 17.NA ICE, 18.Barents (not in FM14)" ;
#"19. STSS_Atl, 20. SPSS_Atl, 21. ICE_Atl, 22. STSS_Ind, 23. SPSS_Ind, 
# 24. ICE_Ind, 25. STSS_Pac, 26. SPSS_Pac, 27. ICE_Pac"

subareas = ['NA_SPSS','NA_STSS','NA_STPS','AEQU','SA_STPS','MED',\
           'IND_STPS','xx',\
           'NA_SPSS','NP_STSS','NP_STPS','PEQU-W','PEQU-E','SP_STPS',\
           'ARTIC_ICE','NP_ICE','NA_ICE','Barents',\
           'STSS_Atl','SPSS_ATL','ICE_ATL','STSS_IND','SPSS_IND',\
            'ICE_IND','STSS_PAC','SPSS_PAC','ICE_PAC']
print(len(subareas))   
    

Load year 0055
Load year 0056
Load year 0057
Load year 0058
Load year 0059
Load year 0060
(2183, 10560)
Min/Max index biomes: (2183, 10560) 0.0 27.0
27


In [11]:
#---
# get average within each subarea
#---
old_code = False
if old_code:
    #subregions = ['global','ICE_south','SPSS_south','STSS_south','STPS_south',\
    #              'Equator',
    #              'STPS_north','STSS_north','SPSS_north','ICE_north']
    subregions = ['global','ICE_south','SPSS_south','STSS_south','STPS_south',\
                  'Equator',
                  'STPS_north','STSS_north','SPSS_north','NA_SPSS']

    save_plots = False
    display_plots = True

    # DOUBLE-CHECK THE ORDER -> I had indices 4,12,14 for equator, but I think it should be 4,12,13

    #"1.NA SPSS, 2.NA STSS, 3.NA STPS, 4.AEQU, 5.SA STPS, 6.MED (not in FM14)" ;
    #"7.IND STPS, 8.(not in FM14)" ;
    #"9.NP SPSS, 10.NP STSS, 11.NP STPS, 12.PEQU-W, 13.PEQU-E, 14.SP STPS" ;
    #"15.ARCTIC ICE (not in FM14), 16.NP ICE, 17.NA ICE, 18.Barents (not in FM14)" ;
    #"19. STSS_Atl, 20. SPSS_Atl, 21. ICE_Atl, 22. STSS_Ind, 23. SPSS_Ind, 
    # 24. ICE_Ind, 25. STSS_Pac, 26. SPSS_Pac, 27. ICE_Pac"

    data_sub = np.zeros([temp10_avg.shape[0],len(subregions)])
    for ss in range(0,len(subregions)):
        which_region = subregions[ss]
        print(which_region)

        for tt in range(0,temp10_avg.shape[0]): # loop over time (10-daily)

            mask_biomes = mask_biomes_all[::10,:][tt,:]
            #----
            # get all data in current biome
            #----
            if which_region in ['global']: 
                ind = np.where(mask_biomes.ravel()>0)[0]
            elif which_region in ['NA_STPS']:
                ind = np.where(mask_biomes.ravel()==3)[0]
            elif which_region in ['NA_SPSS']:
                ind = np.where(mask_biomes.ravel()==1)[0]
            elif which_region in ['ICE_south']:
                ind = np.where((mask_biomes.ravel()==21) | (mask_biomes.ravel()==24) | (mask_biomes.ravel()==27))[0]
            elif which_region in ['SPSS_south']:
                ind = np.where((mask_biomes.ravel()==20) | (mask_biomes.ravel()==23) | (mask_biomes.ravel()==26))[0]
            elif which_region in ['STSS_south']:
                ind = np.where((mask_biomes.ravel()==19) | (mask_biomes.ravel()==22) | (mask_biomes.ravel()==25))[0]
            elif which_region in ['STPS_south']:
                ind = np.where((mask_biomes.ravel()==5) | (mask_biomes.ravel()==7) | (mask_biomes.ravel()==14))[0]
            elif which_region in ['Equator']:
                ind = np.where((mask_biomes.ravel()==4) | (mask_biomes.ravel()==12) | (mask_biomes.ravel()==13))[0]
            elif which_region in ['STPS_north']:
                ind = np.where((mask_biomes.ravel()==3) | (mask_biomes.ravel()==11))[0]
            elif which_region in ['ICE_north']: # without Arctic ice!!
                ind = np.where((mask_biomes.ravel()==16) | (mask_biomes.ravel()==17))[0]
            elif which_region in ['SPSS_north']:
                ind = np.where((mask_biomes.ravel()==1) | (mask_biomes.ravel()==9))[0]
            elif which_region in ['STSS_north']:
                ind = np.where((mask_biomes.ravel()==2) | (mask_biomes.ravel()==10))[0]

            # get average below2000m-temperature for all profiles in current biome
            temp_aux = temp10_avg[tt,:].ravel() 
            data_sub[tt,ss] = np.nanmean(temp_aux[ind])
            del temp_aux
            #print('min/max',np.nanmin(data),np.nanmax(data))

    print('done')


In [12]:
#-----
# load full model output
#-----

path1 = '/global/cfs/cdirs/m4003/maltrud/6year/monthlyEulerianAverages/'
year_list = ['0055','0056','0057','0058','0059','0060']

#months = ['01','02','03','04','05','06','07','08','09','10','11','12']
#subregions = ['global','ICE_south','SPSS_south','STSS_south','STPS_south',\
#              'Equator',
#              'STPS_north','STSS_north','SPSS_north','ICE_north']
#subregions = ['global','ICE_south','SPSS_south','STSS_south','STPS_south',\
#              'Equator',
#              'STPS_north','STSS_north','SPSS_north','NA_SPSS']
#subregions = ['SPSS_NA','STSS_NA','STPS_NA','AEQU','STPS_SA','MED',\
#              'STPS_IND','xx','SPSS_A','STSS_NA','STPS_NP','PEQU-W','PEQU-E','STPS_SP',\
#              'ICE_ARCTIC','ICE_NP','ICE_NA','Barents',\
#             'STSS_Atl','SPSS_Atl','ICE_Atl','STSS_Ind','SPSS_Ind',\
#             'ICE_Ind','STSS_Pac','SPSS_Pac','ICE_Pac']

subregions = ['south_of_60S','3060N_Atl','south_of_60S_WS']

for yy in tqdm(range(0,len(year_list))):
    print('Load year',year_list[yy]) 
    
    file1 = 'monthlyAverageEulerianFields.year'+year_list[yy]+'.nc'
    ff = xr. open_dataset(path1+file1)
    data=ff['timeMonthly_avg_activeTracers_temperature'] #.values.squeeze()
    dataS=ff['timeMonthly_avg_activeTracers_salinity']
    dataO=ff['timeMonthly_avg_ecosysTracers_O2']
    ff.close()
    
    data = data[:,:,:,dd1:].values.squeeze()
    data[data<-999] = np.nan
    data[data==0]   = np.nan
    
    dataS = dataS[:,:,:,dd1:].values.squeeze()
    dataS[dataS<-999] = np.nan
    dataS[dataS==0]   = np.nan
    
    dataO = dataO[:,:,:,dd1:].values.squeeze()
    dataO[dataO<-999] = np.nan
    dataO[dataO==0]   = np.nan
    
    # 3D dz field
    dz_3d = np.tile(dz,[12,236853,1])
    dz_3d[np.isnan(data)] = np.nan
    
    # avg temp below 2000m
    data = np.nansum(data*dz_3d,axis=2)/np.nansum(dz_3d,axis=2)
    dataS = np.nansum(dataS*dz_3d,axis=2)/np.nansum(dz_3d,axis=2)
    dataO = np.nansum(dataO*dz_3d,axis=2)/np.nansum(dz_3d,axis=2)
    
    # keep full model output (to plot maps)
    if yy==0:
        dataT_full = data
        dataS_full = dataS
        dataO_full = dataO
    else:
        dataT_full = dataT_full+data
        dataS_full = dataS_full+dataS
        dataO_full = dataO_full+dataO
    
    # get regional averages
    data_std = np.zeros([data.shape[0],len(subregions)])
    data_stdS = np.zeros([dataS.shape[0],len(subregions)])
    data_stdO = np.zeros([dataO.shape[0],len(subregions)])
    data_avg = np.zeros([data.shape[0],len(subregions)])
    data_avgS = np.zeros([dataS.shape[0],len(subregions)])
    data_avgO = np.zeros([dataO.shape[0],len(subregions)])
    for mm in range(0,12):
        
        aux = data[mm,:]
        auxS = dataS[mm,:]
        auxO = dataO[mm,:]
        for ss in range(0,len(subregions)):
            which_region = subregions[ss]
            #print('Process ',which_region)

            area2 = np.copy(area)
            area2[np.isnan(aux)] = np.nan # make sure area file has the same NaNs as data file
            #------
            # get indices for current region
            if which_region in ['south_of_60S']:
                ind_reg = np.where(lat<=-60)[0]
                #ind_reg = np.where((lat<=-60)  & (lat>-70))[0] # test: subregion
            elif which_region in ['3060N_Atl']:
                #ind_reg = np.where((lat>30) & (lat<=60) & (lon>300))[0]
                ind_reg = np.where((lat>30) & (lat<=60) & (lon>300) & (lon<=350))[0]
            elif which_region in ['south_of_60S_WS']:
                ind_reg = np.where((lat<=-60)  & (lon>300) & (lon<=350))[0]
                #ind_reg = np.where((lat<=-60)  & (lat>-70))[0] # test: subregion
            else:
                ind_reg = np.where(mask_global.ravel()==ss+1)[0]
    
            weights = area2[ind_reg]/np.nansum(area2[ind_reg])
            ind_not_NaN = np.where(~np.isnan(aux[ind_reg]))[0]
            
            data_std[mm,ss] = DescrStatsW(aux[ind_reg][ind_not_NaN],weights=weights[ind_not_NaN]).std 
            data_stdS[mm,ss] = DescrStatsW(auxS[ind_reg][ind_not_NaN],weights=weights[ind_not_NaN]).std 
            data_stdO[mm,ss] = DescrStatsW(auxO[ind_reg][ind_not_NaN],weights=weights[ind_not_NaN]).std 
                    
            data_avg[mm,ss] = np.nansum(aux[ind_reg]*weights) #area2[ind_reg])/np.nansum(area2[ind_reg])
            data_avgS[mm,ss] = np.nansum(auxS[ind_reg]*weights) #area2[ind_reg])/np.nansum(area2[ind_reg])
            data_avgO[mm,ss] = np.nansum(auxO[ind_reg]*weights) #area2[ind_reg])/np.nansum(area2[ind_reg])
        del aux,auxS,auxO
        
    if yy==0:
        data_reg_all_std = data_std
        data_reg_allS_std = data_stdS
        data_reg_allO_std = data_stdO
        
        data_reg_all = data_avg
        data_reg_allS = data_avgS
        data_reg_allO = data_avgO
    else:
        data_reg_all_std = np.concatenate((data_reg_all_std,data_std))
        data_reg_allS_std = np.concatenate((data_reg_allS_std,data_stdS))
        data_reg_allO_std = np.concatenate((data_reg_allO_std,data_stdO))
        
        data_reg_all = np.concatenate((data_reg_all,data_avg))
        data_reg_allS = np.concatenate((data_reg_allS,data_avgS))
        data_reg_allO = np.concatenate((data_reg_allO,data_avgO))
        
    del data_avg,data,dz_3d,data_avgS,dataS,data_avgO,dataO
    
# normalize full output by number of years
dataT_full = np.divide(dataT_full,len(year_list))
dataS_full = np.divide(dataS_full,len(year_list))
dataO_full = np.divide(dataO_full,len(year_list))
print('Min/Max temp:',np.nanmin(dataT_full),np.nanmax(dataT_full))
print('Min/Max salinity:',np.nanmin(dataS_full),np.nanmax(dataS_full))
print('Min/Max oxygen:',np.nanmin(dataO_full),np.nanmax(dataO_full))

print('done')


  0%|          | 0/6 [00:00<?, ?it/s]

Load year 0055


  data = np.nansum(data*dz_3d,axis=2)/np.nansum(dz_3d,axis=2)
  dataS = np.nansum(dataS*dz_3d,axis=2)/np.nansum(dz_3d,axis=2)
  dataO = np.nansum(dataO*dz_3d,axis=2)/np.nansum(dz_3d,axis=2)
 17%|█▋        | 1/6 [00:14<01:10, 14.18s/it]

Load year 0056


 33%|███▎      | 2/6 [00:26<00:52, 13.15s/it]

Load year 0057


 50%|█████     | 3/6 [00:38<00:37, 12.45s/it]

Load year 0058


 67%|██████▋   | 4/6 [00:50<00:25, 12.53s/it]

Load year 0059


 83%|████████▎ | 5/6 [01:02<00:12, 12.14s/it]

Load year 0060


100%|██████████| 6/6 [01:14<00:00, 12.47s/it]

Min/Max temp: -0.8807149333742528 20.815802256266277
Min/Max salinity: 34.091909408569336 40.03798294067383
Min/Max oxygen: 5.04816198348999 301.51340738932294
done





In [None]:
#---
# plot maps of Eulerian output
#---
from mpasview import * # Qing's library, modified by Yohei Takano, Dec, 2022
import copy
cmap1 = copy.copy(plt.cm.RdYlBu_r)
cmap1.set_under('w')

# Restart files
meshroot = '/global/cfs/cdirs/m4003/maltrud/' # restart file
meshfile = meshroot+'ocean.EC30to60E2r2.210210.nc' # EC30to60E2r2

# MPAS-O input file (model output)
inputroot = '/global/cfs/cdirs/m4003/maltrud/6year/monthlyEulerianAverages/'
year_list = ['0055','0056','0057','0058','0059','0060']

# MPAS-O mesh for EC30to60E2r2
mpasmesh = MPASMesh(name = 'EC30to60E2r2', filepath = meshfile)
print(mpasmesh)

save_plots = True
dpicnt = 200

#---
# TEMP
#---
res = 0.25
levels = np.arange(-1, 4.5+res, res)
cticks = [-1,0,1,2,3,4]

data_plot = np.copy(dataT_full)
data_plot = np.mean(data_plot,axis=0) # annual mean
data_plot[data_plot==0]        = -999
data_plot[np.isnan(data_plot)] = -999

mpasomap_run1 = MPASOMap(data = data_plot, name = 'avg. T below 2000m', units = 'deg C', mesh = mpasmesh)

plt.figure(figsize=(18,7))
m = mpasomap_run1.plot(region = 'Global', levels = levels, cmap = cmap1, ptype = 'contourf',colorbar=False) # pcolor, contourf
#plt.title('Annual mean, 2012-2017')
if save_plots:
    filename = 'Temperature_below_2000m_avg_2012_2017.png'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight')
    del filename
    
#-----
# COLORBAR: plot separately
#-----
print('Separate COLORBAR...')

lon_reg2 = np.arange(-180,180,1)
lat_reg2 = np.arange(-90,90,1)
lon_reg, lat_reg = np.meshgrid(lon_reg2, lat_reg2)
data_plot = np.zeros_like(lon_reg)

height,width = 18,7
fs = 12

fig = plt.figure(figsize=(height,width))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree(central_longitude=-150))
ax.set_extent([-180, 180, -90, 90])
mm1=ax.contourf(lon_reg, lat_reg, data_plot,\
                   levels=levels,extend='both',cmap=cmap1,transform=ccrs.PlateCarree())
cbar = plt.colorbar(mm1,ax=ax,orientation='vertical',fraction=0.075, pad=0.02,shrink=0.9,ticks=cticks)
cbar.set_label('avg. T below 2000m',fontsize=fs-2)
cbar.ax.tick_params(labelsize=fs-3)
fig.gca().set_visible(False)
if save_plots:
    filename = 'COLORBAR_Temperature_below_2000m_avg_2012_2017.png'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight')
    del filename
plt.show()

fig = plt.figure(figsize=(height,width))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree(central_longitude=-150))
ax.set_extent([-180, 180, -90, 90])
mm1=ax.contourf(lon_reg, lat_reg, data_plot,\
                   levels=levels,extend='both',cmap=cmap1,transform=ccrs.PlateCarree())
cbar = plt.colorbar(mm1,ax=ax,orientation='vertical',fraction=0.075, pad=0.02,shrink=0.9,ticks=cticks)
#cbar.set_label('avg. S below 2000m',fontsize=fs-2)
cbar.ax.tick_params(labelsize=fs-3)
cbar.ax.set_yticklabels(['', '', '','','',''])
fig.gca().set_visible(False)
if save_plots:
    filename = 'COLORBAR_Temperature_below_2000m_avg_2012_2017.eps'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight',format='eps')
    del filename
plt.show()



In [None]:
#---
# OXYGEN
#---
res = 10
levels = np.arange(0, 300+res, res)
cticks = [0,50,100,150,200,250,300]

data_plot = np.copy(dataO_full)
data_plot = np.mean(data_plot,axis=0) # annual mean
data_plot[data_plot==0]        = -999
data_plot[np.isnan(data_plot)] = -999

mpasomap_run1 = MPASOMap(data = data_plot, name = 'avg. O2 below 2000m', units = 'mmol m-3', mesh = mpasmesh)

plt.figure(figsize=(18,7))
m = mpasomap_run1.plot(region = 'Global', levels = levels, cmap = cmap1, ptype = 'contourf',colorbar=False) # pcolor, contourf
#plt.title('Annual mean, 2012-2017')
if save_plots:
    filename = 'Oxygen_below_2000m_avg_2012_2017.png'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight')
    del filename
    
#-----
# COLORBAR: plot separately
#-----
print('Separate COLORBAR...')

lon_reg2 = np.arange(-180,180,1)
lat_reg2 = np.arange(-90,90,1)
lon_reg, lat_reg = np.meshgrid(lon_reg2, lat_reg2)
data_plot = np.zeros_like(lon_reg)

height,width = 18,7
fs = 12

fig = plt.figure(figsize=(height,width))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree(central_longitude=-150))
ax.set_extent([-180, 180, -90, 90])
mm1=ax.contourf(lon_reg, lat_reg, data_plot,\
                   levels=levels,extend='both',cmap=cmap1,transform=ccrs.PlateCarree())
cbar = plt.colorbar(mm1,ax=ax,orientation='vertical',fraction=0.075, pad=0.02,shrink=0.9,ticks=cticks)
cbar.set_label('avg. O2 below 2000m',fontsize=fs-2)
cbar.ax.tick_params(labelsize=fs-3)
fig.gca().set_visible(False)
if save_plots:
    filename = 'COLORBAR_Oxygen_below_2000m_avg_2012_2017.png'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight')
    del filename
plt.show()

fig = plt.figure(figsize=(height,width))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree(central_longitude=-150))
ax.set_extent([-180, 180, -90, 90])
mm1=ax.contourf(lon_reg, lat_reg, data_plot,\
                   levels=levels,extend='both',cmap=cmap1,transform=ccrs.PlateCarree())
cbar = plt.colorbar(mm1,ax=ax,orientation='vertical',fraction=0.075, pad=0.02,shrink=0.9,ticks=cticks)
#cbar.set_label('avg. S below 2000m',fontsize=fs-2)
cbar.ax.tick_params(labelsize=fs-3)
cbar.ax.set_yticklabels(['', '', '','','','',''])
fig.gca().set_visible(False)
if save_plots:
    filename = 'COLORBAR_Oxygen_below_2000m_avg_2012_2017.eps'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight',format='eps')
    del filename
plt.show()



In [None]:
#---
# SALINITY
#---
res = 0.025
levels = np.arange(34.5,35+res, res)
cticks = [34.5,34.65,34.8,34.95]

data_plot = np.copy(dataS_full)
data_plot = np.mean(data_plot,axis=0) # annual mean
data_plot[data_plot<levels[0]] = levels[0]+0.0001
data_plot[data_plot==0]        = -999
data_plot[np.isnan(data_plot)] = -999

mpasomap_run1 = MPASOMap(data = data_plot, name = 'avg. S below 2000m', units = '', mesh = mpasmesh)

plt.figure(figsize=(18,7))
m = mpasomap_run1.plot(region = 'Global', levels = levels, cmap = cmap1, ptype = 'contourf',colorbar=False) # pcolor, contourf
#plt.title('Annual mean, 2012-2017')
if save_plots:
    filename = 'Salinity_below_2000m_avg_2012_2017.png'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight')
    del filename
    
    
#-----
# COLORBAR: plot separately
#-----
print('Separate COLORBAR...')
cticks = [34.5,34.65,34.8,34.95]

lon_reg2 = np.arange(-180,180,1)
lat_reg2 = np.arange(-90,90,1)
lon_reg, lat_reg = np.meshgrid(lon_reg2, lat_reg2)
data_plot = np.zeros_like(lon_reg)

height,width = 18,7
fs = 12

fig = plt.figure(figsize=(height,width))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree(central_longitude=-150))
ax.set_extent([-180, 180, -90, 90])
mm1=ax.contourf(lon_reg, lat_reg, data_plot,\
                   levels=levels,extend='both',cmap=cmap1,transform=ccrs.PlateCarree())
cbar = plt.colorbar(mm1,ax=ax,orientation='vertical',fraction=0.075, pad=0.02,shrink=0.9,ticks=cticks)
cbar.set_label('avg. S below 2000m',fontsize=fs-2)
cbar.ax.tick_params(labelsize=fs-3)
fig.gca().set_visible(False)
if save_plots:
    filename = 'COLORBAR_Salinity_below_2000m_avg_2012_2017.png'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight')
    del filename
plt.show()

fig = plt.figure(figsize=(height,width))
ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree(central_longitude=-150))
ax.set_extent([-180, 180, -90, 90])
mm1=ax.contourf(lon_reg, lat_reg, data_plot,\
                   levels=levels,extend='both',cmap=cmap1,transform=ccrs.PlateCarree())
cbar = plt.colorbar(mm1,ax=ax,orientation='vertical',fraction=0.075, pad=0.02,shrink=0.9,ticks=cticks)
#cbar.set_label('avg. S below 2000m',fontsize=fs-2)
cbar.ax.tick_params(labelsize=fs-3)
cbar.ax.set_yticklabels(['', '', '',''])
fig.gca().set_visible(False)
if save_plots:
    filename = 'COLORBAR_Salinity_below_2000m_avg_2012_2017.eps'
    print(savepath+filename)
    plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight',format='eps')
    del filename
plt.show()



In [1]:
                
plot_time_series = False
if plot_time_series:
    #----
    # plot 6-year time series of below2000m-temperature
    #----
    #subregions = ['global','ICE_south','SPSS_south','STSS_south','STPS_south',\
    #              'Equator',
    #              'STPS_north','STSS_north','SPSS_north','ICE_north']

    #ind1 = subregions.index("STSS_north")
    color1 = 'darkblue'
    color2 = 'firebrick'
    lw = 2
    fs = 12
    
    display_plots = True
    save_plots    = False

    #ind = np.where((lat10.ravel()<-60))[0]
    #ind = np.where((lat10[0,:]>30) & (lat10[0,:]<=60) & (lon10[0,:]>300) & (lon10[0,:]<=350))[0]

    data_plot = np.copy(oxy10_avg) # temp10_avg, oxy10_avg, salt10_avg
    data_plot_euler = np.copy(data_reg_allO) # data_reg_all, data_reg_allO, data_reg_allS
    vari = 'Oxy' # Temp, Oxy, Salt (for filename only)
    
    for ss in range(0,len(subregions)):
        print(subregions[ss])
        
        data_sub = np.zeros([data_plot.shape[0]])
        for tt in range(0,data_plot.shape[0]):
            data_aux = data_plot[tt,:].ravel() 
            if subregions[ss] in ['south_of_60S']:
                ind = np.where((lat10[tt,:]<-60))[0]
            elif subregions[ss] in ['3060N_Atl']:
                ind = np.where((lat10[tt,:]<60) & (lat10[tt,:]>30) & (lon10[tt,:]>300) & (lon10[tt,:]<=350) )[0]
            elif subregions[ss] in ['south_of_60S_WS']:
                ind = np.where((lat10[tt,:]<-60) & (lon10[tt,:]>300) & (lon10[tt,:]<=350))[0]
            #if tt==0:
            #    print(data_aux[ind])
            data_sub[tt] = np.nanmean(data_aux[ind])
            del data_aux

        num_time1 = data_sub.shape[0] # how many time steps infile?
        num_time2 = data_reg_all.shape[0]

                   
        fig = plt.figure(figsize=(10,4))
        plt.plot(np.arange(0,num_time1,1),data_sub,color=color1,linewidth=lw,label='all floats, 10-daily')
        plt.plot(np.arange(0,num_time1,3)[:-1],data_plot_euler[:,ss],color=color2,linewidth=lw,label='Eulerian, monthly')
        
        annual_ = [np.mean(data_plot_euler[0:12,ss]),np.mean(data_plot_euler[12:24,ss]),np.mean(data_plot_euler[24:36,ss]),
                          np.mean(data_plot_euler[36:48,ss]),np.mean(data_plot_euler[48:60,ss]),np.mean(data_plot_euler[60:,ss])]

        if subregions[ss] in ['south_of_60S']:
            print('Eulerian mean, std, max-min:',np.mean(data_plot_euler[:,ss]),np.std(data_plot_euler[:,ss]),(np.max(annual_)-np.min(annual_)),\
                  (np.max(data_plot_euler[:,ss])-np.min(data_plot_euler[:,ss])),\
                  1.4/np.std(data_plot_euler[:,ss]),1.4/(np.max(annual_)-np.min(annual_)),1.4/(np.max(data_plot_euler[:,ss])-np.min(data_plot_euler[:,ss])))
        elif subregions[ss] in ['3060N_Atl']:
            print('Eulerian mean, std, max-min:',np.mean(data_plot_euler[:,ss]),np.std(data_plot_euler[:,ss]),(np.max(annual_)-np.min(annual_)),\
                  (np.max(data_plot_euler[:,ss])-np.min(data_plot_euler[:,ss])),\
                  3.62/np.std(data_plot_euler[:,ss]),3.62/(np.max(annual_)-np.min(annual_)),3.62/(np.max(data_plot_euler[:,ss])-np.min(data_plot_euler[:,ss])))
        else:
            print('Eulerian mean, std:',np.mean(data_plot_euler[:,ss]),np.std(data_plot_euler[:,ss]))

        plt.xticks(np.arange(0,num_time1,365/10),[2012,2013,2014,2015,2016,2017],fontsize=fs)
        #plt.yticks(fontsize=fs)
        plt.ylabel('avg. below 2000 m')#,fontsize=fs) #in $^{\circ}$C

        plt.legend(frameon=False)
        plt.annotate(subregions[ss],xy=(0.01,1.025),\
                    xycoords='axes fraction',fontsize=fs,fontweight='bold',ha='left',color='k')
        dpicnt = 150
        if save_plots:
            filename = vari+'_below_2000m_2012_2017_Elerian_vs_all_floats_'+subregions[ss]+'.png'
            #print(savepath+filename)
            plt.savefig(savepath+filename,dpi = dpicnt, bbox_inches='tight')
            del filename
        if display_plots:
            plt.show()
        else:
            plt.close(fig)


In [28]:


ss=1

print(subregions[ss])

annual_ = [np.mean(data_plot_euler[0:12,ss]),np.mean(data_plot_euler[12:24,ss]),np.mean(data_plot_euler[24:36,ss]),\
          np.mean(data_plot_euler[36:48,ss]),np.mean(data_plot_euler[48:60,ss]),np.mean(data_plot_euler[60:,ss])]

print(np.max(annual_)-np.min(annual_))

print('')

ss=0

print(subregions[ss])

annual_ = [np.mean(data_plot_euler[0:12,ss]),np.mean(data_plot_euler[12:24,ss]),np.mean(data_plot_euler[24:36,ss]),\
          np.mean(data_plot_euler[36:48,ss]),np.mean(data_plot_euler[48:60,ss]),np.mean(data_plot_euler[60:,ss])]

print(np.max(annual_)-np.min(annual_))


3060N_Atl
0.5043027083483196

south_of_60S
0.19029926317313084


In [None]:
print(data_sub)

In [None]:
#----
# functions to speed things up
#----

#@njit
def get_profiles_in_region_njit(data_mask,which_region):
    # data_mask can be 1D or 2D! (ravel() will be applied here)
    
    #subregions = ['SPSS_NA','STSS_NA','STPS_NA','AEQU','STPS_SA','MED',\
    #          'STPS_IND','xx','SPSS_NP','STSS_NP','STPS_NP','PEQU-W','PEQU-E','STPS_SP',\
    #          'ICE_ARCTIC','ICE_NP','ICE_NA','Barents',\
    #         'STSS_Atl','SPSS_Atl','ICE_Atl','STSS_Ind','SPSS_Ind',\
    #         'ICE_Ind','STSS_Pac','SPSS_Pac','ICE_Pac']
    if which_region in ['SPSS_NA']:
        ind_pos = np.where(data_mask.ravel()==1)[0]
    elif which_region in ['STSS_NA']:
        ind_pos = np.where(data_mask.ravel()==2)[0]
    elif which_region in ['STPS_NA']:
        ind_pos = np.where(data_mask.ravel()==3)[0]
    elif which_region in ['AEQU']:
        ind_pos = np.where(data_mask.ravel()==4)[0]
    elif which_region in ['STPS_SA']:
        ind_pos = np.where(data_mask.ravel()==5)[0]
    elif which_region in ['MED']:
        ind_pos = np.where(data_mask.ravel()==6)[0]
    elif which_region in ['STPS_IND']:
        ind_pos = np.where(data_mask.ravel()==7)[0]
    elif which_region in ['xx']:
        ind_pos = np.where(data_mask.ravel()==8)[0]
    elif which_region in ['SPSS_NP']:
        ind_pos = np.where(data_mask.ravel()==9)[0]
    elif which_region in ['STSS_NP']:
        ind_pos = np.where(data_mask.ravel()==10)[0]
    elif which_region in ['STPS_NP']:
        ind_pos = np.where(data_mask.ravel()==11)[0]
    elif which_region in ['PEQU-W']:
        ind_pos = np.where(data_mask.ravel()==12)[0]
    elif which_region in ['PEQU-E']:
        ind_pos = np.where(data_mask.ravel()==13)[0]
    elif which_region in ['STPS_SP']:
        ind_pos = np.where(data_mask.ravel()==14)[0]
    elif which_region in ['ICE_ARCTIC']:
        ind_pos = np.where(data_mask.ravel()==15)[0]
    elif which_region in ['ICE_NP']:
        ind_pos = np.where(data_mask.ravel()==16)[0]
    elif which_region in ['ICE_NA']:
        ind_pos = np.where(data_mask.ravel()==17)[0]
    elif which_region in ['Barents']:
        ind_pos = np.where(data_mask.ravel()==18)[0]
    elif which_region in ['STSS_Atl']:
        ind_pos = np.where(data_mask.ravel()==19)[0]
    elif which_region in ['SPSS_Atl']:
        ind_pos = np.where(data_mask.ravel()==20)[0]
    elif which_region in ['ICE_Atl']:
        ind_pos = np.where(data_mask.ravel()==21)[0]
    elif which_region in ['STSS_Ind']:
        ind_pos = np.where(data_mask.ravel()==22)[0]
    elif which_region in ['SPSS_Ind']:
        ind_pos = np.where(data_mask.ravel()==23)[0]
    elif which_region in ['ICE_Ind']:
        ind_pos = np.where(data_mask.ravel()==24)[0]
    elif which_region in ['STSS_Pac']:
        ind_pos = np.where(data_mask.ravel()==25)[0]
    elif which_region in ['SPSS_Pac']:
        ind_pos = np.where(data_mask.ravel()==26)[0]
    elif which_region in ['ICE_Pac']:
        ind_pos = np.where(data_mask.ravel()==27)[0]
        
    return ind_pos

def get_profiles_in_region_njit_2(lat,lon,which_region):
    # data_mask can be 1D or 2D! (ravel() will be applied here)
    
    # subregions = ['sout_of_60S','3060N_Atl']
    if which_region in ['south_of_60S']:
        ind_pos = np.where(lat<=-60)[0]
    elif which_region in ['3060N_Atl']:
        ind_pos = np.where((lat>30) & (lat<=60) & (lon>300))[0]
                
    return ind_pos

def get_avg_at_each_time_with_weights(a1,a2,a3,a4,num_time,weights,vector_num_region,\
                                      a5,a6,which_region,\
                                      temp_sub_floats,salt_sub_floats,oxy_sub_floats):
    # temp_sub_floats,salt_sub_floats,oxy_sub_floats are initialized outside of function
    
    for tt in range(0,num_time):
        aux1a = a1[tt,:] # mask, random subset of floats
        aux2a = a2[tt,:] # temp, random subset of floats
        aux3a = a3[tt,:]
        aux4a = a4[tt,:]
        aux5a = a5[tt,:] # lat
        aux6a = a6[tt,:] # lon
        
        for nn in range(0,weights.shape[0]): # loop over regions
            ind = np.where(vector_num_region==nn+1)[0] # all floats in current region
            aux1 = aux1a[ind] # mask, random subset of floats
            aux2 = aux2a[ind] # temp, random subset of floats
            aux3 = aux3a[ind]
            aux4 = aux4a[ind]
            aux5 = aux5a[ind]
            aux6 = aux6a[ind]
            # choose all profiles in region (some floats miht move out of the current biome at some point)
            if which_region in ['south_of_60S','3060N_Atl']:
                ind_pos = get_profiles_in_region_njit_2(aux5,aux6,which_region)
            else:
                ind_pos = get_profiles_in_region_njit(aux1,which_region)
            #ind_pos = get_profiles_in_region_njit(aux1,which_region)
            
            # account for weights in averaging 
            if not np.isnan(np.nanmean(aux2[ind_pos][aux2[ind_pos]>-999])*weights[nn]):
                temp_sub_floats[tt]=temp_sub_floats[tt]+np.nanmean(aux2[ind_pos][aux2[ind_pos]>-999])*weights[nn]
                #print(tt,temp_sub_floats[tt])
                # if the all values are -999, mean function returns NaN!
                #print(aux2[ind_pos][aux2[ind_pos]>-999])
                salt_sub_floats[tt]=salt_sub_floats[tt]+np.nanmean(aux3[ind_pos][aux3[ind_pos]>-999])*weights[nn]
                oxy_sub_floats[tt] =oxy_sub_floats[tt]+np.nanmean(aux4[ind_pos][aux4[ind_pos]>-999])*weights[nn]
            #del aux1,aux2,aux3,aux4,ind_pos
    return temp_sub_floats,salt_sub_floats,oxy_sub_floats


#@njit #(error_model="numpy")
def get_avg_at_each_time(a1,a2,a3,a4,num_time,a5,a6,which_region,temp_sub_floats,salt_sub_floats,oxy_sub_floats):
    # temp_sub_floats,salt_sub_floats,oxy_sub_floats are initialized outside of function
    
    for tt in range(0,num_time):
        aux1 = a1[tt,:] # mask, random subset of floats
        aux2 = a2[tt,:] # temp, random subset of floats
        aux3 = a3[tt,:]
        aux4 = a4[tt,:]     
        aux5 = a5[tt,:]     
        aux6 = a6[tt,:]     
        #ind_pos = np.where(aux1.ravel()==nn)[0]
        if which_region in ['south_of_60S','3060N_Atl']:
            ind_pos = get_profiles_in_region_njit_2(aux5,aux6,which_region)
        else:
            ind_pos = get_profiles_in_region_njit(aux1,which_region)
        #print(np.nanmean(aux2[ind_pos]))
        #print(aux2[ind_pos][aux2[ind_pos]>-999].shape)
        temp_sub_floats[tt]=np.mean(aux2[ind_pos][aux2[ind_pos]>-999]) 
        # if the all values are -999, mean function returns NaN!
        #print(aux2[ind_pos][aux2[ind_pos]>-999])
        salt_sub_floats[tt]=np.mean(aux3[ind_pos][aux3[ind_pos]>-999])
        oxy_sub_floats[tt] =np.mean(aux4[ind_pos][aux4[ind_pos]>-999])
        #del aux1,aux2,aux3,aux4,ind_pos
    return temp_sub_floats,salt_sub_floats,oxy_sub_floats



def get_avg_at_each_time_with_weights_wFreq(a1,a2,a3,a4,num_time,freq,weights,vector_num_region,\
                                      a5,a6,which_region,\
                                      temp_sub_floats,salt_sub_floats,oxy_sub_floats):
    # temp_sub_floats,salt_sub_floats,oxy_sub_floats are initialized outside of function
    
    for tt in range(0,num_time):
        if freq==30:
            aux1a = a1[tt,:] # mask, all floats in subregion
            aux2a = a2[tt,:] # temp, all floats in subregion
            aux3a = a3[tt,:]
            aux4a = a4[tt,:]
            aux5a = a5[tt,:]
            aux6a = a6[tt,:]
        elif freq==10:
            # avg float 1-3, 4-6, 7-9 etc -> only for properties; for mask/lat/lon, take info from first time entry
            list_floats = np.arange(0,219,3)
            
            # there was a problem with the values: 
            # sometimes, one of the three values averaged here will be -999, so that the avg is messed up
            # (can happen if float moves over shallower areas)
            # make sure to set these to NaN here
            aux1a = a1[list_floats[tt],:] # mask, all floats in subregion
            aux1 = a2[list_floats[tt]:list_floats[tt]+3,:]
            aux1[aux1<-99] = np.nan
            aux2a = np.mean(aux1,axis=0) # temp, all floats in subregion
            aux1 = a3[list_floats[tt]:list_floats[tt]+3,:]
            aux1[aux1<-99] = np.nan
            aux3a = np.mean(aux1,axis=0)
            aux1 = a4[list_floats[tt]:list_floats[tt]+3,:]
            aux1[aux1<-99] = np.nan
            aux4a = np.mean(aux1,axis=0) 
            aux5a = a5[list_floats[tt],:]
            aux6a = a6[list_floats[tt],:]
       # print('subsampling, aux4a:',aux4a)
            
        for nn in range(0,weights.shape[0]): # loop over regions
            ind = np.where(vector_num_region==nn+1)[0] # all floats in current region
            aux1 = aux1a[ind] # mask, random subset of floats
            aux2 = aux2a[ind] # temp, random subset of floats
            aux3 = aux3a[ind]
            aux4 = aux4a[ind]
            aux5 = aux5a[ind]
            aux6 = aux6a[ind]
            # choose all profiles in region (some floats miht move out of the current biome at some point)
            if which_region in ['south_of_60S','3060N_Atl']:
                ind_pos = get_profiles_in_region_njit_2(aux5,aux6,which_region)
            else:
                ind_pos = get_profiles_in_region_njit(aux1,which_region)
            #ind_pos = get_profiles_in_region_njit(aux1,which_region)
            
            # account for weights in averaging 
            if not np.isnan(np.nanmean(aux2[ind_pos][aux2[ind_pos]>-999])*weights[nn]):
                temp_sub_floats[tt]=temp_sub_floats[tt]+np.nanmean(aux2[ind_pos][aux2[ind_pos]>-999])*weights[nn]
                #print(tt,temp_sub_floats[tt])
                # if the all values are -999, mean function returns NaN!
                #print(aux2[ind_pos][aux2[ind_pos]>-999])
                salt_sub_floats[tt]=salt_sub_floats[tt]+np.nanmean(aux3[ind_pos][aux3[ind_pos]>-999])*weights[nn]
                oxy_sub_floats[tt] =oxy_sub_floats[tt]+np.nanmean(aux4[ind_pos][aux4[ind_pos]>-999])*weights[nn]
            #del aux1,aux2,aux3,aux4,ind_pos
    return temp_sub_floats,salt_sub_floats,oxy_sub_floats

#----
# error statistics
# functions from here: https://gist.github.com/bshishov/5dc237f59f019b26145648e2124ca1c9
#----

def _error(actual: np.ndarray, predicted: np.ndarray):
            """ Simple error """
            return actual - predicted

def mse(actual: np.ndarray, predicted: np.ndarray):
    """ Mean Squared Error """
    return np.mean(np.square(_error(actual, predicted)))

def rmse(actual: np.ndarray, predicted: np.ndarray):
    """ Root Mean Squared Error """
    return np.sqrt(mse(actual, predicted))

def nrmse(actual: np.ndarray, predicted: np.ndarray):
    """ Normalized Root Mean Squared Error """
    #print(np.var(actual))
    return rmse(actual, predicted) / np.std(actual) #(actual.max() - actual.min())

#-----
# I added these:
def mae(actual: np.ndarray, predicted: np.ndarray):
    """ Mean Absolute Error """
    return np.mean(np.abs((_error(actual, predicted))))
                   
def nmae(actual: np.ndarray, predicted: np.ndarray):
    """ normalized Mean Absolute Error """
    return np.mean(np.abs((_error(actual, predicted))))/np.std(actual)

        

In [None]:
#subregions = ['SPSS_NA','STSS_NA','STPS_NA','AEQU','STPS_SA','MED',\
#              'STPS_IND','xx','SPSS_NP','STSS_NP','STPS_NP','PEQU-W','PEQU-E','STPS_SP',\
#              'ICE_ARCTIC','ICE_NP','ICE_NA','Barents',\
#             'STSS_Atl','SPSS_Atl','ICE_Atl','STSS_Ind','SPSS_Ind',\
#             'ICE_Ind','STSS_Pac','SPSS_Pac','ICE_Pac']
subregions = ['south_of_60S','3060N_Atl']

print(len(subregions))

In [None]:
#---
# subsample floats to target float density
#---
# repeat random sampling XX times
# do for each region 
# do for each variable

# if I choose too few floats, a significant fraction of the iterations could enp up being NaN
# -> is this because a lot of the profiles could then be in too shallow areas?

which_error = 'nrmse' # nmae, nrmse, rmse
print('Use the following error metric:',which_error)

#"1.NA SPSS, 2.NA STSS, 3.NA STPS, 4.AEQU, 5.SA STPS, 6.MED (not in FM14)" ;
#"7.IND STPS, 8.(not in FM14)" ;
#"9.NP SPSS, 10.NP STSS, 11.NP STPS, 12.PEQU-W, 13.PEQU-E, 14.SP STPS" ;
#"15.ARCTIC ICE (not in FM14), 16.NP ICE, 17.NA ICE, 18.Barents (not in FM14)" ;
#"19. STSS_Atl, 20. SPSS_Atl, 21. ICE_Atl, 22. STSS_Ind, 23. SPSS_Ind, 
# 24. ICE_Ind, 25. STSS_Pac, 26. SPSS_Pac, 27. ICE_Pac"

#subregions = ['SPSS_NA','STSS_NA','STPS_NA','AEQU','STPS_SA','MED',\
#              'STPS_IND','xx','SPSS_NP','STSS_NP','STPS_NP','PEQU-W','PEQU-E','STPS_SP',\
#              'ICE_ARCTIC','ICE_NP','ICE_NA','Barents',\
#             'STSS_Atl','SPSS_Atl','ICE_Atl','STSS_Ind','SPSS_Ind',\
#             'ICE_Ind','STSS_Pac','SPSS_Pac','ICE_Pac']
subregions = ['south_of_60S','3060N_Atl']

#subregions = ['NA_SPSS']

color1 = 'darkblue'
color2 = 'cornflowerblue'
color3 = 'firebrick'
lw = 2
fs = 12
    
def get_random_floats_1(floatID,num_floats_reduced1):
    ind_random1 = np.random.choice(np.unique(floatID), size=num_floats_reduced1, replace=False)
    ind_random1 = np.sort(ind_random1)
    ind_random  = ind_random1
    return ind_random

#freq = 30 # take profiles every nn days (e.g., 30)
freq = 10 # take profiles every nn days (e.g., 30)
print('Use the following sampling frequency: '+str(freq)+'-daily')

# get data of synthetic floats every nn days
# later: if I want floats to sample every 10th day, don't have all of them sample on day 1, 11, 21 etc but some on 3, 13, 23 etc
# in reality, not all floats are deployed on the same day!
temp_nn  = temp_all[::freq,:,:]
salt_nn  = salt_all[::freq,:,:]
oxy_nn   = oxy_all[::freq,:,:]
lat_nn   = lat_all[::freq,:]
lon_nn   = lon_all[::freq,:]
print(temp_nn.shape)
if freq==30:
    num_time = temp_nn.shape[0]
elif freq==10:
    num_time = 73 # 73 time entries when freq=30 (corresponds with Eulerian output), 219 when freq==10 -> reduce here so that error calculation works
    # 10-daily flot values are averaged to obtain quasi-monthly estimates
    
#----
# get average temperature/salinity/oxygen below e.g. 2000m
#----
dz = np.diff(np.hstack((0,zlevs)))[dd1:] # add a zero to vector; assume zlevs to give depth at the bottom

temp_nn[np.isnan(temp_nn)] = -999 # njit does not like NaNs
temp_nn[temp_nn==0] = -999
temp_nn_avg = np.ones([temp_nn.shape[0],temp_nn.shape[2]]) # initialize array to pass to function
temp_nn_avg = get_vertical_avg(temp_nn,dz,temp_nn_avg) # get avg temp
temp_nn_avg[temp_nn_avg==1]=np.nan # NaN -> areas shallower than 2000m
print('shape, min, max of avg temp below 2000m:',temp_nn_avg.shape,np.nanmin(temp_nn_avg),np.nanmax(temp_nn_avg))

salt_nn[np.isnan(salt_nn)] = -999 # njit does not like NaNs
salt_nn[salt_nn==0] = -999
salt_nn_avg = np.ones([salt_nn.shape[0],salt_nn.shape[2]]) # initialize array to pass to function
salt_nn_avg = get_vertical_avg(salt_nn,dz,salt_nn_avg) # get avg temp
salt_nn_avg[salt_nn_avg==1]=np.nan # NaN -> areas shallower than 2000m
print('shape, min, max of avg salinity below 2000m:',salt_nn_avg.shape,np.nanmin(salt_nn_avg),np.nanmax(salt_nn_avg))

oxy_nn[np.isnan(oxy_nn)] = -999 # njit does not like NaNs
oxy_nn[oxy_nn==0] = -999
oxy_nn_avg = np.ones([oxy_nn.shape[0],oxy_nn.shape[2]]) # initialize array to pass to function
oxy_nn_avg = get_vertical_avg(oxy_nn,dz,oxy_nn_avg) # get avg temp
oxy_nn_avg[oxy_nn_avg==1]=np.nan # NaN -> areas shallower than 2000m
print('shape, min, max of avg oxygen below 2000m:',oxy_nn_avg.shape,np.nanmin(oxy_nn_avg),np.nanmax(oxy_nn_avg))
print('')

# global target of deep-Argo of 1200 floats (converted to Mio km2 per float)
#global_target_list = [np.sum(area)/1e12/1200]

global_target_list = [np.sum(area)/1e12/1200,np.sum(area)/1e12/800,\
                      np.sum(area)/1e12/400]
num_it = 10000
    
# in case NRMSE is chosen, both NRSME and RMSE are stored
# (NRMSE is plotted as violin plots, avg RMSE printed into panel)
rmse_sub_temp = np.zeros([num_it,len(subregions),len(global_target_list)])
rmse_sub_salt = np.zeros([num_it,len(subregions),len(global_target_list)])
rmse_sub_oxy  = np.zeros([num_it,len(subregions),len(global_target_list)])
nrmse_sub_temp = np.zeros([num_it,len(subregions),len(global_target_list)])
nrmse_sub_salt = np.zeros([num_it,len(subregions),len(global_target_list)])
nrmse_sub_oxy  = np.zeros([num_it,len(subregions),len(global_target_list)])
rmse_full_temp = np.zeros([len(subregions)])
rmse_full_salt = np.zeros([len(subregions)])
rmse_full_oxy  = np.zeros([len(subregions)])
nrmse_full_temp = np.zeros([len(subregions)])
nrmse_full_salt = np.zeros([len(subregions)])
nrmse_full_oxy  = np.zeros([len(subregions)])
num_floats_in_region = np.zeros([len(subregions),len(global_target_list)])
num_floats_all_in_region = np.zeros([len(subregions)])
for ss in range(0,len(subregions)):
    which_region = subregions[ss]
    print('Process',which_region)
    
    #----
    # get indices in region mask on native mesh (for area)
    #----
    if not which_region in ['south_of_60S','3060N_Atl']:
        ind_reg = np.where(mask_global==ss+1)[0]
        num_regions = 1
        indices_regions = [ss+1]
    else:
        if which_region in ['south_of_60S']:
            ind_reg = np.where(lat<=-60)[0]
        elif which_region in ['3060N_Atl']:
            ind_reg = np.where((lat>30) & (lat<=60) & (lon>300) & (lon<=350))[0]
        num_regions = 1
        indices_regions = [1] # not actually used in this case, so don't couple to loop index for now
    
    # use mask at first time step to determine what floats I am interested in
    mask_biomes = mask_biomes_all[::freq,:][0,:]

    #----
    # get indices in region mask for floats
    #----
    if which_region in ['south_of_60S']:
        ind = np.where(lat_nn[0,:]<=-60)[0] # only based on position at day 1
    elif which_region in ['3060N_Atl']:
        ind = np.where((lat_nn[0,:]>30) & (lat_nn[0,:]<=60) & (lon_nn[0,:]>300) & (lon_nn[0,:]<=350))[0]
    else:
        ind = np.where(mask_biomes.ravel()==ss+1)[0]
           
    # select all floats in current region
    temp_aux = temp_nn_avg[:,ind] # all floats in the area
    salt_aux = salt_nn_avg[:,ind]
    oxy_aux  = oxy_nn_avg[:,ind]
    lat_aux  = lat_nn[:,ind]
    lon_aux  = lon_nn[:,ind]
    mask_aux = mask_biomes_all[::freq,ind]
    
    floatID = np.arange(0,temp_aux[0].shape[0],1) 
    # make sure to start from zero (this array will be used for subsampling)
    num_floats_all_in_region[ss] = np.max(floatID)
    #---
    # FULL
    #---
    # 
    # get mean temp for full float data set
    #
    # get weights for current regions
    weights_regions = np.zeros([num_regions])
    for nn in range(0,num_regions):
        # get indices of current region (for area)
        #ind_reg1 = np.where((mask_global==indices_regions[nn]))[0] 
        if which_region in ['south_of_60S']:
            ind_reg1 = np.where(lat<=-60)[0]
        elif which_region in ['3060N_Atl']:
            ind_reg1 = np.where((lat>30) & (lat<=60) & (lon>300) & (lon<=350))[0]
        else:
            ind_reg1 = np.where((mask_global==indices_regions[nn]))[0] 
    
        weights_regions[nn] = np.sum(area[ind_reg1])/np.sum(area[ind_reg])
    print('sum weights:',np.sum(weights_regions))  
    
    temp_all_floats = np.zeros([num_time])
    salt_all_floats = np.zeros([num_time])
    oxy_all_floats  = np.zeros([num_time])
    for tt in range(0,num_time):
        # before adding a weighting with area:
        #aux1 = mask_aux[tt,:] # mask, all floats in subregion
        #aux2 = temp_aux[tt,:] # temp, all floats in subregion
        #aux3 = salt_aux[tt,:]
        #aux4 = oxy_aux[tt,:]
        #-----
        if freq==30:
            aux1a = mask_aux[tt,:] # mask, all floats in subregion
            aux2a = temp_aux[tt,:] # temp, all floats in subregion
            aux3a = salt_aux[tt,:]
            aux4a = oxy_aux[tt,:]
            aux5a = lat_aux[tt,:] # lat
            aux6a = lon_aux[tt,:] # lon
        elif freq==10:
            # avg float 1-3, 4-6, 7-9 etc -> only for properties; for mask/lat/lon, take info from first time entry
            list_floats = np.arange(0,219,3)
            aux1a = mask_aux[list_floats[tt],:] # mask, all floats in subregion
            aux2a = np.mean(temp_aux[list_floats[tt]:list_floats[tt]+3,:],axis=0) # temp, all floats in subregion
            aux3a = np.mean(salt_aux[list_floats[tt]:list_floats[tt]+3,:],axis=0)
            aux4a = np.mean(oxy_aux[list_floats[tt]:list_floats[tt]+3,:],axis=0)
            aux5a = lat_aux[list_floats[tt],:] # lat
            aux6a = lon_aux[list_floats[tt],:] # lon
        
        count = 0
        for nn in indices_regions: # loop over regions
            if which_region in ['south_of_60S']:
                ind = np.where(aux5a<=-60)[0] 
            elif which_region in ['3060N_Atl']:
                ind = np.where((aux5a>30) & (aux5a<=60) & (aux6a>300) & (aux6a<=350))[0]
            else:
                ind = np.where(aux1a==nn)[0] # all floats in current region
            aux1 = aux1a[ind] # mask, random subset of floats
            aux2 = aux2a[ind] # temp, random subset of floats
            aux3 = aux3a[ind]
            aux4 = aux4a[ind]
            aux5 = aux5a[ind] # lat
            aux6 = aux6a[ind] # lon
            # choose all profiles in region (some floats might move out of the current biome at some point)
            if which_region in ['south_of_60S']:
                ind_pos = np.where(aux5<=-60)[0]
            elif which_region in ['3060N_Atl']:
                ind_pos = np.where((aux5>30) & (aux5<=60) & (aux6>300) & (aux6<=350))[0]
            else:
                ind_pos = np.where(aux1.ravel()==nn)[0]
            #ind_pos = get_profiles_in_region_njit(aux1,which_region)
            
            aux2[np.isnan(aux2)] = -999
            aux3[np.isnan(aux3)] = -999
            aux4[np.isnan(aux4)] = -999
            aux5[np.isnan(aux5)] = -999 # lat
            aux6[np.isnan(aux6)] = -999 # lon
            
            #print('salt',aux3)
            #print('temp',aux2)
            
            # account for weights in averaging 
            if not np.isnan(np.nanmean(aux2[ind_pos][aux2[ind_pos]>-999])*weights_regions[count]):
                temp_all_floats[tt]=temp_all_floats[tt]+np.nanmean(aux2[ind_pos][aux2[ind_pos]>-999])*weights_regions[count]
                # if the all values are -999, mean function returns NaN!
                salt_all_floats[tt]=salt_all_floats[tt]+np.nanmean(aux3[ind_pos][aux3[ind_pos]>-999])*weights_regions[count]
                oxy_all_floats[tt] =oxy_all_floats[tt]+np.nanmean(aux4[ind_pos][aux4[ind_pos]>-999])*weights_regions[count]
            count = count+1
            del aux1,aux2,aux3,aux4,ind_pos,ind
        del aux1a,aux2a,aux3a,aux4a
    del weights_regions
        
    print('std temp full output:',np.std(data_reg_all[:,ss]))
    print('std salt full output:',np.std(data_reg_allS[:,ss]))
    print('std oxy full output:',np.std(data_reg_allO[:,ss]))
    # get rmse for full float output
    if which_error in ['nrmse']:
        nrmse_full_temp[ss] = nrmse(data_reg_all[:,ss],temp_all_floats[:-1])
        nrmse_full_salt[ss] = nrmse(data_reg_allS[:,ss],salt_all_floats[:-1])
        nrmse_full_oxy[ss]  = nrmse(data_reg_allO[:,ss],oxy_all_floats[:-1])
        # also store RMSE to get a sense of the actual mismatch
        rmse_full_temp[ss] = math.sqrt(np.square(np.subtract(temp_all_floats[:-1],data_reg_all[:,ss])).mean())
        rmse_full_salt[ss] = math.sqrt(np.square(np.subtract(salt_all_floats[:-1],data_reg_allS[:,ss])).mean())
        rmse_full_oxy[ss]  = math.sqrt(np.square(np.subtract(oxy_all_floats[:-1],data_reg_allO[:,ss])).mean())
    elif which_error in ['rmse']:
        rmse_full_temp[ss] = math.sqrt(np.square(np.subtract(temp_all_floats[:-1],data_reg_all[:,ss])).mean())
        rmse_full_salt[ss] = math.sqrt(np.square(np.subtract(salt_all_floats[:-1],data_reg_allS[:,ss])).mean())
        rmse_full_oxy[ss]  = math.sqrt(np.square(np.subtract(oxy_all_floats[:-1],data_reg_allO[:,ss])).mean())
    elif which_error in ['nmae']:
        rmse_full_temp[ss] = nmae(data_reg_all[:,ss],temp_all_floats[:-1])
        rmse_full_salt[ss] = nmae(data_reg_allS[:,ss],salt_all_floats[:-1])
        rmse_full_oxy[ss]  = nmae(data_reg_allO[:,ss],oxy_all_floats[:-1])
    #--------
    
    #----
    # SUBSAMPLE FLOATS
    #----
    for gg in range(0,len(global_target_list)):
        print('Process global target density: '+str(global_target_list[gg])+' ('+str(1/(global_target_list[gg]/(np.sum(area)/1e12)))+' floats)')
        for ii in tqdm(range(0,num_it)):
            # use unique float IDs to determine which floats to keep/kick out
            
            # find out how many time steps are NaN for each float in subregion
            # --> only consider those whoch spend more than half of the time in deep enough waters
            # 0.5*temp_aux.shape[0] --> 16 in the case of 30-daily flaot output (16 profiles have to be filled)
            b1 = np.sum(np.isnan(temp_aux),axis=0)
            ind_deep_enough = np.where(b1<=(0.5*temp_aux.shape[0]))[0]
            
            mask_aux_new = np.copy(mask_aux)[:,ind_deep_enough]
            temp_aux_new = np.copy(temp_aux)[:,ind_deep_enough]
            salt_aux_new = np.copy(salt_aux)[:,ind_deep_enough]
            oxy_aux_new  = np.copy(oxy_aux)[:,ind_deep_enough]
            lat_aux_new  = np.copy(lat_aux)[:,ind_deep_enough]
            lon_aux_new  = np.copy(lon_aux)[:,ind_deep_enough]
            floatID = np.arange(0,temp_aux_new[0].shape[0],1) 
            
            # re-define num_floats_all after reduction to profiles with sufficient data
            num_floats_all = floatID.shape[0]
            
            #num_floats_reduced_list = np.zeros([num_regions])
            try:
                del ind_random
            except: 
                pass
            
            weights_regions = np.zeros([num_regions])
            for nn in range(0,num_regions):
                
                # get indices of current region (for area)
                if which_region in ['south_of_60S']:
                    ind_reg1 = np.where(lat<=-60)[0]
                elif which_region in ['3060N_Atl']:
                    ind_reg1 = np.where((lat>30) & (lat<=60) & (lon>300) & (lon<=350))[0]
                else:
                    ind_reg1 = np.where((mask_global==indices_regions[nn]))[0] 
                # get indices of current region in float data set
                # take first time step of mask; later: kick out profiles not in subregion
                if which_region in ['south_of_60S']:
                    ind1 = np.where(lat_aux_new[0,:]<=-60)[0]
                elif which_region in ['3060N_Atl']:
                    ind1 = np.where((lat_aux_new[0,:]>30) & (lat_aux_new[0,:]<=60) & (lon_aux_new[0,:]>300) & (lon_aux_new[0,:]<=350))[0]
                else:
                    ind1     = np.where((mask_aux_new[0,:]==indices_regions[nn]))[0] 
                #print(ind1.shape)
                
                # get number of reduced float density in current sub-subregion
                num_floats_red = int(np.round(np.sum(area[ind_reg1])/1e12/global_target_list[gg]))
                #print('Region index, number of floats',indices_regions[nn],num_floats_red)
                
                # reduce to floats in current sub-subregion
                temp_aux_new2 = temp_aux_new[:,ind1]
                salt_aux_new2 = salt_aux_new[:,ind1]
                oxy_aux_new2  = oxy_aux_new[:,ind1]
                lat_aux_new2  = lat_aux_new[:,ind1]
                lon_aux_new2  = lon_aux_new[:,ind1]
                #floatID_2     = np.arange(0,temp_aux_new2[0,:].shape[0],1)
                
            #    print('region index / number of floats:',indices_regions[nn],num_floats_red)
                ind_random_2 = get_random_floats_1(ind1,num_floats_red)
                if nn==0:
                    ind_random = ind_random_2
                    vector_region = np.ones([ind_random_2.shape[0]])
                else:
                    ind_random = np.concatenate((ind_random,ind_random_2))
                    vector_region = np.concatenate((vector_region,(nn+1)*np.ones([ind_random_2.shape[0]])))
                weights_regions[nn] = np.sum(area[ind_reg1])/np.sum(area[ind_reg])
                del ind1,ind_reg1,temp_aux_new2,salt_aux_new2,oxy_aux_new2,ind_random_2,num_floats_red
            
            # check if sum of wieghts equal 1
            # print(np.sum(weights_regions))
            
            if ii==0:
                num_floats_in_region[ss,gg] = ind_random.shape[0]
            ind_random = np.sort(ind_random)
            # subsampled dataset in current subregion
            a1 = mask_aux_new[:,ind_random]
            a2 = temp_aux_new[:,ind_random]
            a3 = salt_aux_new[:,ind_random]
            a4 = oxy_aux_new[:,ind_random]
            a5 = lat_aux_new[:,ind_random]
            a6 = lon_aux_new[:,ind_random]
            a2[np.isnan(a2)] = -999
            a3[np.isnan(a3)] = -999
            a4[np.isnan(a4)] = -999
            a5[np.isnan(a5)] = -999
            a6[np.isnan(a6)] = -999
            #print('a2',a2)
            
            #---
            # SUBSAMPLED
            #---
            # loop over all time steps and store average value of subsampled
            temp_sub_floats = np.zeros([num_time])
            salt_sub_floats = np.zeros([num_time])
            oxy_sub_floats  = np.zeros([num_time])

            #-----
            # in the averaging, accounting for different areas of the sub-subregions!
            #-----
            # DONE create an array with randomly selected floats for each sub-subregion
            # DONE crete weights for each sub-subregion
            # DONE pass these to function 
            # carfully double-check that this is correct!
            
            temp_sub_floats,salt_sub_floats,oxy_sub_floats = get_avg_at_each_time_with_weights_wFreq(a1,a2,a3,a4,\
                                            num_time,freq,weights_regions,vector_region,\
                                                a5,a6,which_region,\
                                                temp_sub_floats,salt_sub_floats,oxy_sub_floats)
            #print('subsampling temp_sub_floats:',temp_sub_floats)
            #print('subsampling salt_sub_floats:',salt_sub_floats)
            #print('subsampling oxy_sub_floats:',oxy_sub_floats)
            
            #temp_sub_floats,salt_sub_floats,oxy_sub_floats = get_avg_at_each_time(a1,a2,a3,a4,\
            #                                num_time,temp_sub_floats,salt_sub_floats,oxy_sub_floats)
             
            #----
            # get rmse or nrmse
            #----      
            
            # SOME entries are zero 
            # --> randomly selected floats end up in shallow regions and are skipped in function above
            # think carefully about how to treat them!
            # 1st idea here: kick all misisng time entries to NaN and disregard them from nrmse calculation
            # --> test this. Does this reduce the large errors I see for some reigons?
            # Note that this way, time series are not always exactly 6 years long! (but I assume that the majority of 
            # time series is not affected)
            
            b1 = np.copy(data_reg_all[:,ss])
            b1[temp_sub_floats[:-1]==0] = np.nan
            temp_sub_floats[temp_sub_floats==0] = np.nan
            
            b2 = np.copy(data_reg_allS[:,ss])
            b2[salt_sub_floats[:-1]==0] = np.nan
            salt_sub_floats[salt_sub_floats==0] = np.nan
            
            b3 = np.copy(data_reg_allO[:,ss])
            b3[oxy_sub_floats[:-1]==0] = np.nan
            oxy_sub_floats[oxy_sub_floats==0]   = np.nan
            ind_no_NaN = np.where(~np.isnan(b1))[0]
            
            #print(oxy_sub_floats[:-1])
            # TEMP
            if which_error in ['nrmse']:
                nrmse_sub_temp[ii,ss,gg] = nrmse(b1[ind_no_NaN],temp_sub_floats[:-1][ind_no_NaN])
                rmse_sub_temp[ii,ss,gg]  = math.sqrt(np.square(np.subtract(temp_sub_floats[:-1][ind_no_NaN],b1[ind_no_NaN])).mean())
            elif which_error in ['rmse']:
                rmse_sub_temp[ii,ss,gg] = math.sqrt(np.square(np.subtract(temp_sub_floats[:-1][ind_no_NaN],b1[ind_no_NaN])).mean())
            elif which_error in ['nmae']:
                rmse_sub_temp[ii,ss,gg] = nmae(b1[ind_no_NaN],temp_sub_floats[:-1][ind_no_NaN])
            #if np.isnan(rmse_sub_temp[ii,ss,gg]):
            #    print('stop due to NaN')
            #    break

            # SALT
            if which_error in ['nrmse']:
                nrmse_sub_salt[ii,ss,gg] = nrmse(b2[ind_no_NaN],salt_sub_floats[:-1][ind_no_NaN])
                rmse_sub_salt[ii,ss,gg]  = math.sqrt(np.square(np.subtract(salt_sub_floats[:-1][ind_no_NaN],b2[ind_no_NaN])).mean())
            elif which_error in ['rmse']:
                rmse_sub_salt[ii,ss,gg] = math.sqrt(np.square(np.subtract(salt_sub_floats[:-1][ind_no_NaN],b2[ind_no_NaN])).mean())
            elif which_error in ['nmae']:
                rmse_sub_salt[ii,ss,gg] = nmae(b2[ind_no_NaN],salt_sub_floats[:-1][ind_no_NaN])
                
            # OXYGEN
            if which_error in ['nrmse']:
                nrmse_sub_oxy[ii,ss,gg] = nrmse(b3[ind_no_NaN],oxy_sub_floats[:-1][ind_no_NaN])
                rmse_sub_oxy[ii,ss,gg]  = math.sqrt(np.square(np.subtract(oxy_sub_floats[:-1][ind_no_NaN],b3[ind_no_NaN])).mean())
            elif which_error in ['rmse']:
                rmse_sub_oxy[ii,ss,gg] = math.sqrt(np.square(np.subtract(oxy_sub_floats[:-1][ind_no_NaN],b3[ind_no_NaN])).mean())
            elif which_error in ['nmae']:
                rmse_sub_oxy[ii,ss,gg] = nmae(b3[ind_no_NaN],oxy_sub_floats[:-1][ind_no_NaN])
                
        if gg==2:
            print('rmse sub oxy: ',np.mean(rmse_sub_oxy[:,ss,gg]))
                
            del b1,b2,b3,ind_no_NaN
print('done')



In [None]:
print(which_error)

print('oxy:',rmse_full_oxy,np.mean(rmse_sub_oxy,axis=0))

print('salt:',rmse_full_salt,np.mean(rmse_sub_salt,axis=0))

print('temp:',rmse_full_temp,np.mean(rmse_sub_temp,axis=0))

#ind1 = np.where((lat_aux_new[0,:]>30) & (lat_aux_new[0,:]<=60) & (lon_aux_new[0,:]>300))[0]
#print(lat_aux_new[0,:])
#print(ind1)

print(np.mean(rmse_sub_temp[:,ss,0]))
print(np.mean(rmse_sub_salt[:,ss,0]))
print(np.mean(rmse_sub_oxy[:,ss,0]))
print(global_target_list)


In [None]:

#subregions = ['SPSS_NA','STSS_NA','STPS_NA','AEQU','STPS_SA','MED',\
#              'STPS_IND','xx','SPSS_NP','STSS_NP','STPS_NP','PEQU-W','PEQU-E','STPS_SP',\
#              'ICE_ARCTIC','ICE_NP','ICE_NA','Barents',\
#             'STSS_Atl','SPSS_Atl','ICE_Atl','STSS_Ind','SPSS_Ind',\
#             'ICE_Ind','STSS_Pac','SPSS_Pac','ICE_Pac']

savepath2     = savepath+str(num_it)+'iterations_'+which_error+'/'
# check existence of paths
if not os.path.exists(savepath2):
    print ('Created '+savepath2)
    os.makedirs(savepath2)
    

In [None]:
#---
# plot
#---
from matplotlib import cm 

fs = 12

# subsampled float data set
ms2 = 10
caps = 4

# full float data set
symbol_full = 's'
ms = 6

color2 = 'darkorange'

nn=250
color1a = [cm.Blues(nn)[0],cm.Blues(nn)[1],cm.Blues(nn)[2]]
nn=190
color1b = [cm.Blues(nn)[0],cm.Blues(nn)[1],cm.Blues(nn)[2]]
nn=130
color1c = [cm.Blues(nn)[0],cm.Blues(nn)[1],cm.Blues(nn)[2]]
nn=70
color1d = [cm.Blues(nn)[0],cm.Blues(nn)[1],cm.Blues(nn)[2]]

#subregions = ['global','ICE_south','SPSS_south','STSS_south','STPS_south',\
#              'Equator',
#              'STPS_north','STSS_north','SPSS_north']

#subregions = ['SPSS_NA','STSS_NA','STPS_NA','AEQU','STPS_SA','MED',\
#              'STPS_IND','xx','SPSS_NP','STSS_NP','STPS_NP','PEQU-W','PEQU-E','STPS_SP',\
#              'ICE_ARCTIC','ICE_NP','ICE_NA','Barents',\
#             'STSS_Atl','SPSS_Atl','ICE_Atl','STSS_Ind','SPSS_Ind',\
#             'ICE_Ind','STSS_Pac','SPSS_Pac','ICE_Pac']

save_plots = False
display_plots = True


In [None]:
#---
# VIOLIN PLOTS plot each region
#---

#subregions2 = ['SPSS_NA','STSS_NA','STPS_NA','AEQU','STPS_SA',\
#              'STPS_IND','SPSS_NP','STSS_NP','STPS_NP','PEQU-W','PEQU-E','STPS_SP',\
#              'STSS_Atl','SPSS_Atl','ICE_Atl','STSS_Ind','SPSS_Ind',\
#             'ICE_Ind','STSS_Pac','SPSS_Pac','ICE_Pac']
#subregions2 = ['south of 60°S','N. Atl. 30-60°N']
subregions2 = ['south_of_60S','3060N_Atl']

display_plots_png = True
display_plots_eps = False
save_plots = True
plot_eps = True

if len(global_target_list)>3:
    xpos = [0.3,0.2,0.1,0]
    xlim1,xlim2 = -0.05,0.35
    width,height = 3,3
else:
    xpos = [0.2,0.1,0]
    xlim1,xlim2 = -0.05,0.25
    width,height = 2,3
    width2,height2 = 3.2,3
    
for rr in range(0,len(subregions2)):
    print(subregions2[rr])
    which_region = subregions2[rr]
    ss = subregions.index(subregions2[rr])
    #---
    # TEMP
    #---
    fig = plt.figure(figsize=(width2,height2))
    if which_error in ['nrmse']:
        plt.vlines(0.8,-2,220,color='darkgrey',linestyle='-',linewidth=1)
    #plt.plot(0.3,rmse_full_temp[ss],symbol_full,color=color2,markersize=ms,zorder=0,label='all floats')
    plt.hlines(nrmse_full_temp[ss],xlim1,xlim2,'black',linewidth=1.5,zorder=0,label='all floats')
    
    width1 = 0.06
    color_vio = 'darkblue'
    pp1=plt.violinplot(nrmse_sub_temp[:,ss,0],[xpos[0]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9]) 
    pp2=plt.violinplot(nrmse_sub_temp[:,ss,1],[xpos[1]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
    pp3=plt.violinplot(nrmse_sub_temp[:,ss,2],[xpos[2]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
    for pc in pp1['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    for pc in pp2['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    for pc in pp3['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    
    if which_error in ['rmse']:
        plt.ylabel(which_error.upper()+' T below 2000m\nin deg C\n(Eulerian vs. float-based)',fontsize=fs)
    else:
        plt.ylabel(which_error.upper()+' T below 2000m\n(Eulerian vs. float-based)',fontsize=fs)
    plt.xlabel('Global # of floats\n(# in subregion)',fontsize=fs)
    plt.annotate('n='+str(num_it),xy=(0.01,1.015),\
                xycoords='axes fraction',fontsize=fs,ha='left',color='k')
    plt.annotate(subregions[ss],xy=(0.99,1.015),\
                xycoords='axes fraction',fontsize=fs-2,ha='right',color='k',style='italic')
    plt.annotate('# all synthetic floats\nin subregion: '+str(int(num_floats_all_in_region[ss])),xy=(0.97,0.84),\
                xycoords='axes fraction',fontsize=fs-2,ha='right',color=color2)
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+' floats: '+\
                 str(np.round(1000*np.mean(rmse_sub_temp[:,ss,2]))/1000)+' mmol m-3',xy=(0.97,0.77),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+' floats: '+\
                 str(np.round(1000*np.mean(rmse_sub_temp[:,ss,1]))/1000)+' mmol m-3',xy=(0.97,0.73),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+' floats: '+\
                 str(np.round(1000*np.mean(rmse_sub_temp[:,ss,0]))/1000)+' mmol m-3',xy=(0.97,0.69),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    if len(global_target_list)>3:
        plt.xticks([0,0.1,0.2,0.3],[str(int(np.round((np.sum(area)/1e12)/global_target_list[3])))+\
                                            '\n('+str(int(num_floats_in_region[ss,3]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+\
                                            '\n('+str(int(num_floats_in_region[ss,2]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+\
                                            '\n('+str(int(num_floats_in_region[ss,1]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+\
                                            '\n('+str(int(num_floats_in_region[ss,0]))+')'],fontsize=fs)
    else: # only 3 values
        plt.xticks([0,0.1,0.2],[str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+\
                                            '\n('+str(int(num_floats_in_region[ss,2]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+\
                                            '\n('+str(int(num_floats_in_region[ss,1]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+\
                                            '\n('+str(int(num_floats_in_region[ss,0]))+')'],fontsize=fs)
    if which_error in ['nrmse']:
        if which_region in ['south_of_60S']:
            plt.ylim((0,100))
        elif which_region in ['3060N_Atl']:
            plt.ylim((0,400))
    elif which_error in ['rmse']:
        ylim1,ylim2 = plt.gca().get_ylim()
        plt.ylim((0,ylim2))
    plt.yticks(fontsize=fs)
    plt.xlim((xlim1,xlim2))
    if which_error in ['rmse']:
        if which_region in ['south_of_60S']:
            plt.yticks([0,0.05,0.1,0.15],[0,0.05,0.1,0.15])
        elif which_region in ['3060N_Atl']:
            plt.yticks([0,0.1,0.2,0.3,0.4],[0,0.1,0.2,0.3,0.4])
    elif which_error in ['nrmse']:
        if which_region in ['south_of_60S']:
            plt.yticks([0,25,50,75,100],[0,25,50,75,100])
        elif which_region in ['3060N_Atl']:
            plt.yticks([0,100,200,300,400],[0,100,200,300,400])
    if save_plots:
        filename = 'Violin_Temperature_below_2000m_2012_2017_'+which_error+'_norm_by_std_full_vs_subsampled_vs_eulerian_'+\
                str(num_it)+'iterations_'+subregions2[rr]+'.png'
        plt.savefig(savepath2+filename,dpi = dpicnt, bbox_inches='tight')
        del filename
    if display_plots_png:
        plt.show()
    else:
        plt.close(fig)
        
    #---
    # OXYGEN
    #---
    fig = plt.figure(figsize=(width2,height2))
    if which_error in ['nrmse']:
        plt.vlines(0.8,-2,220,color='darkgrey',linestyle='-',linewidth=1)
    #plt.plot(0.3,rmse_full_temp[ss],symbol_full,color=color2,markersize=ms,zorder=0,label='all floats')
    plt.hlines(nrmse_full_oxy[ss],xlim1,xlim2,color='black',linewidth=1.5,zorder=0,label='all floats')
    
    width1 = 0.06
    color_vio = 'darkblue'
    pp1=plt.violinplot(nrmse_sub_oxy[:,ss,0],[xpos[0]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9]) 
    pp2=plt.violinplot(nrmse_sub_oxy[:,ss,1],[xpos[1]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
    pp3=plt.violinplot(nrmse_sub_oxy[:,ss,2],[xpos[2]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
    for pc in pp1['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    for pc in pp2['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    for pc in pp3['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    
    if which_error in ['rmse']:
        plt.ylabel(which_error.upper()+' O$_{2}$ below 2000m\nin mmol m$^{-3}$\n(Eulerian vs. float-based)',fontsize=fs)
    else:
        plt.ylabel(which_error.upper()+' O$_{2}$ below 2000m\n(Eulerian vs. float-based)',fontsize=fs)
    plt.xlabel('Global # of floats\n(# in subregion)',fontsize=fs)
    plt.annotate('n='+str(num_it),xy=(0.01,1.015),\
                xycoords='axes fraction',fontsize=fs,ha='left',color='k')
    plt.annotate(subregions[ss],xy=(0.99,1.015),\
                xycoords='axes fraction',fontsize=fs-2,ha='right',color='k',style='italic')
    plt.annotate('# all synthetic floats\nin subregion: '+str(int(num_floats_all_in_region[ss])),xy=(0.97,0.84),\
                xycoords='axes fraction',fontsize=fs-2,ha='right',color=color2)
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+' floats: '+\
                 str(np.round(100*np.mean(rmse_sub_oxy[:,ss,2]))/100)+' mmol m-3',xy=(0.97,0.77),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+' floats: '+\
                 str(np.round(100*np.mean(rmse_sub_oxy[:,ss,1]))/100)+' mmol m-3',xy=(0.97,0.73),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+' floats: '+\
                 str(np.round(100*np.mean(rmse_sub_oxy[:,ss,0]))/100)+' mmol m-3',xy=(0.97,0.69),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.yticks(fontsize=fs)
    plt.xlim((xlim1,xlim2))
    if len(global_target_list)>3:
        plt.xticks([0,0.1,0.2,0.3],[str(int(np.round((np.sum(area)/1e12)/global_target_list[3])))+\
                                            '\n('+str(int(num_floats_in_region[ss,3]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+\
                                            '\n('+str(int(num_floats_in_region[ss,2]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+\
                                            '\n('+str(int(num_floats_in_region[ss,1]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+\
                                            '\n('+str(int(num_floats_in_region[ss,0]))+')'],fontsize=fs)
    else: # only 3 values
        plt.xticks([0,0.1,0.2],[str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+\
                                            '\n('+str(int(num_floats_in_region[ss,2]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+\
                                            '\n('+str(int(num_floats_in_region[ss,1]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+\
                                            '\n('+str(int(num_floats_in_region[ss,0]))+')'],fontsize=fs)
    if which_error in ['nrmse']:
        if which_region in ['south_of_60S']:
            plt.ylim((0,100))
        elif which_region in ['3060N_Atl']:
            plt.ylim((0,400))
    elif which_error in ['rmse']:
        ylim1,ylim2 = plt.gca().get_ylim()
        plt.ylim((0,ylim2))
    if which_error in ['rmse']:
        if which_region in ['south_of_60S']:
            plt.yticks([0,1,2,3,4],[0,1,2,3,4])
        elif which_region in ['3060N_Atl']:
            plt.yticks([0,2,4,6,8,10,12],[0,2,4,6,8,10,12])
    elif which_error in ['nrmse']:
        if which_region in ['south_of_60S']:
            plt.yticks([0,25,50,75,100],[0,25,50,75,100])
        elif which_region in ['3060N_Atl']:
            plt.yticks([0,100,200,300,400],[0,100,200,300,400])
    if save_plots:
        filename = 'Violin_Oxygen_below_2000m_2012_2017_'+which_error+'_norm_by_std_full_vs_subsampled_vs_eulerian_'+\
                str(num_it)+'iterations_'+subregions2[rr]+'.png'
        plt.savefig(savepath2+filename,dpi = dpicnt, bbox_inches='tight')
        del filename
    if display_plots_png:
        plt.show()
    else:
        plt.close(fig)
    
    #---
    # Salinity
    #---
    fig = plt.figure(figsize=(width2,height2))
    if which_error in ['nrmse']:
        plt.vlines(0.8,-2,220,color='darkgrey',linestyle='-',linewidth=1)
    #plt.plot(0.3,rmse_full_temp[ss],symbol_full,color=color2,markersize=ms,zorder=0,label='all floats')
    plt.hlines(nrmse_full_salt[ss],xlim1,xlim2,color='black',linewidth=1.5,zorder=0,label='all floats')
    
    width1 = 0.06
    color_vio = 'darkblue'
    pp1=plt.violinplot(nrmse_sub_salt[:,ss,0],[xpos[0]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9]) 
    pp2=plt.violinplot(nrmse_sub_salt[:,ss,1],[xpos[1]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
    pp3=plt.violinplot(nrmse_sub_salt[:,ss,2],[xpos[2]],points=20,widths=width1,showmeans=False,
                     showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
    for pc in pp1['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    for pc in pp2['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    for pc in pp3['bodies']:
        pc.set_facecolor(color_vio)
        pc.set_edgecolor(color_vio)
    
    if which_error in ['rmse']:
        plt.ylabel(which_error.upper()+' S below 2000m\n(Eulerian vs. float-based)',fontsize=fs)
    else:
        plt.ylabel(which_error.upper()+' S below 2000m\n(Eulerian vs. float-based)',fontsize=fs)
    plt.xlabel('Global # of floats\n(# in subregion)',fontsize=fs)
    plt.annotate('n='+str(num_it),xy=(0.01,1.015),\
                xycoords='axes fraction',fontsize=fs,ha='left',color='k')
    plt.annotate(subregions[ss],xy=(0.99,1.015),\
                xycoords='axes fraction',fontsize=fs-2,ha='right',color='k',style='italic')
    plt.annotate('# all synthetic floats\nin subregion: '+str(int(num_floats_all_in_region[ss])),xy=(0.97,0.84),\
                xycoords='axes fraction',fontsize=fs-2,ha='right',color=color2)
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+' floats: '+\
                 str(np.round(10000*np.mean(rmse_sub_salt[:,ss,2]))/10000)+' mmol m-3',xy=(0.97,0.77),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+' floats: '+\
                 str(np.round(10000*np.mean(rmse_sub_salt[:,ss,1]))/10000)+' mmol m-3',xy=(0.97,0.73),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.annotate(str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+' floats: '+\
                 str(np.round(10000*np.mean(rmse_sub_salt[:,ss,0]))/10000)+' mmol m-3',xy=(0.97,0.69),\
                xycoords='axes fraction',fontsize=fs-6,ha='right',color='k')
    plt.yticks(fontsize=fs)
    plt.xlim((xlim1,xlim2))
    if len(global_target_list)>3:
        plt.xticks([0,0.1,0.2,0.3],[str(int(np.round((np.sum(area)/1e12)/global_target_list[3])))+\
                                            '\n('+str(int(num_floats_in_region[ss,3]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+\
                                            '\n('+str(int(num_floats_in_region[ss,2]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+\
                                            '\n('+str(int(num_floats_in_region[ss,1]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+\
                                            '\n('+str(int(num_floats_in_region[ss,0]))+')'],fontsize=fs)
    else: # only 3 values
        plt.xticks([0,0.1,0.2],[str(int(np.round((np.sum(area)/1e12)/global_target_list[2])))+\
                                            '\n('+str(int(num_floats_in_region[ss,2]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[1])))+\
                                            '\n('+str(int(num_floats_in_region[ss,1]))+')',\
                               str(int(np.round((np.sum(area)/1e12)/global_target_list[0])))+\
                                            '\n('+str(int(num_floats_in_region[ss,0]))+')'],fontsize=fs)
    if which_error in ['nrmse']:
        if which_region in ['south_of_60S']:
            plt.ylim((0,100))
        elif which_region in ['3060N_Atl']:
            plt.ylim((0,400))
    elif which_error in ['rmse']:
        ylim1,ylim2 = plt.gca().get_ylim()
        plt.ylim((0,ylim2))
    if which_error in ['rmse']:
        if which_region in ['south_of_60S']:
            plt.yticks([0,0.002,0.004,0.006,0.008],[0,0.002,0.004,0.006,0.008])
        elif which_region in ['3060N_Atl']:
            plt.yticks([0,0.01,0.02,0.03,0.04],[0,0.01,0.02,0.03,0.04])
    elif which_error in ['nrmse']:
        if which_region in ['south_of_60S']:
            plt.yticks([0,25,50,75,100],[0,25,50,75,100])
        elif which_region in ['3060N_Atl']:
            plt.yticks([0,100,200,300,400],[0,100,200,300,400])
    if save_plots:
        filename = 'Violin_Salinity_below_2000m_2012_2017_'+which_error+'_norm_by_std_full_vs_subsampled_vs_eulerian_'+\
                str(num_it)+'iterations_'+subregions2[rr]+'.png'
        plt.savefig(savepath2+filename,dpi = dpicnt, bbox_inches='tight')
        del filename
    if display_plots_png:
        plt.show()
    else:
        plt.close(fig)
        
        
    #----
    # PLOT EPS
    #----
    
    if plot_eps:
        #---
        # TEMP
        #---
        fig = plt.figure(figsize=(width2,height2))
        if which_error in ['nrmse']:
            plt.vlines(0.8,-2,220,color='darkgrey',linestyle='-',linewidth=1)
        #plt.plot(0.3,rmse_full_temp[ss],symbol_full,color=color2,markersize=ms,zorder=0,label='all floats')
        plt.hlines(nrmse_full_temp[ss],xlim1,xlim2,'black',linewidth=1.5,zorder=0,label='all floats')

        width1 = 0.06
        color_vio = 'darkblue'
        pp1=plt.violinplot(nrmse_sub_temp[:,ss,0],[xpos[0]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9]) 
        pp2=plt.violinplot(nrmse_sub_temp[:,ss,1],[xpos[1]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
        pp3=plt.violinplot(nrmse_sub_temp[:,ss,2],[xpos[2]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
        for pc in pp1['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)
        for pc in pp2['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)
        for pc in pp3['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)

        if len(global_target_list)>3:
            plt.xticks([0,0.1,0.2,0.3],[],fontsize=fs)
        else: # only 3 values
            plt.xticks([0,0.1,0.2],[],fontsize=fs)
        if which_error in ['nrmse']:
            if which_region in ['south_of_60S']:
                plt.ylim((0,100))
            elif which_region in ['3060N_Atl']:
                plt.ylim((0,400))
        elif which_error in ['rmse']:
            ylim1,ylim2 = plt.gca().get_ylim()
            plt.ylim((0,ylim2))
        plt.yticks(fontsize=fs)
        plt.xlim((xlim1,xlim2))
        # TEMP
        if which_error in ['rmse']:
            if which_region in ['south_of_60S']:
                plt.yticks([0,0.05,0.1,0.15],[])
            elif which_region in ['3060N_Atl']:
                plt.yticks([0,0.1,0.2,0.3,0.4],[])
        elif which_error in ['nrmse']:
            if which_region in ['south_of_60S']:
                plt.yticks([0,25,50,75,100],[])
            elif which_region in ['3060N_Atl']:
                plt.yticks([0,100,200,300,400],[])
        if save_plots:
            filename = 'Violin_Temperature_below_2000m_2012_2017_'+which_error+'_norm_by_std_full_vs_subsampled_vs_eulerian_'+\
                    str(num_it)+'iterations_'+subregions2[rr]+'.eps'
            plt.savefig(savepath2+filename,dpi = dpicnt, bbox_inches='tight',format='eps')
            del filename
        if display_plots_eps:
            plt.show()
        else:
            plt.close(fig)

        #---
        # OXYGEN
        #---
        fig = plt.figure(figsize=(width2,height2))
        if which_error in ['nrmse']:
            plt.vlines(0.8,-2,220,color='darkgrey',linestyle='-',linewidth=1)
        #plt.plot(0.3,rmse_full_temp[ss],symbol_full,color=color2,markersize=ms,zorder=0,label='all floats')
        plt.hlines(nrmse_full_oxy[ss],xlim1,xlim2,color='black',linewidth=1.5,zorder=0,label='all floats')

        width1 = 0.06
        color_vio = 'darkblue'
        pp1=plt.violinplot(nrmse_sub_oxy[:,ss,0],[xpos[0]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9]) 
        pp2=plt.violinplot(nrmse_sub_oxy[:,ss,1],[xpos[1]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
        pp3=plt.violinplot(nrmse_sub_oxy[:,ss,2],[xpos[2]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
        for pc in pp1['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)
        for pc in pp2['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)
        for pc in pp3['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)

        plt.yticks(fontsize=fs)
        plt.xlim((xlim1,xlim2))
        if len(global_target_list)>3:
            plt.xticks([0,0.1,0.2,0.3],[],fontsize=fs)
        else: # only 3 values
            plt.xticks([0,0.1,0.2],[],fontsize=fs)
        if which_error in ['nrmse']:
            if which_region in ['south_of_60S']:
                plt.ylim((0,100))
            elif which_region in ['3060N_Atl']:
                plt.ylim((0,400))
        elif which_error in ['rmse']:
            ylim1,ylim2 = plt.gca().get_ylim()
            plt.ylim((0,ylim2))
        # OXYGEN
        if which_error in ['rmse']:
            if which_region in ['south_of_60S']:
                plt.yticks([0,1,2,3,4],[])
            elif which_region in ['3060N_Atl']:
                plt.yticks([0,2,4,6,8,10,12],[])
        elif which_error in ['nrmse']:
            if which_region in ['south_of_60S']:
                plt.yticks([0,25,50,75,100],[])
            elif which_region in ['3060N_Atl']:
                plt.yticks([0,100,200,300,400],[])
        if save_plots:
            filename = 'Violin_Oxygen_below_2000m_2012_2017_'+which_error+'_norm_by_std_full_vs_subsampled_vs_eulerian_'+\
                    str(num_it)+'iterations_'+subregions2[rr]+'.eps'
            plt.savefig(savepath2+filename,dpi = dpicnt, bbox_inches='tight',format='eps')
            del filename
        if display_plots_eps:
            plt.show()
        else:
            plt.close(fig)

        #---
        # Salinity
        #---
        fig = plt.figure(figsize=(width2,height2))
        if which_error in ['nrmse']:
            plt.vlines(0.8,-2,220,color='darkgrey',linestyle='-',linewidth=1)
        #plt.plot(0.3,rmse_full_temp[ss],symbol_full,color=color2,markersize=ms,zorder=0,label='all floats')
        plt.hlines(nrmse_full_salt[ss],xlim1,xlim2,color='black',linewidth=1.5,zorder=0,label='all floats')

        width1 = 0.06
        color_vio = 'darkblue'
        pp1=plt.violinplot(nrmse_sub_salt[:,ss,0],[xpos[0]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9]) 
        pp2=plt.violinplot(nrmse_sub_salt[:,ss,1],[xpos[1]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
        pp3=plt.violinplot(nrmse_sub_salt[:,ss,2],[xpos[2]],points=20,widths=width1,showmeans=False,
                         showextrema=False,showmedians=True,quantiles=[0.1, 0.9])
        for pc in pp1['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)
        for pc in pp2['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)
        for pc in pp3['bodies']:
            pc.set_facecolor(color_vio)
            pc.set_edgecolor(color_vio)

        plt.yticks(fontsize=fs)
        plt.xlim((xlim1,xlim2))
        if len(global_target_list)>3:
            plt.xticks([0,0.1,0.2,0.3],[],fontsize=fs)
        else: # only 3 values
            plt.xticks([0,0.1,0.2],[],fontsize=fs)
        if which_error in ['nrmse']:
            if which_region in ['south_of_60S']:
                plt.ylim((0,100))
            elif which_region in ['3060N_Atl']:
                plt.ylim((0,400))
        elif which_error in ['rmse']:
            ylim1,ylim2 = plt.gca().get_ylim()
            plt.ylim((0,ylim2))
        # SALT
        if which_error in ['rmse']:
            if which_region in ['south_of_60S']:
                plt.yticks([0,0.002,0.004,0.006,0.008],[])
            elif which_region in ['3060N_Atl']:
                plt.yticks([0,0.01,0.02,0.03,0.04],[])
        elif which_error in ['nrmse']:
            if which_region in ['south_of_60S']:
                plt.yticks([0,25,50,75,100],[])
            elif which_region in ['3060N_Atl']:
                plt.yticks([0,100,200,300,400],[])
        if save_plots:
            filename = 'Violin_Salinity_below_2000m_2012_2017_'+which_error+'_norm_by_std_full_vs_subsampled_vs_eulerian_'+\
                    str(num_it)+'iterations_'+subregions2[rr]+'.eps'
            plt.savefig(savepath2+filename,dpi = dpicnt, bbox_inches='tight',format='eps')
            del filename
        if display_plots_eps:
            plt.show()
        else:
            plt.close(fig)
        
print('done')
    
    