In [None]:
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import copy
import sys

sys.path.append('../src')
import utils as ut

from alldata import AllData
from ensemble import EnsembleMember

#np.set_printoptions(precision=2)
%matplotlib notebook
%matplotlib inline
%config InlineBackend.print_figure_kwargs={'bbox_inches':None}
%load_ext autoreload
%autoreload 2

In [None]:
#Read data
ad = AllData()
ds = ad.gather().drop_sel(ism='PISM_DMI')

In [None]:
year0 = 1871

ssp = '245'

bmsens = [4/(ad.K*ad.spy),7/(ad.K*ad.spy),16/(ad.K*ad.spy),18e-5,36e-5,72e-5]
bmps = ['lin','lin','lin','quad','quad','quad']

In [None]:
nyears = 41
time = ds.time[-nyears:]

VARwf = np.zeros((len(ds.esm)*len(ds.ism),len(bmsens),nyears,len(ds.basin)))
VARnf = np.zeros((len(ds.esm)*len(ds.ism),len(bmsens),nyears,len(ds.basin)))

c=0
d= -1
for e,esm in enumerate(ds.esm.values):
    for i,ism in enumerate(ds.ism.values):
        d+=1
        for bm,bms in enumerate(bmsens):
            c+=1
            ens = EnsembleMember(ds,ism=ism,esm=esm,ssp=ssp)
            ens.bmp = bmps[bm]
            ens.gamma[bmp] = bms*np.ones(5)
            ens.iterate()
            VARwf[d,bm,:,:] = 1000*(ens.SLR[-1,-nyears:,:]-ens.SLR[-1,-nyears,:])
            VARnf[d,bm,:,:] = 1000*(ens.SLR[0,-nyears:,:]-ens.SLR[0,-nyears,:])

            print(f'Got esm {e+1} of {len(ds.esm)} | ism {i+1} of {len(ds.ism)} | bmp {bm+1} of {len(bmsens)} | {100*c/(len(ds.esm)*len(ds.ism)*len(bmsens)):.0f}% ',end='           \r')

In [None]:
mpl.rcParams['figure.subplot.wspace'] = .1
mpl.rcParams['figure.subplot.left'] = .1
mpl.rcParams['figure.subplot.right'] = .99
mpl.rcParams['figure.figsize'] = (15,15)

In [None]:
rignot = [3,-2,10,-1,2]

In [None]:
fig,ax = plt.subplots(5,len(bmsens),sharex=True,sharey=True)

for bm,bms in enumerate(bmsens):
    for b,bas in enumerate(ds.basin.values):
        dax = ax[b,bm]
        dax.axhline(0,0,1,c='k',lw=.3,ls=':')

        dax.plot(time,np.median(VARnf[:,bm,:,b],axis=0),c='.5',lw=2)
        dax.fill_between(time,np.percentile(VARnf[:,bm,:,b],17,axis=0),np.percentile(VARnf[:,bm,:,b],83,axis=0),color='.5',alpha=.3)
        
        dax.plot(time,np.median(VARwf[:,bm,:,b],axis=0),c=ut.bcol[bas],lw=2)
        dax.fill_between(time,np.percentile(VARwf[:,bm,:,b],17,axis=0),np.percentile(VARwf[:,bm,:,b],83,axis=0),color=ut.bcol[bas],alpha=.3)

        dax.scatter(2017,rignot[b],200,marker='X',c='k')
    ax[0,bm].set_title(f'{bmps[bm]} {1e5*bms:.1f}e5 m/s')

for b,bas in enumerate(ds.basin.values):
    ax[b,0].set_ylabel(bas)
    
ax[0,0].set_xlim([1979,2020])
ax[0,0].set_ylim([-5,15])

fig.supylabel('Sea level rise [mm]')
plt.savefig(f'../draftfigs/calibration{ad.option}.png',dpi=450,facecolor='w',transparent=False)
plt.show()
ds.close()