# Post-processing of FESOM-REcoM simulations for AGI toothfish project
# calculate pO2 and save as netcdf file
# on regular mesh: 0.25° x 0.125°
# calculate in-situ temp. and pO2 for each month, but then only store annual mean
# include drift correction

In [None]:
#get_ipython().system(u'jupyter nbconvert --to=python MASTER_toothfish_postprocessing_AGI_save_netcdf_files_monthly.ipynb')

In [None]:
### modules
import os
import sys
sys.path.append("pyfesom/") # add pyfesom to search path
sys.path.append("python_gsw_py3/") 
import pyfesom as pf
import numpy as np
from scipy.interpolate import griddata
import matplotlib
import matplotlib.pyplot as plt
from netCDF4 import Dataset
from time import localtime, gmtime, strftime
from netCDF4 import Dataset
from matplotlib import cm 
from numba import njit
from gsw import rho # rho from SA, CT, p
from gsw import pt0_from_t # potTemp from SA, t, p (at reference pressure 0)
from gsw import pt_from_t # potTemp from SA, t, p and reference pressure
from gsw import pot_rho_t_exact # potRho from SA, t, p and reference pressure
from gsw import p_from_z # get pressure from z and lat
from gsw import CT_from_pt # conservative temp from potTemp and SA
from gsw import rho_first_derivatives # first derivatives of rho with respect to SA, CT, p
from gsw2 import sigma0_pt0_exact # gsw_sigma0_pt0_exact(SA,pt0)
from gsw import SA_from_SP
import pandas as pd
import seawater as sw
from tqdm import tqdm

In [None]:
#-----
## simulations
#-----

mesh_id_src    = 'COARZE'  

#-----
## define time range -> for now, decadal avg monthly climatologies are produced
#-----
year1,year2 = 1980,1981#2018 # year1=start_year, year2=end_year
latlim = -45 # restrict interpolated output to south of this latitude

#-----
## path to data
#-----
meshpath = '/pscratch/sd/c/cnissen/mesh_COARZE/'


In [None]:
#-----
# load mesh file
#-----

mesh = pf.load_mesh(meshpath, get3d=True,usepickle=False)

path_mesh = '/pscratch/sd/c/cnissen/'
file_mesh = 'Nissen2022_FESOM_REcoM_mesh_information_corrected_20220910.nc'

f1 = Dataset(path_mesh+file_mesh) #xr.open_dataset(path+file1)
lats      = f1.variables['lat'][:]
lons      = f1.variables['lon'][:]
zlevs     = f1.variables['zlevs'][:]
cavities = f1.variables['cavity'][:]
topo     = f1.variables['topo'][:]
area_nodes     = f1.variables['cell_area'][:]
volume   = f1.variables['cell_volume'][:]
f1.close()
print(lats.shape)

ind_no_cavity = np.where(cavities==0)[0]
ind_cavities = np.where(cavities==1)[0]

#
df = pd.read_csv('/pscratch/sd/c/cnissen/HLRN_runs_postprocessed/nod3d.out', delim_whitespace=True, skiprows=1, \
                        names=['node_number','x','y','z','flag'])
lats3d    = df.y.values
lons3d    = df.x.values
zlevs3d = df.z.values
print (zlevs3d.shape)


In [None]:
#------
# FUNCTIONS
#------

def transform_lon_coord(data,grid_resolution_x):
    # change lon coordinate in 2D array from 0-360 to -180:180
    # for 2D arrays: assume lon coordinate to be the 2nd dimension  
    if grid_resolution_x==1.0: 
        num_lon = 360/2
    elif grid_resolution_x==0.5: 
        num_lon = 720/2
    elif grid_resolution_x==0.25: 
        num_lon = 1440/2
    elif grid_resolution_x==0.125: 
        num_lon = 2880/2
    #print num_lon
        
    if len(data.shape)==2:
        data_transformed          = np.empty_like(data)
        #data_transformed[0:180,:] = data[180:,:]
        #data_transformed[180:,:]  = data[0:180,:] 
        data_transformed[:,0:num_lon] = data[:,num_lon:]
        data_transformed[:,num_lon:]  = data[:,0:num_lon]   
    elif len(data.shape)==1:
        data_transformed          = np.empty_like(data)
        data_transformed[0:num_lon] = data[num_lon:]
        data_transformed[num_lon:]  = data[0:num_lon]
    return data_transformed

