# Copyright Netherlands eScience Center <br>
** Function     : Predict the Spatial Sea Ice Concentration with ConvLSTM at weekly time scale** <br>
** Author       : Yang Liu ** <br>
** First Built  : 2019.05.21 ** <br>
** Last Update  : 2019.07.21 ** <br>
** Library      : Pytorth, Numpy, NetCDF4, os, iris, cartopy, deepclim, matplotlib **<br>
Description     : This notebook serves to predict the Arctic sea ice using deep learning. We also include many climate index (to represent the forcing from atmosphere). The convolutional Long Short Time Memory neural network is used to deal with this spatial-temporal sequence problem. We use Pytorch as the deep learning framework. <br>
<br>
** Here we predict sea ice concentration with one extra relevant field from either ocean or atmosphere to test the predictor.** <br>

Return Values   : Time series and figures <br>

Here is the information of climate index in details:<br>
**NAO** @ 1950 Jan - 2018 Feb (818 records) <br>
http://www.cpc.ncep.noaa.gov/products/precip/CWlink/pna/nao.shtml <br>
**ENSO - NINO 3.4 SST** @ 1950 Jan - 2018 Jan (817 records) <br>
https://www.esrl.noaa.gov/psd/gcos_wgsp/Timeseries/Nino34/ <br>
**AO** @ 1950 Jan - 2018 Feb (818 records) <br>
http://www.cpc.ncep.noaa.gov/products/precip/CWlink/daily_ao_index/ao.shtml <br>
**AMO** @ 1950 Jan - 2018 Feb (818 records) <br>
AMO unsmoothed, detrended from the Kaplan SST V2. The result is standarised.<br>
https://www.esrl.noaa.gov/psd/data/timeseries/AMO/<br>

(All the NOAA index shown above are given by NCEP/NCAR Reanalysis (CDAS))<br>

**PDO** @ 1900 Jan - 2018 Feb (1418 records)<br>
This PDO index comes from University of Washington, it contains SST data from the following 3 datasets:<br>
- UKMO Historical SST data set for 1900-81;
- Reynold's Optimally Interpolated SST (V1) for January 1982-Dec 2001)
- OI SST Version 2 (V2) beginning January 2002 -<br>

http://research.jisao.washington.edu/pdo/PDO.latest<br>

The regionalization adopted here follows that of the MASIE (Multisensor Analyzed Sea Ice Extent) product available from the National Snow and Ice Data Center:<br>
https://nsidc.org/data/masie/browse_regions<br>
It is given by paper J.Walsh et. al., 2019. Benchmark seasonal prediction skill estimates based on regional indices.<br>

The method comes from the study by Shi et. al. (2015) Convolutional LSTM Network: A Machine Learning Approach for Precipitation Nowcasting. <br>

In [1]:
%matplotlib inline

import sys
import numbers

# for data loading
import os
from netCDF4 import Dataset
# for pre-processing and machine learning
import numpy as np
import sklearn
#import scipy
import torch
import torch.nn.functional

sys.path.append(os.path.join('/home/ESLT0068/NLeSC/Computation_Modeling/ML4Climate/Scripts/DeepClim'))
#sys.path.append("C:\\Users\\nosta\\ML4Climate\\Scripts\\DeepClim")
import deepclim
import deepclim.preprocess
import deepclim.deepSeries
import deepclim.deepArray_GPU
#import deepclim.deepArrayStep
#import deepclim.function

# for visualization
import deepclim.visual
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import iris # also helps with regriding
import iris.plot as iplt
import cartopy
import cartopy.crs as ccrs

# for animation
import imageio
import matplotlib.image as mgimg
from matplotlib import animation

The testing device is Dell Inspirion 5680 with Intel Core i7-8700 x64 CPU and Nvidia GTX 1060 6GB GPU.<br>
Here is a benchmark about cpu v.s. gtx 1060 <br>
https://www.analyticsindiamag.com/deep-learning-tensorflow-benchmark-intel-i5-4210u-vs-geforce-nvidia-1060-6gb/

In [2]:
# constants
constant = {'g' : 9.80616,      # gravititional acceleration [m / s2]
            'R' : 6371009,      # radius of the earth [m]
            'cp': 1004.64,      # heat capacity of air [J/(Kg*K)]
            'Lv': 2264670,      # Latent heat of vaporization [J/Kg]
            'R_dry' : 286.9,    # gas constant of dry air [J/(kg*K)]
            'R_vap' : 461.5,    # gas constant for water vapour [J/(kg*K)]
            'rho' : 1026,       # sea water density [kg/m3]
            }

** Data ** <br>
Time span of each product included: <br>
** Reanalysis ** <br>
- **ERA-Interim** 1979 - 2016 (ECMWF)
- **ORAS4**       1958 - 2014 (ECMWF)

** Index ** <br>
- **NINO3.4**     1950 - 2017 (NOAA)
- **AO**          1950 - 2017 (NOAA)
- **NAO**         1950 - 2017 (NOAA)
- **AMO**         1950 - 2017 (NOAA)
- **PDO**         1950 - 2017 (University of Washington)

!! These index are given by NCEP/NCAR Reanalysis (CDAS) <br>


Alternative (not in use yet) <br>
** Reanalysis ** <br>
- **MERRA2**      1980 - 2016 (NASA)
- **JRA55**       1979 - 2015 (JMA)
- **GLORYS2V3**   1993 - 2014 (Mercartor Ocean)
- **SODA3**       1980 - 2015
- **PIOMASS**     1980 - 2015

** Observations ** <br>
- **NSIDC**       1958 - 2017 

In [2]:
################################################################################# 
#########                           datapath                             ########
#################################################################################
# please specify data path
datapath_ERAI = '/home/ESLT0068/WorkFlow/Core_Database_DeepLearn/ERA-Interim'
#datapath_ERAI = 'H:\\Creator_Zone\\Core_Database_DeepLearn\\ERA-Interim'
datapath_ORAS4 = '/home/ESLT0068/WorkFlow/Core_Database_DeepLearn/ORAS4'
#datapath_ORAS4 = 'H:\\Creator_Zone\\Core_Database_DeepLearn\\ORAS4'
datapath_ORAS4_mask = '/home/ESLT0068/WorkFlow/Core_Database_AMET_OMET_reanalysis/ORAS4'
#datapath_ORAS4_mask = 'H:\\Creator_Zone\\Core_Database_DeepLearn\\ORAS4'
#datapath_PIOMASS = '/home/ESLT0068/WorkFlow/Core_Database_AMET_OMET_reanalysis/PIOMASS'
#datapath_PIOMASS = 'H:\\Creator_Zone\\Core_Database_AMET_OMET_reanalysis\\PIOMASS'
#datapath_clim_index = '/home/ESLT0068/WorkFlow/Core_Database_AMET_OMET_reanalysis/Climate_index'
#datapath_clim_index = 'F:\\PhD_essential\\Core_Database_AMET_OMET_reanalysis\\Climate_index'
output_path = '/home/ESLT0068/NLeSC/Computation_Modeling/ML4Climate/PredictArctic/Maps/Barents/Anime'
#output_path = 'C:\\Users\\nosta\\ML4Climate\\PredictArctic\\Maps\\Barents\\Anime'

