In [4]:
import os 
from ipsl_dataset import IPSL_DCPP
import torch
import numpy as np
import hydra
from ipsl_dataset import surface_variables,plev_variables
from hydra import compose, initialize
from omegaconf import OmegaConf
with initialize(version_base=None, config_path="conf"):
    cfg = compose(config_name="config")
import matplotlib.pyplot as plt
checkpoint_folder = '20eg0mbx'


In [5]:
import datetime
def inc_time(batch_time):
    batch_time = datetime.datetime.strptime(batch_time,'%Y-%m')
    if(batch_time.month == 12):
        year = batch_time.year + 1
        month = 1
    else:
        year = batch_time.year
        month = batch_time.month + 1
    return f'{year}-{month}'
    

In [None]:
from ipsl_dataset import surface_variables
work = os.environ['WORK']
with_soil_checkpoint_5_year = f'{work}/ipsl_dcpp/ipsl_dcpp_emulation/{checkpoint_folder}/checkpoints/24_month_epoch=10.ckpt'

checkpoint_with_soil = torch.load(with_soil_checkpoint_5_year,map_location=torch.device('cpu'))
test = IPSL_DCPP('test',1)
test_dataloader = torch.utils.data.DataLoader(test,batch_size=1,shuffle=False,num_workers=1)

soil_model = hydra.utils.instantiate(cfg.experiment.module,backbone=hydra.utils.instantiate(cfg.experiment.backbone,soil=True,conv_head=True),dataset=test_dataloader.dataset)
soil_model.load_state_dict(checkpoint_with_soil['state_dict'])
land_mask = torch.tensor(np.load('data/land_mask.npy'))
#do rollout
import psutil
iter_ts = iter(test_dataloader)
surfaces = []
plevs = []
model_plevs = []
model_surfaces = []
for i in range(120):
    batch_actual = next(iter_ts)
    if(i == 0):
        batch = batch_actual
    model_surfaces.append(batch_actual['next_state_surface'])
    model_plevs.append(batch_actual['next_state_level'])
    print(batch['time'])

    print(psutil.virtual_memory().available * 100 / psutil.virtual_memory().total)
    with torch.no_grad():
        output = soil_model.forward(batch)
    #output['next_state_surface'][:,var_index] = torch.where(land_mask == 1,output['next_state_surface'][:,var_index],0)
    batch=dict(state_surface=output['next_state_surface'],
               state_level=output['next_state_level'],
               state_depth=output['next_state_depth'],
               state_constant=batch['state_constant'],
              time=[inc_time(batch['time'][0])])
    # output=None
    surfaces.append(output['next_state_surface'])
    plevs.append(output['next_state_level'])


  from .autonotebook import tqdm as notebook_tqdm
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


['2014-01']
88.19377284221835
['2014-2']
88.06500594314252
['2014-3']
88.02685731739295
['2014-4']
87.99896332597477
['2014-5']
87.92828360025923
['2014-6']
87.9370197089343
['2014-7']
87.90532430765167
['2014-8']
87.87758726452279
['2014-9']
87.84523757231979
['2014-10']
87.80788184117465
['2014-11']
87.77590515594494
['2014-12']
87.74821295518443
['2015-1']
87.71602632613916
['2015-2']
87.68796315669483
['2015-3']
87.65579283396532
['2015-4']
87.62706722044274
['2015-5']
87.59502530994995
['2015-6']
87.56730457313684
['2015-7']
87.54549283750175
['2015-8']
87.51777821555706
['2015-9']
87.48598701466926
['2015-10']
87.44772016813035
['2015-11']
87.4093840197494
['2015-12']
87.3715594820258
['2016-1']
87.33229591193529
['2016-2']
87.29421047315934
['2016-3']
87.25638185885681
['2016-4']
87.21806609337058
['2016-5']
87.1798603955158
['2016-6']
87.14155278318745
['2016-7']
87.10347142099045
['2016-8']
87.06551235616173
['2016-9']
87.02763482291188
['2016-10']
86.98944135479394
['2016-11'

In [None]:
surface_var_name = 'hurs'
plev_var_name = 'hur'
var_index = surface_variables.index(surface_var_name)
plev_var_index = plev_variables.index(plev_var_name)
ts = np.stack(surfaces).squeeze()[:,var_index].mean(axis=(1,2))
plev_ts = np.stack(plevs).squeeze()[:,plev_var_index].mean(axis=(2,3))

In [None]:
np.save('surface_rollout.npy',surfaces)
np.save('plevel_rollout.npy',plevs)

In [None]:
plev_model = np.stack(model_plevs).squeeze()[:,plev_var_index].mean(axis=(2,3))
ts_model = np.stack(model_surfaces).squeeze()[:,var_index].mean(axis=(1,2))

In [None]:
#plot surface variables
plt.plot(ts,label='predicted rollout')
plt.plot(ts_model,label='actual')
plt.legend()
plt.title(surface_var_name)
#print pressure levels
#plt.plot(plev_ts[:,-3],label='predicted rollout')
#plt.plot(plev_model[:,-3],label='actual')
#plt.legend()
#plt.title(plev_var_name)

In [None]:
#get data for power series
stacked_pred_surfaces = np.stack(surfaces)
stacked_model_surfaces = np.stack(model_surfaces)
pred_timeseries = stacked_pred_surfaces[:,:,var_index,100,100]
model_timeseries = stacked_model_surfaces[:,:,var_index,100,100]
#plot power series
import scipy
pred_power = scipy.signal.periodogram(pred_timeseries.T)
model_power = scipy.signal.periodogram(model_timeseries.T)
import matplotlib.pyplot as plt
plt.plot(pred_power[-1][0])
plt.plot(model_power[-1][0])

In [None]:
surface_var_name = 'ps'
plev_var_name = 'hurs'
var_index = surface_variables.index(surface_var_name)
predicted = np.stack(surfaces).squeeze()[:,var_index]
climate_model = np.stack(model_surfaces).squeeze()[:,var_index]

In [None]:
#gif of rollout
# Plotting ---
import matplotlib.pyplot as plt
#import seaborn
from celluloid import Camera

# seaborn.set_context("paper")
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(16, 6))
camera = Camera(fig)
ax1.set_title("predicted")
ax2.set_title("IPSL_CM6A")
import xarray as xr
ds = xr.open_dataset(test.files[0])
shell = ds.isel(time=0)

# Animate plot over time
for time_step in range(120):
    #ax1.plot(predicted[time_step])
    shell[surface_var_name].data = predicted[time_step]
    shell[surface_var_name].plot.pcolormesh(ax=ax1,add_colorbar=False)
    shell[surface_var_name].data = climate_model[time_step]

    shell[surface_var_name].plot.pcolormesh(ax=ax2,add_colorbar=False)
    #ax2.plot(climate_model[time_step])
    camera.snap()
anim = camera.animate()
anim.save(f"{surface_var_name}_{checkpoint_folder}_rollout.gif")