@njit
def reorganize_field_in_cavities(ind_cavities,data): 
    for ii in ind_cavities: #ind_cavities: 
        bb = data[:,ii] # get all depth levels at current cavity node
        ind_av = np.where(bb>0)[0] # get indices of all depth levels that are NOT masked
        #ind_av = bb>=0 #bb.mask==False
        #nd_av = np.where(ind_av==True)[0]
        #print ind_av
        # if surface value is filled, but thereafter there is a gap: 
        if len(ind_av)>1:
            if (ind_av[1]-ind_av[0])>1:  #any(np.diff(ind_av)>1):  
                bb[ind_av[1]-1]=bb[ind_av[0]] # move "surface" value to correct depth
                bb[ind_av[0]] = 0 # set surface entry to zero
               
        data[:,ii] = bb # overwrite original field
    return data


In [None]:
#-------
# 1) surface area of cells of interpolated grid ('area')
# 2) volume of cells of interpolated grid ('volume')
#-------

#meshfile_interp = '/work/ollie/ncara/fesom/fesom-1.4-recom/HLRN/files_toothfish_project_AGI/'+\
#        'Mesh_ancillary_information_v20220725.nc'
meshfile_interp = '/pscratch/sd/c/cnissen/files_toothfish_project_AGI/'+\
        'Mesh_ancillary_information_v20220725.nc'
file1 = Dataset(meshfile_interp)
    
grid_resolution_x = 0.25 #0.5,0.25 # define grid resolution of regular mesh
grid_resolution_y = 0.125 #0.5,0.0625

xi = file1.variables['lon'][:] # longitude
yi = file1.variables['lat'][:] # latitude
depths = file1.variables['depth'][:] # depth
x_all,y_all = np.meshgrid(xi,yi) #xi[:-1],yi[:-1])
print (x_all.shape)

area = file1.variables['area'][:,:]
mask_vol = file1.variables['mask_vol'][:,:,:] # mask; depth x lat x lon
mask_sfc = file1.variables['mask_sfc'][:,:] # mask; lat x lon
file1.close()
    
print ('CHECK')
print ('Southern Ocean surface area ('+str(grid_resolution_x)+'° x '+str(grid_resolution_y)+\
            '° mesh): '+str(np.sum(area[mask_sfc==1]))+' m2')
indSO_native = np.where(mesh.y2<=-50)[0]
print ('Southern Ocean surface area (native mesh): '+\
       str(np.sum(area_nodes[indSO_native]))+' m2 (ocean only)')




In [None]:
#-----
# constants for calculation of pO2
#-----

# convert O2 from mmol m-3 to mol kg-1
# convert T_pot to T_insitu

ref_pressure = 0

#-----
# constants
#-----
xO2 = 0.20946 # O2 mole fraction dry air
Vm = 0.317 # partial molar volume O2 (m3 mol-1 Pa dbar-1)
R = 8.3145 # gas constant in m3 Pa K-1 mol -1
KC = 273.15 # deg C to Kelvin

A0 = 5.80871
A1 = 3.20291
A2 = 4.17887
A3 = 5.10006
A4 = -0.0986643
A5 = 3.80369
B0 = -0.00701577
B1 = -0.00770028
B2 = -0.0113864
B3 = -0.00951519
C0 = -0.000000275915
D0 = 24.4543
D1 = -67.4509
D2 = -4.8489
D3 = -5.44*1e-4


In [None]:
#----
# load files from simB
#----

year_string1 = '1995_2014'
year_string2 = '2081_2100' # 2081_2100, 2091_2100, 2098_2100  #'2090_2099' # '2081_2100'

path_B = '/pscratch/sd/c/cnissen/files_toothfish_project_AGI/simB_monthly_drift/'

# load historical and future, compute difference
ff = Dataset(path_B+'/pO2_fesom_simB_avg_'+year_string1+'_monthly.nc')
pO2_B_hist = ff.variables['pO2'][:,:,:,:] # month x depth x lat x lon
ff.close()
ff = Dataset(path_B+'/pO2_fesom_simB_avg_'+year_string2+'_monthly.nc')
pO2_B_future = ff.variables['pO2'][:,:,:,:] # month x depth x lat x lon
ff.close()
pO2_B      = pO2_B_future - pO2_B_hist 
print(pO2_B.shape)
print('Min/Max pO2 historical:',np.min(pO2_B_hist),np.max(pO2_B_hist))
print('Min/Max pO2 future:',np.min(pO2_B_future),np.max(pO2_B_future))
print('Min/Max pO2 difference:',np.min(pO2_B),np.max(pO2_B))
del pO2_B_future,pO2_B_hist