In [None]:
if __name__=="__main__":
    print ('*********************** get the key to the datasets *************************')
    # weekly variables on ERAI grid
    dataset_ERAI_fields_sic = Dataset(os.path.join(datapath_ERAI,
                                      'sic_weekly_erai_1979_2017.nc'))
    dataset_ERAI_fields_slp = Dataset(os.path.join(datapath_ERAI,
                                      'slp_weekly_erai_1979_2017.nc'))
    dataset_ERAI_fields_t2m = Dataset(os.path.join(datapath_ERAI,
                                      't2m_weekly_erai_1979_2017.nc'))
    dataset_ERAI_fields_z500 = Dataset(os.path.join(datapath_ERAI,
                                       'z500_weekly_erai_1979_2017.nc'))
    dataset_ERAI_fields_z850 = Dataset(os.path.join(datapath_ERAI,
                                       'z850_weekly_erai_1979_2017.nc'))
    dataset_ERAI_fields_uv10m = Dataset(os.path.join(datapath_ERAI,
                                       'uv10m_weekly_erai_1979_2017.nc'))
    dataset_ERAI_fields_rad = Dataset(os.path.join(datapath_ERAI,
                                        'rad_flux_weekly_erai_1979_2017.nc'))
    #dataset_PIOMASS_siv = Dataset(os.path.join(datapath_PIOMASS,
    #                             'siv_monthly_PIOMASS_1979_2017.nc'))
    # OHC interpolated on ERA-Interim grid
    dataset_ORAS4_OHC = Dataset(os.path.join(datapath_ORAS4,
                                'ohc_monthly_oras2erai_1978_2017.nc'))
    dataset_index = Dataset(os.path.join(datapath_clim_index,
                            'index_climate_monthly_regress_1950_2017.nc'))
    #dataset_ERAI_fields_flux = Dataset(os.path.join(datapath_ERAI_fields,
    #                                  'surface_erai_monthly_regress_1979_2017_radiation.nc'))
    # mask
    dataset_ORAS4_mask = Dataset(os.path.join(datapath_ORAS4_mask, 'mesh_mask.nc'))
    print ('*********************** extract variables *************************')
    #################################################################################
    #########                        data gallery                           #########
    #################################################################################
    # we use time series from 1979 to 2016 (468 months in total)
    # training data: 1979 - 2013
    # validation: 2014 - 2016
    # variables list:
    # SIC (ERA-Interim) / SIV (PIOMASS) / SST (ERA-Interim) / ST (ERA-Interim) / OHC (ORAS4) / AO-NAO-AMO-NINO3.4 (NOAA)
    # integrals from spatial fields cover the area from 20N - 90N (4D fields [year, month, lat, lon])
    # *************************************************************************************** #
    # SIC (ERA-Interim) - benckmark
    SIC_ERAI = dataset_ERAI_fields_sic.variables['sic'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    year_ERAI = dataset_ERAI_fields_sic.variables['year'][:-1]
    week_ERAI = dataset_ERAI_fields_sic.variables['week'][:]
    latitude_ERAI = dataset_ERAI_fields_sic.variables['latitude'][:]
    longitude_ERAI = dataset_ERAI_fields_sic.variables['longitude'][:]
    # T2M (ERA-Interim)
    T2M_ERAI = dataset_ERAI_fields_t2m.variables['t2m'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    year_ERAI_t2m = dataset_ERAI_fields_t2m.variables['year'][:-1]
    week_ERAI_t2m = dataset_ERAI_fields_t2m.variables['week'][:]
    latitude_ERAI_t2m = dataset_ERAI_fields_t2m.variables['latitude'][:]
    longitude_ERAI_t2m = dataset_ERAI_fields_t2m.variables['longitude'][:]
    # SLP (ERA-Interim)
    SLP_ERAI = dataset_ERAI_fields_slp.variables['slp'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    year_ERAI_slp = dataset_ERAI_fields_slp.variables['year'][:-1]
    week_ERAI_slp = dataset_ERAI_fields_slp.variables['week'][:]
    latitude_ERAI_slp = dataset_ERAI_fields_slp.variables['latitude'][:]
    longitude_ERAI_slp = dataset_ERAI_fields_slp.variables['longitude'][:]
    # Z500 (ERA-Interim)
    Z500_ERAI = dataset_ERAI_fields_z500.variables['z'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    year_ERAI_z500 = dataset_ERAI_fields_z500.variables['year'][:-1]
    week_ERAI_z500 = dataset_ERAI_fields_z500.variables['week'][:]
    latitude_ERAI_z500 = dataset_ERAI_fields_z500.variables['latitude'][:]
    longitude_ERAI_z500 = dataset_ERAI_fields_z500.variables['longitude'][:]
    # Z850 (ERA-Interim)
    Z850_ERAI = dataset_ERAI_fields_z850.variables['z'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    year_ERAI_z850 = dataset_ERAI_fields_z850.variables['year'][:-1]
    week_ERAI_z850 = dataset_ERAI_fields_z850.variables['week'][:]
    latitude_ERAI_z850 = dataset_ERAI_fields_z850.variables['latitude'][:]
    longitude_ERAI_z850 = dataset_ERAI_fields_z850.variables['longitude'][:]
    # UV10M (ERA-Interim)
    U10M_ERAI = dataset_ERAI_fields_uv10m.variables['u10m'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    V10M_ERAI = dataset_ERAI_fields_uv10m.variables['v10m'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    year_ERAI_uv10m = dataset_ERAI_fields_uv10m.variables['year'][:-1]
    week_ERAI_uv10m = dataset_ERAI_fields_uv10m.variables['week'][:]
    latitude_ERAI_uv10m = dataset_ERAI_fields_uv10m.variables['latitude'][:]
    longitude_ERAI_uv10m = dataset_ERAI_fields_uv10m.variables['longitude'][:]
    # SFlux (ERA-Interim)
    SFlux_ERAI = dataset_ERAI_fields_rad.variables['SFlux'][:-1,:,:,:] # 4D fields [year, week, lat, lon]
    year_ERAI_SFlux = dataset_ERAI_fields_rad.variables['year'][:-1]
    week_ERAI_SFlux = dataset_ERAI_fields_rad.variables['week'][:]
    latitude_ERAI_SFlux = dataset_ERAI_fields_rad.variables['latitude'][:]
    longitude_ERAI_SFlux = dataset_ERAI_fields_rad.variables['longitude'][:]
    #SIV (PIOMASS)
    #SIV_PIOMASS = dataset_PIOMASS_siv.variables['SIV'][:-12]
    #year_SIV = dataset_PIOMASS_siv.variables['year'][:-1]
    # OHC (ORAS4)
    # from 1978 - 2017 (for interpolation) / from 90 N upto 40 N
    OHC_300_ORAS4 = dataset_ORAS4_OHC.variables['OHC'][:-1,:,:67,:]/1000 # unit Peta Joule
    latitude_ORAS4 = dataset_ORAS4_OHC.variables['latitude'][:]
    longitude_ORAS4 = dataset_ORAS4_OHC.variables['longitude'][:]
    mask_OHC = np.ma.getmask(OHC_300_ORAS4[0,0,:,:])
    # AO-NAO-AMO-NINO3.4 (NOAA)
    AO = dataset_index.variables['AO'][348:-1] # from 1979 - 2017
    NAO = dataset_index.variables['NAO'][348:-1]
    NINO = dataset_index.variables['NINO'][348:-1]
    AMO = dataset_index.variables['AMO'][348:-1]
    PDO = dataset_index.variables['PDO'][348:-1]

In [None]:
    # first check of grid
    print(latitude_ERAI)
    print(longitude_ERAI)
    print(latitude_ERAI_t2m)
    print(longitude_ERAI_t2m)
    print(longitude_ORAS4)

In [None]:
    #################################################################################
    ###########                 global land-sea mask                      ###########
    #################################################################################
    sea_ice_mask_global = np.ones((len(latitude_ERAI),len(longitude_ERAI)),dtype=float)
    sea_ice_mask_global[SIC_ERAI[0,0,:,:]==-1] = 0
    #################################################################################
    ###########                regionalization sea mask                   ###########
    #################################################################################
    print ('*********************** create mask *************************')
    # W:-156 E:-124 N:80 S:67
    mask_Beaufort = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    # W:-180 E:-156 N:80 S:66
    mask_Chukchi = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    # W:146 E:180 N:80 S:67
    mask_EastSiberian = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    # W:100 E:146 N:80 S:67
    mask_Laptev = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    # W:60 E:100 N:80 S:67
    mask_Kara = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    # W:18 E:60 N:80 S:64
    mask_Barents = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    # W:-44 E:18 N:80 S:55
    mask_Greenland = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    # W:-180 E:180 N:90 S:80
    mask_CenArctic = np.zeros((len(latitude_ERAI),len(longitude_ERAI)),dtype=int)
    print ('*********************** calc mask *************************')
    mask_Beaufort[13:31,32:76] = 1

    mask_Chukchi[13:32,0:32] = 1
    mask_Chukchi[13:32,-1] = 1

    mask_EastSiberian[13:31,434:479] = 1

    mask_Laptev[13:31,374:434] = 1

    mask_Kara[13:31,320:374] = 1

    mask_Barents[13:36,264:320] = 1

    mask_Greenland[13:47,179:264] = 1
    mask_Greenland[26:47,240:264] = 0

    mask_CenArctic[:13,:] = 1
    print ('*********************** packing *************************')
    mask_dict = {'Beaufort': mask_Beaufort[:,:],
                 'Chukchi': mask_Chukchi[:,:],
                 'EastSiberian': mask_EastSiberian[:,:],
                 'Laptev': mask_Laptev[:,:],
                 'Kara': mask_Kara[:,:],
                 'Barents': mask_Barents[:,:],
                 'Greenland': mask_Greenland[:,:],
                 'CenArctic': mask_CenArctic[:,:]}
    seas_namelist = ['Beaufort','Chukchi','EastSiberian','Laptev',
                     'Kara', 'Barents', 'Greenland','CenArctic']

In [None]:
    #################################################################################
    ########                  temporal interpolation matrix                  ########
    #################################################################################
    # interpolate from monthly to weekly
    # original monthly data will be taken as the last week of the month
    OHC_300_ORAS4_weekly_series = np.zeros(SIC_ERAI.reshape(len(year_ERAI)*48,len(latitude_ERAI),len(longitude_ERAI)).shape,
                                           dtype=float)
    OHC_300_ORAS4_series= deepclim.preprocess.operator.unfold(OHC_300_ORAS4)
    # calculate the difference between two months
    OHC_300_ORAS4_deviation_series = (OHC_300_ORAS4_series[1:,:,:] - OHC_300_ORAS4_series[:-1,:,:]) / 4
    for i in np.arange(4):
        OHC_300_ORAS4_weekly_series[3-i::4,:,:] = OHC_300_ORAS4_series[12:,:,:] - i * OHC_300_ORAS4_deviation_series[11:,:,:]

In [None]:
    print ('******************  calculate extent from spatial fields  *******************')
    # size of the grid box
    dx = 2 * np.pi * constant['R'] * np.cos(2 * np.pi * latitude_ERAI /
                                            360) / len(longitude_ERAI)
    dy = np.pi * constant['R'] / 480
    # calculate the sea ice area
    SIC_ERAI_area = np.zeros(SIC_ERAI.shape, dtype=float)
    SFlux_ERAI_area = np.zeros(SFlux_ERAI.shape, dtype=float)
    for i in np.arange(len(latitude_ERAI[:])):
        # change the unit to terawatt
        SIC_ERAI_area[:,:,i,:] = SIC_ERAI[:,:,i,:]* dx[i] * dy / 1E+6 # unit km2
        SFlux_ERAI_area[:,:,i,:] = SFlux_ERAI[:,:,i,:]* dx[i] * dy / 1E+12 # unit TeraWatt
    SIC_ERAI_area[SIC_ERAI_area<0] = 0 # switch the mask from -1 to 0
    print ('================  reshape input data into time series  =================')
    SIC_ERAI_area_series = deepclim.preprocess.operator.unfold(SIC_ERAI_area)
    T2M_ERAI_series = deepclim.preprocess.operator.unfold(T2M_ERAI)
    SLP_ERAI_series = deepclim.preprocess.operator.unfold(SLP_ERAI)
    Z500_ERAI_series = deepclim.preprocess.operator.unfold(Z500_ERAI)
    Z850_ERAI_series = deepclim.preprocess.operator.unfold(Z850_ERAI)
    U10M_ERAI_series = deepclim.preprocess.operator.unfold(U10M_ERAI)
    V10M_ERAI_series = deepclim.preprocess.operator.unfold(V10M_ERAI)
    SFlux_ERAI_area_series = deepclim.preprocess.operator.unfold(SFlux_ERAI_area)

In [None]:
    print ('******************  choose the fields from target region  *******************')
    # select land-sea mask
    sea_ice_mask_barents = sea_ice_mask_global[12:36,264:320]
    print ('******************  choose the fields from target region  *******************')
    # select the area between greenland and ice land for instance 60-70 N / 44-18 W
    sic_exp = SIC_ERAI_area_series[:,12:36,264:320]
    t2m_exp = T2M_ERAI_series[:,12:36,264:320]
    slp_exp = SLP_ERAI_series[:,12:36,264:320]
    z500_exp = Z500_ERAI_series[:,12:36,264:320]
    z850_exp = Z850_ERAI_series[:,12:36,264:320]
    u10m_exp = U10M_ERAI_series[:,12:36,264:320]
    v10m_exp = V10M_ERAI_series[:,12:36,264:320]
    sflux_exp = SFlux_ERAI_area_series[:,12:36,264:320]
    ohc_exp = OHC_300_ORAS4_weekly_series[:,12:36,264:320]
    print(sic_exp.shape)
    print(t2m_exp.shape)
    print(slp_exp.shape)
    print(z500_exp.shape)
    print(u10m_exp.shape)
    print(v10m_exp.shape)
    print(sflux_exp.shape)
    print(ohc_exp.shape)
    print(latitude_ERAI[12:36])
    print(longitude_ERAI[264:320])
    print(latitude_ORAS4[12:36])
    print(longitude_ORAS4[264:320])
    #print(latitude_ERAI[26:40])
    #print(longitude_ERAI[180:216])

In [None]:
    print ('*******************  pre-processing  *********************')
    print ('=========================   normalize data   ===========================')
    sic_exp_norm = deepclim.preprocess.operator.normalize(sic_exp)
    t2m_exp_norm = deepclim.preprocess.operator.normalize(t2m_exp)
    slp_exp_norm = deepclim.preprocess.operator.normalize(slp_exp)
    z500_exp_norm = deepclim.preprocess.operator.normalize(z500_exp)
    z850_exp_norm = deepclim.preprocess.operator.normalize(z850_exp)
    u10m_exp_norm = deepclim.preprocess.operator.normalize(u10m_exp)
    v10m_exp_norm = deepclim.preprocess.operator.normalize(v10m_exp)
    sflux_exp_norm = deepclim.preprocess.operator.normalize(sflux_exp)
    ohc_exp_norm = deepclim.preprocess.operator.normalize(ohc_exp)
    print('================  save the normalizing factor  =================')
    sic_max = np.amax(sic_exp)
    sic_min = np.amin(sic_exp)
    print(sic_max,"km2")
    print(sic_min,"km2")
    ohc_max = np.amax(ohc_exp)
    ohc_min = np.amin(ohc_exp)
    t2m_max = np.amax(t2m_exp)
    t2m_min = np.amin(t2m_exp)
    slp_max = np.amax(slp_exp)
    slp_min = np.amin(slp_exp)
    z500_max = np.amax(z500_exp)
    z500_min = np.amin(z500_exp)
    z850_max = np.amax(z850_exp)
    z850_min = np.amin(z850_exp)
    u10m_max = np.amax(u10m_exp)
    u10m_min = np.amin(u10m_exp)
    v10m_max = np.amax(v10m_exp)
    v10m_min = np.amin(v10m_exp)
    sflux_max = np.amax(sflux_exp)
    sflux_min = np.amin(sflux_exp)    
    print ('====================    A series of time (index)    ====================')
    _, yy, xx = sic_exp_norm.shape # get the lat lon dimension
    year = np.arange(1979,2017,1)
    year_cycle = np.repeat(year,48)
    month_cycle = np.repeat(np.arange(1,13,1),4)
    month_cycle = np.tile(month_cycle,len(year)+1) # one extra repeat for lead time dependent prediction
    month_cycle.astype(float)
    month_2D = np.repeat(month_cycle[:,np.newaxis],yy,1)
    month_exp = np.repeat(month_2D[:,:,np.newaxis],xx,2)
    print ('===================  artificial data for evaluation ====================')
    # calculate climatology of SIC
#     seansonal_cycle_SIC = np.zeros(48,dtype=float)
#     for i in np.arange(48):
#         seansonal_cycle_SIC[i] = np.mean(SIC_ERAI_sum_norm[i::48],axis=0)
    # weight for loss
#     weight_month = np.array([0,1,1,
#                              1,0,0,
#                              1,1,1,
#                              0,0,0])
    #weight_loss = np.repeat(weight_month,4)
    #weight_loss = np.tile(weight_loss,len(year))

In [None]:
    print(t2m_exp.shape)
    ax = plt.contourf(t2m_exp[443,:,:])
    print(t2m_exp[443,:,:])
    print(month_exp[0,:,:].shape)

In [None]:
    x_input = np.stack((sic_exp_norm[1,:,:],
                        t2m_exp_norm[1,:,:],
                        ohc_exp_norm[1,:,:],
                        month_exp[1,:,:]))
    print(x_input[1,:,:])
    print(x_input[:].shape)

# Procedure for LSTM <br>
** We use Pytorth to implement LSTM neural network with time series of climate data. ** <br>

In [None]:
    print ('*******************  parameter for check  *********************')
    print ('*******************  create basic dimensions for tensor and network  *********************')
    # specifications of neural network
    input_channels = 10
    hidden_channels = [10, 9, 1] # number of channels & hidden layers, the channels of last layer is the channels of output, too
    #hidden_channels = [3, 3, 3, 3, 2]
    #hidden_channels = [2]
    kernel_size = 3
    # here we input a sequence and predict the next step only
    #step = 1 # how many steps to predict ahead
    #effective_step = [0] # step to output
    batch_size = 1
    #num_layers = 1
    learning_rate = 0.005
    num_epochs = 1500
    print (torch.__version__)
    # check if CUDA is available
    use_cuda = torch.cuda.is_available()
    print("Is CUDA available? {}".format(use_cuda))
    print ('*******************  cross validation and testing data  *********************')
    # take 10% data as cross-validation data
    cross_valid_year = 4
    # take 10% years as testing data
    test_year = 4
    # minibatch
    #iterations = 3 # training data divided into 3 sets

In [None]:
    print ('*******************  preview of input tensor  *********************')
    #plt.plot(SIC_ERAI_sum)
    #print(SIC_ERAI_sum_norm[:-test_year*12])
    #print(x_input.shape)
    #print(x_input[:,:,:])
    sequence_len, _, _ = sic_exp_norm.shape
    print(sequence_len)

In [None]:
#     print ('*******************  module for calculating accuracy  *********************')
#     def accuracy(out, labels):
#         outputs = np.argmax(out, axis=1)
#     return np.sum(outputs==labels)/float(labels.size)

In [None]:
    %%time
    print ('*******************  load exsited LSTM model  *********************')
    #model = torch.load(os.path.join(output_path, 'Barents','convlstm_era_sic_oras_ohc_Barents_hl_3_kernel_3_lr_0.005_epoch_1500.pkl'))
    model = torch.load(os.path.join(output_path,'convlstm_era_sic_t2m_slp_z500_z850_uv10m_sflux_oras_ohc_Barents_hl_3_kernel_3_lr_0.005_epoch_1500_validSIC.pkl'))
    print(model)
     # check the sequence length (dimension in need for post-processing)
    sequence_len, height, width = sic_exp_norm.shape

In [None]:
    print ('*******************  evaluation matrix  *********************')
    # The prediction will be evaluated through RMSE against climatology
    
    # error score for temporal-spatial fields, without keeping spatial pattern
    def RMSE(x,y):
        """
        Calculate the RMSE. x is input series and y is reference series.
        It calculates RMSE over the domain, not over time. The spatial structure
        will not be kept.
        Parameter
        ----------------------
        x: input time series with the shape [time, lat, lon]
        """
        x_series = x.reshape(x.shape[0],-1)
        y_series = y.reshape(y.shape[0],-1)
        rmse = np.sqrt(np.mean((x_series - y_series)**2,1))
        rmse_std = np.sqrt(np.std((x_series - y_series)**2,1))
    
        return rmse, rmse_std
    
    # error score for temporal-spatial fields, keeping spatial pattern
    def MAE(x,y):
        """
        Calculate the MAE. x is input series and y is reference series.
        It calculate MAE over time and keeps the spatial structure.
        """
        mae = np.mean(np.abs(x-y),0)
        
        return mae

In [None]:
    %%time
    #################################################################################
    ########       lead time depedent prediction for cross-validation        ########
    #################################################################################
    print('##############################################################')
    print('###################  start prediction loop ###################')
    print('##############################################################')
    # the model learn from time series and try to predict the next time step based on the previous time series
    print ('*******************************  one step ahead forecast  *********************************')
    print ('************  the last {} years of total time series are treated as test data  ************'.format(cross_valid_year))
    # time series before test data
    pred_base_sic = sic_exp_norm[-cross_valid_year*12*4-test_year*12*4:-test_year*12*4,:,:]
    # predict x steps ahead
    step_lead = 6 # unit week
    # create a matrix for the prediction
    lead_pred_sic = np.zeros((cross_valid_year*12*4,step_lead,height,width),dtype=float) # dim [predict time, lead time, lat, lon]
    # start the prediction loop
    for step in range(cross_valid_year*12*4):
        # Clear stored gradient
        model.zero_grad()
        # Don't do this if you want your LSTM to be stateful
        # Otherwise the hidden state should be cleaned up at each time step for prediction (we don't clear hidden state in our forward function)
        # see example from (https://github.com/pytorch/examples/blob/master/time_sequence_prediction/train.py)
        # model.hidden = model.init_hidden()
        # based on the design of this module, the hidden states and cell states are initialized when the module is called.
        for i in np.arange(1,sequence_len-cross_valid_year*12*4-test_year*12*4 + step + step_lead,1): # here i is actually the time step (index) of prediction, we use var[:i] to predict var[i]
            #############################################################################
            ###############           before time of prediction           ###############
            #############################################################################
            if i <= (sequence_len-cross_valid_year*12*4-test_year*12*4 + step):
                # create variables
                x_input = np.stack((sic_exp_norm[i-1,:,:],
                                    ohc_exp_norm[i-1,:,:],
                                    t2m_exp_norm[i-1,:,:],
                                    slp_exp_norm[i-1,:,:],
                                    z500_exp_norm[i-1,:,:],
                                    z850_exp_norm[i-1,:,:],
                                    u10m_exp_norm[i-1,:,:],
                                    v10m_exp_norm[i-1,:,:],
                                    sflux_exp_norm[i-1,:,:],                                    
                                    month_exp[i-1,:,:])) #vstack,hstack,dstack
                x_var_pred = torch.autograd.Variable(torch.Tensor(x_input).view(-1,input_channels,height,width),
                                                     requires_grad=False).cuda()
                # make prediction
                last_pred, _ = model(x_var_pred, i-1)
                # record the real prediction after the time of prediction
                if i == (sequence_len-cross_valid_year*12*4-test_year*12*4 + step):
                    lead = 0
                    # GPU data should be transferred to CPU
                    lead_pred_sic[step,0,:,:] = last_pred[0,0,:,:].cpu().data.numpy()
            #############################################################################
            ###############            after time of prediction           ###############
            #############################################################################
            else:
                lead += 1
                # prepare predictor
                # use the predicted data to make new prediction
                x_input = np.stack((lead_pred_sic[step,i-(sequence_len-cross_valid_year*12*4-test_year*12*4 + step +1),:,:],
                                    ohc_exp_norm[i-1,:,:],
                                    t2m_exp_norm[i-1,:,:],
                                    slp_exp_norm[i-1,:,:],
                                    z500_exp_norm[i-1,:,:],
                                    z850_exp_norm[i-1,:,:],
                                    u10m_exp_norm[i-1,:,:],
                                    v10m_exp_norm[i-1,:,:],
                                    sflux_exp_norm[i-1,:,:],
                                    month_exp[i-1,:,:])) #vstack,hstack,dstack
                x_var_pred = torch.autograd.Variable(torch.Tensor(x_input).view(-1,input_channels,height,width),
                                                     requires_grad=False).cuda()        
                # make prediction
                last_pred, _ = model(x_var_pred, i-1)
                # record the prediction
                lead_pred_sic[step,lead,:,:] = last_pred[0,0,:,:].cpu().data.numpy()

In [None]:
    #############################################################################################################
    ########        visualization of lead time dependent prediction with cross-validation data           ########
    #############################################################################################################
    index_plot = np.arange(cross_valid_year*12*4)
    index_plot_step = np.arange(cross_valid_year*12*4+1)
    year_index = np.arange(2009,2013,1)
    # repeat climatology as reference
    #climatology = np.tile(seansonal_cycle_SIC,len(year_ERAI))
    # create index
    # correction for float point at 0
    lead_pred_sic[lead_pred_sic<0] = 0
    # extend the dimension of sea ice mask
    sea_ice_mask_crossValid = np.repeat(sea_ice_mask_barents[np.newaxis,:,:],cross_valid_year*48,0)
    # correct the land cells in the prediction
    for i in range(step_lead):
        lead_pred_sic[:,i,:,:] = lead_pred_sic[:,i,:,:] * sea_ice_mask_crossValid
    print ("*******************  Predicted Ice Extent  **********************")
    # include text box in the figure
    #text_content = '$RMSE=%.3f$ ' % (error_pred)
    sic_extend_lead = np.sum(np.sum(lead_pred_sic,3),2)
    
    colormap=cm.autumn(range(cross_valid_year*12*4))
    
    fig0 = plt.figure(figsize=(12,6))
    for i in range(cross_valid_year*12*4):
        plt.plot(np.arange(index_plot[i],index_plot[i]+step_lead), sic_extend_lead[i,:] * sic_max / 1E+6, color=colormap[i])
    plt.scatter(index_plot, sic_extend_lead[:,0] * sic_max / 1E+6, color='r', label="Lead 0")
    plt.scatter(index_plot_step[1:], sic_extend_lead[:,1] * sic_max / 1E+6, color='g', label="Lead 1")
    plt.plot(index_plot, np.sum(np.sum(sic_exp_norm[-cross_valid_year*12*4-test_year*12*4:-test_year*12*4,:,:],2),1) * sic_max / 1E+6,
             'b', label="Observation")
    plt.scatter(index_plot, np.sum(np.sum(sic_exp_norm[-cross_valid_year*12*4-test_year*12*4:-test_year*12*4,:,:],2),1) * sic_max / 1E+6,
                color='b')
    #plt.plot(index_plot, climatology, 'c--',label="climatology")
    plt.xlabel('Time (week)',fontsize = 14)
    plt.ylabel('Sea ice extent (million square kilometers)',fontsize = 14)
    plt.xticks(np.arange(0,cross_valid_year*12*4,6*4),(['200901', '200907', '201001', '201007',
                                                        '201101', '201107', '201201', '201207']),
               fontsize = 12)
    plt.yticks(np.arange(0,0.5,0.1),fontsize = 12)
    plt.legend(frameon=False, loc=1, prop={'size': 12})
    #props = dict(boxstyle='round', facecolor='white', alpha=0.8)
    #ax = plt.gca()
    #ax.text(0.03,0.2,text_content,transform=ax.transAxes,fontsize=10,verticalalignment='top',bbox=props)
    plt.show()
    fig0.savefig(os.path.join(output_path,'SIC_ERAI_LSTM_pred_lead_crossValid.png'),dpi=200)

    print ("*******************  Prediction Ice Distribution  **********************")
    mae = MAE(lead_pred_sic[:,0,:,:],
              sic_exp_norm[-cross_valid_year*12*4-test_year*12*4:-test_year*12*4,:,:])
    label = 'MAE (square kilometers)'
    ticks = [i for i in np.linspace(0,200,11)]
    deepclim.visual.plots.geograph(latitude_ERAI[12:36], longitude_ERAI[264:320],
                                   mae * sic_max, label, ticks,
                                   os.path.join(output_path,'spatial_sic_mae_avg_crossValid.png'),
                                   boundary='Barents_Polar', colormap='Blues')
    print ("*******************  Other variables (Prediction with testing sets only) **********************")

In [None]:
    %%time
    #################################################################################
    ########  operational lead time dependent prediction with testing data   ########
    #################################################################################
    print('##############################################################')
    print('###################  start prediction loop ###################')
    print('##############################################################')
    # the model learn from time series and try to predict the next time step based on the previous time series
    print ('*******************************  one step ahead forecast  *********************************')
    print ('************  the last {} years of total time series are treated as test data  ************'.format(test_year))
    # time series before test data
    pred_base_sic = sic_exp_norm[:-test_year*12*4,:,:]
    # predict x steps ahead
    step_lead = 6 # unit week
    # create a matrix for the prediction
    lead_pred_sic = np.zeros((test_year*12*4,step_lead,height,width),dtype=float) # dim [predict time, lead time, lat, lon]
    # start the prediction loop
    for step in range(test_year*12*4):
        # Clear stored gradient
        model.zero_grad()
        # Don't do this if you want your LSTM to be stateful
        # Otherwise the hidden state should be cleaned up at each time step for prediction (we don't clear hidden state in our forward function)
        # see example from (https://github.com/pytorch/examples/blob/master/time_sequence_prediction/train.py)
        # model.hidden = model.init_hidden()
        # based on the design of this module, the hidden states and cell states are initialized when the module is called.
        for i in np.arange(1,sequence_len-test_year*12*4 + step + step_lead,1): # here i is actually the time step (index) of prediction, we use var[:i] to predict var[i]
            #############################################################################
            ###############           before time of prediction           ###############
            #############################################################################
            if i <= (sequence_len-test_year*12*4 + step):
                # create variables
                x_input = np.stack((sic_exp_norm[i-1,:,:],
                                    ohc_exp_norm[i-1,:,:],
                                    t2m_exp_norm[i-1,:,:],
                                    slp_exp_norm[i-1,:,:],
                                    z500_exp_norm[i-1,:,:],
                                    z850_exp_norm[i-1,:,:],
                                    u10m_exp_norm[i-1,:,:],
                                    v10m_exp_norm[i-1,:,:],
                                    sflux_exp_norm[i-1,:,:],
                                    month_exp[i-1,:,:])) #vstack,hstack,dstack
                x_var_pred = torch.autograd.Variable(torch.Tensor(x_input).view(-1,input_channels,height,width),
                                                     requires_grad=False).cuda()
                # make prediction
                last_pred, _ = model(x_var_pred, i-1)
                # record the real prediction after the time of prediction
                if i == (sequence_len-test_year*12*4 + step):
                    lead = 0
                    # GPU data should be transferred to CPU
                    lead_pred_sic[step,0,:,:] = last_pred[0,0,:,:].cpu().data.numpy()
            #############################################################################
            ###############            after time of prediction           ###############
            #############################################################################
            else:
                lead += 1
                # prepare predictor
                if i <= sequence_len:
                    # use the predicted data to make new prediction
                    x_input = np.stack((lead_pred_sic[step,i-(sequence_len-test_year*12*4 + step +1),:,:],
                                        ohc_exp_norm[i-1,:,:],
                                        t2m_exp_norm[i-1,:,:],
                                        slp_exp_norm[i-1,:,:],
                                        z500_exp_norm[i-1,:,:],
                                        z850_exp_norm[i-1,:,:],
                                        u10m_exp_norm[i-1,:,:],
                                        v10m_exp_norm[i-1,:,:],
                                        sflux_exp_norm[i-1,:,:],
                                        month_exp[i-1,:,:])) #vstack,hstack,dstack
                else: # choice_exp_norm out of range, use the last value
                    x_input = np.stack((lead_pred_sic[step,i-(sequence_len-test_year*12*4 + step +1),:,:],
                                        ohc_exp_norm[-1,:,:],
                                        t2m_exp_norm[-1,:,:],
                                        slp_exp_norm[-1,:,:],
                                        z500_exp_norm[-1,:,:],
                                        z850_exp_norm[-1,:,:],
                                        u10m_exp_norm[-1,:,:],
                                        v10m_exp_norm[-1,:,:],
                                        sflux_exp_norm[-1,:,:],
                                        month_exp[i-1,:,:])) #vstack,hstack,dstack                    
                x_var_pred = torch.autograd.Variable(torch.Tensor(x_input).view(-1,input_channels,height,width),
                                                     requires_grad=False).cuda()        
                # make prediction
                last_pred, _ = model(x_var_pred, i-1)
                # record the prediction
                lead_pred_sic[step,lead,:,:] = last_pred[0,0,:,:].cpu().data.numpy()

In [None]:
    #################################################################################
    ########        correction of lead time dependent prediction          ########
    #################################################################################
    ######################    data cleaner   ######################
    # repeat climatology as reference
    #climatology = np.tile(seansonal_cycle_SIC,len(year_ERAI))
    # create index
    # correction for float point at 0
    lead_pred_sic[lead_pred_sic<0] = 0
    # extend the dimension of sea ice mask
    sea_ice_mask_test = np.repeat(sea_ice_mask_barents[np.newaxis,:,:],test_year*48,0)
    # correct the land cells in the prediction
    for i in range(step_lead):
        lead_pred_sic[:,i,:,:] = lead_pred_sic[:,i,:,:] * sea_ice_mask_test
    #################################################################################
    ########          transfer the sea ice fields into binary data           ########
    #################################################################################
    criterion_0 = 0.15 # ice concentration below the threshold is regarded as no ice
    # remove the area weight
    sic_exp_denorm = np.zeros(sic_exp_norm.shape, dtype=float)
    lead_pred_sic_denorm = np.zeros(lead_pred_sic.shape, dtype=float)
    for i in np.arange(height):
        lead_pred_sic_denorm[:,:,i,:] = lead_pred_sic[:,:,i,:] / dx[i+12] * dx[35]
        sic_exp_denorm[:,i,:] = sic_exp_norm[:,i,:] / dx[i+12] * dx[35]
    # turn sea ice fields into binary data
    lead_pred_sic_bin = lead_pred_sic_denorm[:]
    sic_exp_bin = sic_exp_denorm[:]
    lead_pred_sic_bin[lead_pred_sic_bin <= criterion_0] = 0
    lead_pred_sic_bin[lead_pred_sic_bin > criterion_0] = 1
    sic_exp_bin[sic_exp_bin <= criterion_0] = 0
    sic_exp_bin[sic_exp_bin > criterion_0] = 1
    # turn matrix into int
    lead_pred_sic_bin = lead_pred_sic_bin.astype(int)
    sic_exp_bin = sic_exp_bin.astype(int)

In [None]:
    #################################################################################
    ########        visualization of lead time dependent prediction          ########
    #################################################################################
    year_index = np.arange(2013,2017,1)
    index_plot = np.arange(test_year*12*4)

    # first construct iris coordinate
    lat_iris = iris.coords.DimCoord(latitude_ERAI[12:36], standard_name='latitude', long_name='latitude',
                                    var_name='lat', units='degrees')
    lon_iris = iris.coords.DimCoord(longitude_ERAI[264:320], standard_name='longitude', long_name='longitude',
                                    var_name='lon', units='degrees')
    # take obs sic from certain period
    sic_select_norm = sic_exp_norm[-test_year*12*4:,:,:]
    sic_select_bin = sic_exp_bin[-test_year*12*4:,:,:]
    sic_persist_norm = sic_exp_norm[-test_year*12*4-1:-1,:,:]
    
    for i in index_plot:
        year = year_index[i//48]
        month = i//4 - (year-2013)*12 + 1
        week = i - (month-1)*4 - (year-2013) * 48 + 1
        # figure
        fig = plt.figure(figsize=(12,10))
        fig.suptitle('ConvLSTM SIC prediction year year {} month {} week {}'.format(year, month, week))
        
        # submap 1 - configure map
        ax1 = plt.subplot(2, 2, 1, projection=ccrs.EquidistantConic(central_longitude=39.0, central_latitude=72.0))
        ax1.set_extent([16,60,60,82],ccrs.PlateCarree()) # W:18 E:60 S:64 N:80    
        ax1.set_aspect('1')
        ax1.coastlines()
        gl = ax1.gridlines(linewidth=1, color='gray', alpha=0.5, linestyle='--')
        # assemble cube iris
        cube_iris = iris.cube.Cube(lead_pred_sic[i,0,:,:], long_name='geographical field', var_name='field', 
                                   units='1', dim_coords_and_dims=[(lat_iris, 0), (lon_iris, 1)])
        # other set-ups
        ticks = [i for i in np.linspace(0.0,1.0,11)]
        label = 'Pred SIC'
        # make plots
        cs = iplt.contourf(cube_iris, cmap='Blues',levels=ticks, extend='both', vmin=ticks[0], vmax=ticks[-1])
        cbar = fig.colorbar(cs,extend='both', orientation='horizontal',
                            shrink =0.8, pad=0.05)#, format="%.1f")
        cbar.set_label(label,size = 12)
        cbar.set_ticks(ticks)
        cbar.ax.tick_params(labelsize = 9)
        
        # submap 2 - configure map
        ax2 = plt.subplot(2, 2, 2, projection=ccrs.EquidistantConic(central_longitude=39.0, central_latitude=72.0))
        ax2.set_extent([16,60,60,82],ccrs.PlateCarree()) # W:18 E:60 S:64 N:80    
        ax2.set_aspect('1')
        ax2.coastlines()
        gl = ax2.gridlines(linewidth=1, color='gray', alpha=0.5, linestyle='--')
        # assemble cube iris
        cube_iris = iris.cube.Cube(sic_select_norm[i,:,:], long_name='geographical field', var_name='field', 
                                   units='1', dim_coords_and_dims=[(lat_iris, 0), (lon_iris, 1)])
        # other set-ups
        ticks = [i for i in np.linspace(0.0,1.0,11)]
        label = 'Obs SIC'
        # make plots       
        cs = iplt.contourf(cube_iris, cmap='Blues',levels=ticks, extend='both', vmin=ticks[0], vmax=ticks[-1])
        cbar = fig.colorbar(cs,extend='both', orientation='horizontal',
                            shrink =0.8, pad=0.05)#, format="%.1f")
        cbar.set_label(label,size = 12)
        cbar.set_ticks(ticks)
        cbar.ax.tick_params(labelsize = 9)

        # submap 3 - configure map
        ax3 = plt.subplot(2, 2, 3, projection=ccrs.EquidistantConic(central_longitude=39.0, central_latitude=72.0))
        ax3.set_extent([16,60,60,82],ccrs.PlateCarree()) # W:18 E:60 S:64 N:80    
        ax3.set_aspect('1')
        ax3.coastlines()
        gl = ax3.gridlines(linewidth=1, color='gray', alpha=0.5, linestyle='--')
        # assemble cube iris
        cube_iris = iris.cube.Cube(lead_pred_sic[i,0,:,:] - sic_select_norm[i,:,:], long_name='geographical field', var_name='field', 
                                   units='1', dim_coords_and_dims=[(lat_iris, 0), (lon_iris, 1)])
        # other set-ups
        ticks = [i for i in np.linspace(-0.2,0.2,11)]
        label = 'Pred - Obs (SIC)'
        # make plots
        cs = iplt.contourf(cube_iris, cmap='coolwarm',levels=ticks, extend='both', vmin=ticks[0], vmax=ticks[-1])
        cbar = fig.colorbar(cs,extend='both', orientation='horizontal',
                            shrink =0.8, pad=0.05)#, format="%.1f")
        cbar.set_label(label,size = 12)
        cbar.set_ticks(ticks)
        cbar.ax.tick_params(labelsize = 9)

        # submap 4 - configure map
        ax4 = plt.subplot(2, 2, 4, projection=ccrs.EquidistantConic(central_longitude=39.0, central_latitude=72.0))
        ax4.set_extent([16,60,60,82],ccrs.PlateCarree()) # W:18 E:60 S:64 N:80    
        ax4.set_aspect('1')
        ax4.coastlines()
        gl = ax4.gridlines(linewidth=1, color='gray', alpha=0.5, linestyle='--')
        # assemble cube iris
        cube_iris = iris.cube.Cube(np.abs(lead_pred_sic[i,0,:,:] - sic_select_norm[i,:,:]) - np.abs(sic_persist_norm[i,:,:] - sic_select_norm[i,:,:]),
                                   long_name='geographical field', var_name='field', 
                                   units='1', dim_coords_and_dims=[(lat_iris, 0), (lon_iris, 1)])
        # other set-ups
        ticks = [i for i in np.linspace(-0.1,0.1,11)]
        label = 'Pred - Persist (Absolute Error)'
        # make plots
        cs = iplt.contourf(cube_iris, cmap='coolwarm',levels=ticks, extend='both', vmin=ticks[0], vmax=ticks[-1])
        cbar = fig.colorbar(cs,extend='both', orientation='horizontal',
                            shrink =0.8, pad=0.05)#, format="%.1f")
        cbar.set_label(label,size = 12)
        cbar.set_ticks(ticks)
        cbar.ax.tick_params(labelsize = 9)

        # adjust location of subplots
        plt.subplots_adjust(left= 0.05, right= 0.95, bottom=0.05, top=0.95, wspace = 0.02, hspace = 0.05)
        
        fig.savefig(os.path.join(output_path,'sic_pred_obs_anime_week_{}.png'.format(i)), dpi=200)
        iplt.show()
        plt.close(fig)

In [3]:
    #################################################################################
    ########                       Animation Generator                       ########
    #################################################################################
    #For long movie (save memory)
    test_year = 4
    index_plot = np.arange(test_year*12*4)
    # output path
    #output_path = 'C:\\Users\\nosta\\ML4Climate\\PredictArctic\\Maps\\Barents\\Anime'
    # quantizer determines the quality of gif, there are options like 'nq' 'wu'
    with imageio.get_writer(os.path.join(output_path, 'ConvLSTM_sic.gif'), mode='I',duration=0.3,quantizer='wu') as writer:
    #with imageio.get_writer(os.path.join(output_path, 'ConvLSTM_sic.gif'), mode='I',duration=0.5) as writer:
        for i in index_plot:
            print("step {}".format(i))
            images = imageio.imread(os.path.join(output_path,'sic_pred_obs_anime_week_{}.png'.format(i)))
            writer.append_data(images[:])
    writer.close()
    # for short gif
#     images = []
#     for i in index_plot:
#         images.append(imageio.imread(os.path.join(output_path,'sic_pred_obs_anime_week_{}.png'.format(i))))
#     imageio.mimwrite(os.path.join(output_path, 'ConvLSTM_sic.gif'), images[:], duration=0.3,quantizer='nq')    

step 0
step 1
step 2
step 3
step 4
step 5
step 6
step 7
step 8
step 9
step 10
step 11
step 12
step 13
step 14
step 15
step 16
step 17
step 18
step 19
step 20
step 21
step 22
step 23
step 24
step 25
step 26
step 27
step 28
step 29
step 30
step 31
step 32
step 33
step 34
step 35
step 36
step 37
step 38
step 39
step 40
step 41
step 42
step 43
step 44
step 45
step 46
step 47
step 48
step 49
step 50
step 51
step 52
step 53
step 54
step 55
step 56
step 57
step 58
step 59
step 60
step 61
step 62
step 63
step 64
step 65
step 66
step 67
step 68
step 69
step 70
step 71
step 72
step 73
step 74
step 75
step 76
step 77
step 78
step 79
step 80
step 81
step 82
step 83
step 84
step 85
step 86
step 87
step 88
step 89
step 90
step 91
step 92
step 93
step 94
step 95
step 96
step 97
step 98
step 99
step 100
step 101
step 102
step 103
step 104
step 105
step 106
step 107
step 108
step 109
step 110
step 111
step 112
step 113
step 114
step 115
step 116
step 117
step 118
step 119
step 120
step 121
step 122
ste