ff = Dataset(path_B+'/oxygen_fesom_simB_avg_'+year_string1+'_monthly.nc')
oxygen_B_hist = ff.variables['oxygen'][:,:,:,:] # month x depth x lat x lon
ff.close()
ff = Dataset(path_B+'/oxygen_fesom_simB_avg_'+year_string2+'_monthly.nc')
oxygen_B_future = ff.variables['oxygen'][:,:,:,:] # month x depth x lat x lon
ff.close()
oxygen_B   = oxygen_B_future - oxygen_B_hist
del oxygen_B_future,oxygen_B_hist

ff = Dataset(path_B+'/t_insitu_fesom_simB_avg_'+year_string1+'_monthly.nc')
t_insitu_B_hist = ff.variables['t_insitu'][:,:,:,:] # month x depth x lat x lon
ff.close()
ff = Dataset(path_B+'/t_insitu_fesom_simB_avg_'+year_string2+'_monthly.nc')
t_insitu_B_future = ff.variables['t_insitu'][:,:,:,:] # month x depth x lat x lon
ff.close()
t_insitu_B = t_insitu_B_future - t_insitu_B_hist 
del t_insitu_B_future,t_insitu_B_hist

ff = Dataset(path_B+'/salt_fesom_simB_avg_'+year_string1+'_monthly.nc')
salt_B_hist = ff.variables['salt'][:,:,:,:] # month x depth x lat x lon
ff.close()
ff = Dataset(path_B+'/salt_fesom_simB_avg_'+year_string2+'_monthly.nc')
salt_B_future = ff.variables['salt'][:,:,:,:] # month x depth x lat x lon
ff.close()
salt_B     = salt_B_future - salt_B_hist 
del salt_B_future,salt_B_hist


In [None]:
#----
# prepare loop over all data
#----

year_list = np.arange(int(year_string2[0:4]),int(year_string2[5:])+1,1) # for now, only process the 10 years I am currently analyzing in paper
simulation = 'ssp585' #'simAssp585'

note_data = 'FESOM-REcoM simulation, Cara Nissen, December 2023, cara.nissen@awi.de or cara.nissen@colorado.edu'

var_name1 = 'thetao'
var_name2 = 'so'
var_name3 = 'bgc22'

path = '/pscratch/sd/c/cnissen/'
savepath = '/pscratch/sd/c/cnissen/files_toothfish_project_AGI/'+simulation+\
                '_monthly_drift_corrected_'+year_string2+'_minus_'+year_string1+'/'
# check existence of paths
if not os.path.exists(savepath):
    print ('Created '+savepath)
    os.makedirs(savepath)
    
fill_value = -9999

yi_SO = yi #[:-1]
ind_SO = np.where(yi_SO<=latlim)[0]
yi_SO = yi_SO[ind_SO]

lat_3D    = np.tile(y_all,[depths.shape[0],1,1]) # (99, 91, 720) depth, lat, lon
depths_3D = np.tile(depths,[y_all.shape[0],y_all.shape[1],1])
depths_3D = np.transpose(depths_3D,[2,0,1]) # (99, 91, 720) depth, lat, lon

# option to turn on/off the storage of additional fields:
save_oxygen = False
save_salinity = False

pres_3D  = sw.eos80.pres(depths_3D,lat_3D) # (99, 91, 720) depth, lat, lon

for yy in range(19,len(year_list)):
    year = year_list[yy]
    print ('')
    print ('Process year',year)
    
    path1 = path+'COARZE_temp/'+simulation+'/'
    path2 = path+'COARZE_salt/'+simulation+'/'
    path3 = path+'COARZE_oxy/'+simulation+'/'
    
    # temp
    ff = Dataset(path1+'/'+var_name1+'_fesom_'+str(year)+'0101.nc')
    data1_all = ff.variables[var_name1][:,:,:] #.mean(axis=0)
    ff.close()
    # salt
    ff = Dataset(path2+'/'+var_name2+'_fesom_'+str(year)+'0101.nc')
    data2_all = ff.variables[var_name2][:,:,:] #.mean(axis=0)
    ff.close()
    # oxygen
    ff = Dataset(path3+'/'+var_name3+'_fesom_'+str(year)+'0101.nc')
    data3_all = ff.variables[var_name3][:,:,:] #.mean(axis=0)
    ff.close()
    
    #----
    # create netcdf file here (1 file per year and variable)
    #----
    # variables: t_insitu, pO2
    # dimensions: depth x lon x lat
    save_as_netcdf = True
    if save_as_netcdf:
        if save_oxygen:
            #----
            # oxygen in mol kg-1
            #----
            netcdf_name = 'oxygen_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+year_string2+'_minus_'+year_string1+'.nc'
            if not os.path.exists(savepath+netcdf_name):
                print ('Create file '+savepath+netcdf_name)
                w_nc_fid = Dataset(savepath+netcdf_name, 'w', format='NETCDF4_CLASSIC')
                # create dimension & variable
                w_nc_fid.createDimension('lon', len(xi)) 
                w_nc_fid.createDimension('lat', len(yi_SO))  
                w_nc_fid.createDimension('month', 12) # monthly  
                w_nc_fid.createDimension('depth', len(np.unique(mesh.zlevs)))
                w_nc_var2 = w_nc_fid.createVariable('lat', 'f4',('lat'))
                w_nc_var3 = w_nc_fid.createVariable('lon', 'f4',('lon'))
                w_nc_var4 = w_nc_fid.createVariable('depth', 'f4',('depth'))
                w_nc_var2.unit = 'degrees N'
                w_nc_var3.unit = 'degrees E'
                w_nc_var4.unit = 'm'
                w_nc_var5 = w_nc_fid.createVariable('oxygen', 'f4',('month','depth','lat','lon'),fill_value=fill_value)
                w_nc_var5.description = 'oxygen concentration'
                w_nc_var5.unit = 'mol kg-1'
                w_nc_var5.note = 'converted from mmol m-3 using modeled in-situ density'
                w_nc_fid.variables['lat'][:] = yi_SO
                w_nc_fid.variables['lon'][:] = xi
                w_nc_fid.variables['depth'][:] = depths
                w_nc_fid.note = note_data
                w_nc_fid.original_file = path2+var_name3+'_fesom_'+str(year)+'0101.nc'
                w_nc_fid.drift_corr_path = path_B
                w_nc_fid.drift_corr_files = 'oxygen_fesom_simB_avg_'+\
                                year_string2+'_monthly.nc - oxgen_fesom_simB_avg_'+year_string1+'_monthly.nc'
                w_nc_fid.script = '/home/ollie/ncara/scripts/MASTER_toothfish_postprocessing_AGI_save_netcdf_files_monthly.ipynb'
                w_nc_fid.close()
        #----
        # pO2
        #----
        netcdf_name = 'pO2_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+year_string2+'_minus_'+year_string1+'.nc'
        if not os.path.exists(savepath+netcdf_name):
            print ('Create file '+savepath+netcdf_name)
            w_nc_fid = Dataset(savepath+netcdf_name, 'w', format='NETCDF4_CLASSIC')
            # create dimension & variable
            w_nc_fid.createDimension('lon', len(xi)) 
            w_nc_fid.createDimension('lat', len(yi_SO))  
            w_nc_fid.createDimension('month', 12) # monthly  
            w_nc_fid.createDimension('depth', len(np.unique(mesh.zlevs)))
            w_nc_var2 = w_nc_fid.createVariable('lat', 'f4',('lat'))
            w_nc_var3 = w_nc_fid.createVariable('lon', 'f4',('lon'))
            w_nc_var4 = w_nc_fid.createVariable('depth', 'f4',('depth'))
            w_nc_var2.unit = 'degrees N'
            w_nc_var3.unit = 'degrees E'
            w_nc_var4.unit = 'm'
            w_nc_var5 = w_nc_fid.createVariable('pO2', 'f4',('month','depth','lat','lon'),fill_value=fill_value)
            w_nc_var5.description = 'partial pressure of oxygen'
            w_nc_var5.note = 'calculated from monthly mean model output of T, S, and O2'
            w_nc_var5.unit = 'mbar'
            w_nc_fid.variables['lat'][:] = yi_SO
            w_nc_fid.variables['lon'][:] = xi
            w_nc_fid.variables['depth'][:] = depths
            w_nc_fid.note = note_data
            w_nc_fid.original_file = path2+var_name3+'_fesom_'+str(year)+'0101.nc'
            w_nc_fid.drift_corr_path = path_B
            w_nc_fid.drift_corr_files = 'pO2_fesom_simB_avg_'+year_string2+'_monthly.nc - pO2_fesom_simB_avg_'+year_string1+'_monthly.nc'    
            w_nc_fid.script = '/home/ollie/ncara/scripts/MASTER_toothfish_postprocessing_AGI_save_netcdf_files_monthly.ipynb'
            w_nc_fid.close()
        #else:
        #    print 'File '+savepath+netcdf_name+' exists already, overwrite'
        
        #----
        # in situ temp
        #----
        netcdf_name = 't_insitu_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+year_string2+'_minus_'+year_string1+'.nc'
        if not os.path.exists(savepath+netcdf_name):
            print ('Create file '+savepath+netcdf_name)
            w_nc_fid = Dataset(savepath+netcdf_name, 'w', format='NETCDF4_CLASSIC')
            # create dimension & variable
            w_nc_fid.createDimension('lon', len(xi)) 
            w_nc_fid.createDimension('lat', len(yi_SO))  
            w_nc_fid.createDimension('month', 12) # monthly  
            w_nc_fid.createDimension('depth', len(np.unique(mesh.zlevs)))
            w_nc_var2 = w_nc_fid.createVariable('lat', 'f4',('lat'))
            w_nc_var3 = w_nc_fid.createVariable('lon', 'f4',('lon'))
            w_nc_var4 = w_nc_fid.createVariable('depth', 'f4',('depth'))
            w_nc_var2.unit = 'degrees N'
            w_nc_var3.unit = 'degrees E'
            w_nc_var4.unit = 'm'
            w_nc_var5 = w_nc_fid.createVariable('t_insitu', 'f4',('month','depth','lat','lon'),fill_value=fill_value)
            w_nc_var5.description = 'insitu temperature'
            w_nc_var5.unit = 'deg C'
            w_nc_fid.variables['lat'][:] = yi_SO
            w_nc_fid.variables['lon'][:] = xi
            w_nc_fid.variables['depth'][:] = depths
            w_nc_fid.note = note_data
            w_nc_fid.original_file = path1+var_name1+'_fesom_'+str(year)+'0101.nc'
            w_nc_fid.drift_corr_path = path_B
            w_nc_fid.drift_corr_files = 't_insitu_fesom_simB_avg_'+year_string2+\
                                '_monthly.nc - t_insitu_fesom_simB_avg_'+year_string1+'_monthly.nc'
            w_nc_fid.script = '/home/ollie/ncara/scripts/MASTER_toothfish_postprocessing_AGI_save_netcdf_files_monthly.ipynb'
            w_nc_fid.close()
        #else:
        #    print 'File '+savepath+netcdf_name+' exists already, overwrite'
        
        if save_salinity:
            #----
            # salinity
            #----
            netcdf_name = 'salt_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+year_string2+'_minus_'+year_string1+'.nc'
            if not os.path.exists(savepath+netcdf_name):
                print ('Create file '+savepath+netcdf_name)
                w_nc_fid = Dataset(savepath+netcdf_name, 'w', format='NETCDF4_CLASSIC')
                # create dimension & variable
                w_nc_fid.createDimension('lon', len(xi)) 
                w_nc_fid.createDimension('lat', len(yi_SO))  
                w_nc_fid.createDimension('month', 12) # monthly  
                w_nc_fid.createDimension('depth', len(np.unique(mesh.zlevs)))
                w_nc_var2 = w_nc_fid.createVariable('lat', 'f4',('lat'))
                w_nc_var3 = w_nc_fid.createVariable('lon', 'f4',('lon'))
                w_nc_var4 = w_nc_fid.createVariable('depth', 'f4',('depth'))
                w_nc_var2.unit = 'degrees N'
                w_nc_var3.unit = 'degrees E'
                w_nc_var4.unit = 'm'
                w_nc_var5 = w_nc_fid.createVariable('salt', 'f4',('month','depth','lat','lon'),fill_value=fill_value)
                w_nc_var5.description = 'practical salinity'
                w_nc_var5.unit = 'psu'
                w_nc_fid.variables['lat'][:] = yi_SO
                w_nc_fid.variables['lon'][:] = xi
                w_nc_fid.variables['depth'][:] = depths
                w_nc_fid.note = note_data
                w_nc_fid.original_file = path1+var_name2+'_fesom_'+str(year)+'0101.nc'
                w_nc_fid.drift_corr_path = path_B
                w_nc_fid.drift_corr_files = 'salt_fesom_simB_avg_'+year_string2+'_monthly.nc - salt_fesom_simB_avg_'+year_string1+'_monthly.nc'
                w_nc_fid.script = '/home/ollie/ncara/scripts/MASTER_toothfish_postprocessing_AGI_save_netcdf_files_monthly.ipynb'
                w_nc_fid.close()

    for kk in range(0,len(depths)):
        print ('Depth level: '+str(depths[kk])+'m')
        
        pO2_all_months      = fill_value*np.ones([len(yi_SO),len(xi),12])
        O2_molkg_all_months = fill_value*np.ones([len(yi_SO),len(xi),12])
        t_insitu_all_months = fill_value*np.ones([len(yi_SO),len(xi),12])
        salt_all_months     = fill_value*np.ones([len(yi_SO),len(xi),12])

        for mm in range(0,12): # calculate pO2 etc for each month, but then only store the annual mean
            print ('Month: '+str(mm+1)+'...')
            
            data1 = data1_all[mm,:,:]
            data2 = data2_all[mm,:,:]
            data3 = data3_all[mm,:,:]
            # set masked values to 0 to get correction within cavity correct
            # (if I don't do that, masked and not-masked values are not correctly recognized with njit)
            data1[data1.mask==True]=0 
            data1 = reorganize_field_in_cavities(ind_cavities,data1) 
            data2[data2.mask==True]=0 
            data2 = reorganize_field_in_cavities(ind_cavities,data2) 
            data3[data3.mask==True]=0 
            data3 = reorganize_field_in_cavities(ind_cavities,data3)    
            
            data1[data1==0]=np.nan
            data2[data2==0]=np.nan
            data3[data3==0]=np.nan
    
            #-----
            # TEMP
            #-----
            data = np.copy(data1[kk,:])
            data_int = griddata((mesh.x2, mesh.y2),data.ravel(),\
                            (x_all.ravel(), y_all.ravel()), method='linear')  # Final interpolated field
            del data
            data_int = data_int.reshape((x_all.shape[0],x_all.shape[1])) 
         #   data_int = transform_lon_coord(data_int,grid_resolution_x) # NOTE: ...
            # I think this transformation is only needed because "mask_vol" is wrongly rotated!
            # Confirmed: after correcting mask_vol in ancillary file, I do not need the transformation anymore!
            #data_int[mask==0]=-99999 # check if it should be done after or before lon-transformation
            data_int[np.isnan(data_int)] = -99999
            data_int[data_int==0]=-99999
            data_int[data_int>50]=-99999
            data_int_temp = np.copy(data_int)
            data_int_temp[mask_vol[kk,:,:]==0]=-99999
            del data_int

            #-----
            # test plot
            #-----
            test_plot = False
            if test_plot:
                # note that lon is still organized the wrong way when plotting here
                levels1 = np.arange(-3,10.5,0.5)
                levels2 = np.arange(-1,3.5,0.5)
                unit1 = 'deg C'
                ind_plot = 300#80#600
                dpicnt = 150
                colormap = cm.Spectral_r
                plt.figure(figsize=(6,4), dpi = dpicnt,facecolor='w', edgecolor='k') 
                plt.contourf(x_all[0:ind_plot,:],y_all[0:ind_plot,:],data_int_temp[0:ind_plot,:],\
                             levels=levels1,cmap=colormap, extend='both')
                cbar=plt.colorbar(orientation='horizontal',shrink=0.8,pad=0.1)
                cbar.ax.set_xlabel(unit1) 
                cbar.cmap.set_under('grey')
                plt.title('Temperature at '+str(mesh.zlevs[kk])+'m',fontsize=10,fontweight='bold')
                plt.show()
                ## mask
                plt.figure(figsize=(6,4), dpi = dpicnt,facecolor='w', edgecolor='k') 
                plt.contourf(x_all[0:ind_plot,:],y_all[0:ind_plot,:],mask_vol[kk,0:ind_plot,:],\
                             levels=levels2,cmap=colormap, extend='both')
                cbar=plt.colorbar(orientation='horizontal',shrink=0.8,pad=0.1)
                cbar.ax.set_xlabel('') 
                cbar.cmap.set_under('grey')
                plt.title('Mask at '+str(mesh.zlevs[kk])+'m',fontsize=10,fontweight='bold')
                plt.show()

            #-----
            # SALINITY
            #-----
            data = np.copy(data2[kk,:])
            data_int = griddata((mesh.x2, mesh.y2),data.ravel(),\
                            (x_all.ravel(), y_all.ravel()), method='linear')  # Final interpolated field
            del data
            data_int = data_int.reshape((x_all.shape[0],x_all.shape[1])) 
          #  data_int = transform_lon_coord(data_int,grid_resolution_x) # see note above for temp
            #data_int[mask==0]=-99999 # check if it should be done after or before lon-transformation
            data_int[np.isnan(data_int)] = -99999
            data_int[data_int==0]=-99999
            data_int[data_int>50]=-99999
            data_int_salt = np.copy(data_int)
            data_int_salt[mask_vol[kk,:,:]==0]=-99999
            del data_int
            #-----
            # OXYGEN
            #-----
            data = np.copy(data3[kk,:])
            data_int = griddata((mesh.x2, mesh.y2),data.ravel(),\
                            (x_all.ravel(), y_all.ravel()), method='linear')  # Final interpolated field
            del data
            data_int = data_int.reshape((x_all.shape[0],x_all.shape[1])) 
          #  data_int = transform_lon_coord(data_int,grid_resolution_x) # see note above for temp!
            #data_int[mask==0]=-99999 # check if it should be done after or before lon-transformation
            data_int[np.isnan(data_int)] = -99999
            data_int[data_int==0]=-99999
            data_int_oxygen = np.copy(data_int)
            data_int_oxygen[mask_vol[kk,:,:]==0]=-99999
            O2_mmol_per_m3 = np.copy(data_int_oxygen)
            del data_int

            #print data_int_temp.shape

            #print 'Min/Max pot. temp:',np.nanmin(data_int_temp),np.nanmax(data_int_temp)
            #print 'Min/Max salinty:',np.nanmin(data_int_salt),np.nanmax(data_int_salt)
            #print 'Min/Max oxy:',np.nanmin(data_int_oxygen),np.nanmax(data_int_oxygen)

            if np.mod(kk,20)==0:
                print ('Max temp,salt,oxy:',\
                       np.nanmax(data_int_temp),np.nanmax(data_int_salt),np.nanmax(data_int_oxygen))

            #-----
            # calculate pO2
            #-----

            # get in situ temp
            t_insitu = sw.eos80.temp(data_int_salt,data_int_temp,pres_3D[kk,:,:],ref_pressure)
            t_insitu[data_int_temp<-99]=np.nan
        
            # DRIFT correction
            t_insitu = t_insitu - t_insitu_B[mm,kk,:,:]

            # get in situ density
            rho    = sw.dens(data_int_salt,t_insitu,pres_3D[kk,:,:])

            # include DRIFT correction
            O2_molkg = (O2_mmol_per_m3*0.001/rho) -  oxygen_B[mm,kk,:,:] #1035.0  # convert model O2 to mol kg-1 (using in situ or potential density?)

            # scaled temp
            T_scaled = np.log((298.15-t_insitu)/(KC+t_insitu)) # np.log -> ln(), np.log10 -> log10

            # saturation concentration of O2 in seawater
            A = A0 + A1*T_scaled + A2*T_scaled*T_scaled + A3*T_scaled*T_scaled*T_scaled+\
                    A4*T_scaled*T_scaled*T_scaled*T_scaled +\
                    A5*T_scaled*T_scaled*T_scaled*T_scaled*T_scaled +\
                    data_int_salt*(B0+B1*T_scaled+B2*T_scaled*T_scaled+B3*T_scaled*T_scaled*T_scaled)+\
                    C0*data_int_salt*data_int_salt
            c_sat_o2 = 1e-6*np.exp(A)

            # water vapour pressure
            pH2O = 1013.25 * np.exp(D0 + D1*(100./(t_insitu+KC))+D2*np.log((t_insitu+KC)/100.)+D3*data_int_salt)

            K0 = c_sat_o2/(xO2 * (1013.25 - pH2O))
            pO2 = (O2_molkg/K0)*np.exp((Vm*pres_3D[kk,:,:])/(R*(KC+t_insitu)))

            t_insitu[np.isnan(pO2)]=np.nan
            data_int_salt[np.isnan(pO2)]=np.nan

            if np.mod(kk,20)==0:
                print ('Min/Max pO2:',np.nanmin(pO2),np.nanmax(pO2))
                print ('Min/Max t insitu:',np.nanmin(t_insitu),np.nanmax(t_insitu))

            t_insitu[np.isnan(t_insitu)]=fill_value
            pO2[np.isnan(pO2)]=fill_value
            data_int_salt[np.isnan(data_int_salt)]=fill_value
            
            pO2_all_months[:,:,mm]      = pO2
            O2_molkg_all_months[:,:,mm] = O2_molkg
            t_insitu_all_months[:,:,mm] = t_insitu
            salt_all_months[:,:,mm]     = data_int_salt
            
            del t_insitu,pO2,data_int_salt
            del data_int_temp,data_int_oxygen,rho
            del T_scaled,A,c_sat_o2,pH2O,K0
            del O2_mmol_per_m3
           # del data1,data2,data3
        
            #-----
            # fields to save: t_insitu, pO2, O2 in mol/kg, practical salinity
            #-----
            if save_as_netcdf:
                if save_oxygen:
                    #----
                    # oxygen
                    #----
                    netcdf_name = 'oxygen_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+\
                                year_string2+'_minus_'+year_string1+'.nc'
                    w_nc_fid = Dataset(savepath+netcdf_name, 'r+', format='NETCDF4_CLASSIC')# Create and open new netcdf file to write to
                    w_nc_fid.variables['oxygen'][mm,kk,:,:] = O2_molkg_all_months[:,:,mm] #- oxygen_B[mm,kk,:,:]
                    w_nc_fid.close()
                #----
                # pO2
                #----
                netcdf_name = 'pO2_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+\
                            year_string2+'_minus_'+year_string1+'.nc'
                w_nc_fid = Dataset(savepath+netcdf_name, 'r+', format='NETCDF4_CLASSIC')# Create and open new netcdf file to write to
                w_nc_fid.variables['pO2'][mm,kk,:,:] = pO2_all_months[:,:,mm] #- pO2_B[mm,kk,:,:]
                w_nc_fid.close()
                #----
                # in situ temp
                #----
                netcdf_name = 't_insitu_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+\
                            year_string2+'_minus_'+year_string1+'.nc'
                w_nc_fid = Dataset(savepath+netcdf_name, 'r+', format='NETCDF4_CLASSIC')# Create and open new netcdf file to write to
                w_nc_fid.variables['t_insitu'][mm,kk,:,:] = t_insitu_all_months[:,:,mm] #- t_insitu_B[mm,kk,:,:]
                w_nc_fid.close()
                if save_salinity:
                    #----
                    # salinity
                    #----
                    netcdf_name = 'salt_fesom_'+simulation+'_'+str(year)+'0101_monthly_drift_corrected_'+\
                                year_string2+'_minus_'+year_string1+'.nc'
                    w_nc_fid = Dataset(savepath+netcdf_name, 'r+', format='NETCDF4_CLASSIC')# Create and open new netcdf file to write to
                    w_nc_fid.variables['salt'][mm,kk,:,:] = salt_all_months[:,:,mm] #- salt_B[mm,kk,:,:]
                    w_nc_fid.close()
                print ('Successfully written depth level '+str(kk+1)+'/'+str(len(depths))+\
                       ' of year '+str(year)+' to file')
            #print ''
        
        del t_insitu_all_months,O2_molkg_all_months
        del salt_all_months,pO2_all_months
    
# DONE loop over each year
# DONE calculate annual mean
# DONE loop over each depth level
# DONE regrid to regular mesh
# DONE calculate pO2 and t_insitu
# DONE save as netcdf

print ('done')
