In [None]:
from os import path
import pathlib
import platform
import subprocess
import sys

import cartopy.crs as ccrs
import dask
import matplotlib.pyplot as plt
from netCDF4 import Dataset
import numpy as np
import pandas as pd
import scipy as sp
import torch
import torch.nn as nn
import torch.fft as fft
import xarray

from sht_utils import *
from subs1_utils import *

# Variable Glossary
In the following cell you can set the values of the variables relevant to the model. The details of each variable are included below.


## Standard Variables
In most cases it is only necessary to set values for the standard variables.

**zw** is the zonal wave number. For standard use, zw should be set to the value of either 42, 63, or 124. Setting the value for zw also sets default values for the following variables: mw, jmax, imax, and steps_per_day.
<br>Set zw = 42, to set mw = 42, jmax = 64, imax = 128, and steps_per_day = 216.
<br>Set zw = 63, to set mw = 63, jmax = 96, imax = 192, and steps_per_day = 324.
<br>Set zw = 124, to set mw = 124, jmax = 188, imax = 376, and steps_per_day = 648.
<br>Each of these variables (mw, jmax, imax, and steps_per_day) that is given a value in the advanced variables section will instead use that value.

**kmax** is the number of vertical levels. The value of kmax should be 11 or 26 for standard use.

**expname** is the name you want to be given to your experiment. When the data is saved to your computer, it will be saved in a folder with this name. Note that if you run this program twice with the same expname, your first experiment's data will be overwritten.


## Advanced Variables
While most cases only require setting the standard variables, some cases might require setting some or all advanced variables as well. The following variables should only be changed from their default value if a specific behavior is desired. An advanced variable set to the value of None will use the default case.

**mw** is the meridional wave number. In the standard case this value is set equal to zw.

**jmax** is the number of Gaussian latitudes. jmax = imax/2

**imax** is the number of longitude grid points. imax >= 3 * zw + 1. imax must be an even number.

**steps_per_day** is the number of time steps per day. It gives you the delta t in the time differencing scheme. The length of a day is 86400 seconds, so delta t = 86400/steps_per_day. Changing this number implies time step changes and should be implemented carefully. The values used in the standard case were determined expertimentally.

**custom_path** is the full path of the folder in which you wish to save your data. If custom_path is set, expname is ignored. Note that this must be an existing folder.

**custom_kmax** is used to safeguard against using unexpected values for the kmax. If custom_kmax is set, it will be used instead of kmax. By default the program only supports kmax with a value of either 11 or 26. Other values are implementable, but the user must modify subs1_utils.py routine bscst. If unclear email bkirtman@miami.edu for clarification.

In [None]:
# Set model parameters.

# Standard Variables
zw = 42
kmax = 11
expname = 'PrescribedMeanT42L11F'

# Advanced Variables
mw = None
jmax = None
imax = None
steps_per_day = None
custom_path = None
custom_kmax = None

In [None]:
# Initialize the model.

# Set value of kmax if custom_kmax is used.
if not(custom_kmax is None):
    kmax = custom_kmax
    print("Using custom value for kmax:", kmax)
# Otherwise check value for kmax.
elif kmax!=11 and kmax!=26:
    raise Exception("Unexpected value for kmax. Use custom_kmax and note that other values are implementable, but the user must modify subs1_utils.py routine bscst. If unclear email bkirtman@miami.edu for clarification.")

# Check value for zw.
# Afterwards, set mw, jmax, imax, and steps_per_day values based on the value given to zw.
# If a value is already given for one of the listed variables, use that instead.
match zw:
    case 42:
        mw = 42 if (mw is None) else mw
        jmax = 64 if (jmax is None) else jmax
        imax = 128 if (imax is None) else imax
        steps_per_day = 216 if (steps_per_day is None) else steps_per_day
    case 63:
        mw = 63 if (mw is None) else mw
        jmax = 96 if (jmax is None) else jmax
        imax = 192 if (imax is None) else imax
        steps_per_day = 324 if (steps_per_day is None) else steps_per_day
    case 124:
        mw = 124 if (mw is None) else mw
        jmax = 188 if (jmax is None) else jmax
        imax = 376 if (imax is None) else imax
        steps_per_day = 648 if (steps_per_day is None) else steps_per_day
    case _:
        if (mw is None) or (jmax is None) or (imax is None) or (steps_per_day is None):
            raise Exception("Unexpected value for zw. Other values are implementable, but the user must specify values for mw, jmax, imax, and steps_per_day in the advanced variables section.")
print("zw =", zw,
      "\nmw =", mw,
      "\nkmax =", kmax,
      "\njmax =", jmax,
      "\nimax =", imax,
      "\nsteps_per_day =", steps_per_day)

spec = (mw,zw,kmax)
grid = (imax,jmax,kmax)

In [None]:
# Set preprocess path.

# Use this path whenever loading model data.
# The preprocess file shares a directory with the model and saves its data to
# a folder under that directory. This folder is named after the variable values in the data.
preprocess_path = (
    'preprocess'
    + '__zw_' + str(zw)
    + '__kmax_' + str(kmax)
    + '\\'
)

# Check that the path exists, throwing an exception if it doesn't.
folder = path.join(pathlib.Path().resolve(), preprocess_path)
if path.isdir(folder):
    print("Directory containing preprocess data was found.",
          "\npreprocess_path =", preprocess_path)
else:
    raise Exception("Directory containing preprocess data was not found. "
                    + "\npreprocess_path = " + str(preprocess_path)
                    + "\nfull path = " + str(folder)
                    + "\nRun preprocess.ipynb prior to running the model."
                    + "\nPreprocess must use the same variable values as the model."
    )

In [None]:
# Set output datapath.

# If custom_path was set, use that as the output datapath.
# Otherwise create an appropriate datapath for the user's operating system.
user_platform = platform.system() if (custom_path is None) else "Custom Path"
print("Setting output datapath for", user_platform)
datapath = ''
match user_platform:
    case 'Custom Path':
        datapath = custom_path
    case 'Windows':
        foo = str(subprocess.check_output(['whoami']))
        end = len(foo) - 5
        uname = foo[2:end].split("\\\\")[1]
        datapath = "C:\\Users\\" + uname + "\\Documents\\AGCM_Experiments\\" + expname + "\\"
        subprocess.run(['rmdir', '/s', '/q', datapath], shell=True)
        subprocess.run(['mkdir', datapath], shell=True)
    case 'Darwin':
        foo = str(subprocess.check_output(['whoami']))
        end = len(foo) - 3
        uname = foo[2:end]
        datapath = '/Users/'+uname+'/Documents/AGCM_Experiments/'+expname+'/'
        subprocess.call(['rm','-r', datapath])
        subprocess.check_output(['mkdir', datapath])
    case _:
        raise Exception("Use case for this system/OS is not implemented. Consider using custom_path in the advanced variables.")
print("datapath =", datapath)

In [None]:
# Get the Gaussian latitudes and equally spaced longitudes
#
cost_lg, wlg = legendre_gauss_weights(jmax, -1, 1)
lats = np.flip(np.arccos(cost_lg))
lats = -90+180*lats/(np.pi)
#
lons = np.linspace(0.0,360.0-360.0/imax,imax)
#
# Initialize grid to spectral (sht) and spectral to grid (isht)
# transforms
#
sht = RealSHT(jmax, imax, lmax=mw, mmax=zw, grid="legendre-gauss", csphase=False)
isht = InverseRealSHT(jmax, imax, lmax=mw, mmax=zw, grid="legendre-gauss", csphase=False)
vsht = RealVectorSHT(jmax, imax, lmax=mw, mmax=zw, grid="legendre-gauss", csphase=False)
ivsht = InverseRealVectorSHT(jmax, imax, lmax=mw, mmax=zw, grid="legendre-gauss", csphase=False)

In [None]:
#
# Initialization: Could read spectral restarts, or could
#              start at rest. If wanting to use grid point see
#              jupyter notebook preprocess
#              
#
# Implement at rest initial condition, but need coriolis
#
#
coriolis = np.zeros((jmax,imax))
for jj in range(jmax):
    coriolis[jj,:] = -(4.0*np.pi/86400)*cost_lg[jj] # minus sign because 
                                                    #grid runs South-To-North
#
#
# Initialize spectral fields (at rest or to be read in)
#
zmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
zmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
zmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
dmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
dmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
dmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
tmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
tmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
tmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
wmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
wmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
wmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
qmn1 = torch.zeros((mw,zw),dtype=torch.complex128)
qmn2 = torch.zeros((mw,zw),dtype=torch.complex128)
qmn3 = torch.zeros((mw,zw),dtype=torch.complex128)
#
# Introduce an initial perturbation to get things going
#
datapath_init = './'
zmn1 = torch.load(datapath_init+'zmn1.spectral.pt')
zmn2 = torch.load(datapath_init+'zmn2.spectral.pt')
dmn1 = torch.load(datapath_init+'dmn1.spectral.pt')
dmn2 = torch.load(datapath_init+'dmn2.spectral.pt')
tmn1 = torch.load(datapath_init+'tmn1.spectral.pt')
tmn2 = torch.load(datapath_init+'tmn2.spectral.pt')
qmn1 = torch.load(datapath_init+'qmn1.spectral.pt')
qmn2 = torch.load(datapath_init+'qmn2.spectral.pt')
#
#
# Topography data - this should be spectral data or can be
#                        initialized to zero. If grid point data
#                        is desired see gptosp.agcm.ipynb for how to
#                        convert to spectral. 
#
# Setting topography to zero here
#
phismn = torch.zeros((mw,zw),dtype=torch.complex128)
#
# If non-zero topog read here
#
####phismn = torch.load('topog.spectral.pt') # only read topography is background state is zonally symmetric
#
#
# Adding heating here see preprocess.ipynb
#
heat = torch.load(preprocess_path+'heat.ggrid.pt')
#
# or set to zero
#
###heat = torch.zeros((kmax,jmax,imax),dtype=torch.float64)
#
#
# Read Climatology on Gausian Grid
#
uclim = torch.load(preprocess_path+'usig.ggrid.pt')
vclim = torch.load(preprocess_path+'vsig.ggrid.pt')
tclim = torch.load(preprocess_path+'tsig.ggrid.pt')
vortclim = torch.load(preprocess_path+'vortsig.ggrid.pt')
divclim = torch.load(preprocess_path+'divsig.ggrid.pt')
dxqclim = torch.load(preprocess_path+'dxq_gg.ggrid.pt')
dyqclim = torch.load(preprocess_path+'dyq_gg.ggrid.pt')
#
#

In [None]:
###
### Constants, parameters, vertical differencing parameters,
### matricies for geopotential height, etc ...
###
delsig, si, sl, sikap, slkap, cth1, cth2, r1b, r2b = bscst(kmax)
### The above code is in subs1_utils.py - vertical structure related
### This code would need to be changed if the vertical resolution
### is changed - could be done by simply specifying delsig in bscst
###
amtrx, cmtrx, dmtrx = mcoeff(kmax,si,sl,slkap,r1b,r2b,delsig)
### The above code is for geopotential height and implicit scheme
### in subs1_utils.py but unlikely any changes would be needed
emtrx = inv_em(dmtrx,steps_per_day,kmax,mw,zw)
### The above code
### emtrix is used in the implicit time scheme, computed once here to save cpu time
### changes unlikely

In [None]:
#### Preprocessing is complete - now time to run model
#
#
# The Model Runs in 30-day chuncks - need to specify how many 30-day chunks to run
tl = 30 ##### tl is the chunk size - typically 30 days, but for testing 3 is reasonable
#
# Suggested ichunk for time dependent models: 120
#
ae = 6.371E+06 # Earth radius
tmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
zmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
dmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
qmnt = torch.zeros((tl,mw,zw),dtype=torch.complex128)
wmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
#
ddtdiv = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
ddtvort = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
junk = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
ttend = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
#
vort = torch.zeros((kmax,jmax,imax),dtype=torch.float64)
div = torch.zeros((kmax,jmax,imax),dtype=torch.float64)
temp = torch.zeros((kmax,jmax,imax),dtype=torch.float64)
qdot = torch.zeros((kmax,jmax,imax),dtype=torch.float64)
times = pd.date_range(start = '1950-01-01', end='2100-01-01', freq='D')
# 
ichunk = 30
#ichunk = 1      # For shortening the runtime while testing.
#
idays = tl * ichunk
#
#
#
#
# Begin Time Loop
#
ii = 0
savedat = 0
daycount = 0
total_days = 0
nstep = idays*steps_per_day
while ii < nstep:
    ii = ii + 1
    savedat = savedat + 1
    zmnt[daycount] = zmnt[daycount] + zmn1/steps_per_day
    dmnt[daycount] = dmnt[daycount] + dmn1/steps_per_day
    tmnt[daycount] = tmnt[daycount] + tmn1/steps_per_day
    qmnt[daycount] = qmnt[daycount] + qmn1/steps_per_day
    wmnt[daycount] = wmnt[daycount] + wmn1/steps_per_day
    if (savedat == steps_per_day): # post processing
        #
        # Call Postprocessing Routine as needed
        #
        print((dmn1[10,2,2],dmn1[9,2,1]))
        daycount = daycount + 1
        total_days = total_days + 1
        print(['Day = ',total_days])
        if (daycount == tl):
            times_30day = times[total_days-tl:total_days]
            postprocessing(isht,ivsht,zmnt,dmnt,tmnt,qmnt,wmnt,\
                           phismn,amtrx,times_30day,mw,zw,\
                           kmax,imax,jmax,sl,lats,lons,tl,datapath)
            tmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            zmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            dmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            qmnt = torch.zeros((tl,mw,zw),dtype=torch.complex128)
            wmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            daycount = 0
        savedat = 0
    #
    # Run model for one time step
    #
    # Spectral to grid transformation of needed fields:
    #   Vorticity, divergence, temperature, U, V, 
    #   grad(ln(Ps)), Q prescibed heating
    #
    #
    for k in range(kmax):
        vort[k] = isht(zmn2[k]).cpu() ### This is the relative vorticity
        div[k] = isht(dmn2[k]).cpu()
        temp[k] = isht(tmn2[k]).cpu()
        qdot[k] = isht(wmn2[k]).cpu()
    u,v = uv(ivsht,zmn2,dmn2,mw,zw,kmax,imax,jmax)
    dxq,dyq = gradq(ivsht,qmn2,mw,zw,imax,jmax)
    #
    # Stack grid variables climo first[0] perturbation second[1]
    #
    vort2 = torch.stack((vortclim,vort)) # vortclim is the absolute vorticity
    div2 = torch.stack((divclim,div))
    temp2 = torch.stack((tclim,temp))
    u2 = torch.stack((uclim,u))
    v2 = torch.stack((vclim,v))
    dxq2 = torch.stack((dxqclim,dxq))
    dyq2 = torch.stack((dyqclim,dyq))
    #
    # Non-Linear products
    #
    a,b,e,ut,vt,ri,wj,cbar,dbar = nlprod_prescribed_mean(u2,v2,vort2,div2,temp2,dxq2,dyq2,heat,coriolis,delsig,si,sikap,slkap,\
                                         r1b,r2b,cth1,cth2,cost_lg,kmax,imax,jmax)
    #
    #
    # Grid to spectral transformation of nlprod results
    #
    ddtdiv,ddtvort = vortdivspec(vsht,a,b,kmax,mw,zw)
    zmn3 = - ddtvort
    dmn3 = ddtdiv
    junk,ttend = vortdivspec(vsht,ut,vt,kmax,mw,zw)
    for k in range(kmax):
        dddt = dmn3[k] - lap_sht(sht,e[k],mw,zw) 
        dmn3[k] = dddt
        tmn3[k] = -ttend[k] + sht(ri[k]).cpu()
        wmn3[k] = sht(wj[k]).cpu() ### Prescribed heating converted to spectral
    qmn3 = -sht(cbar[1]).cpu() ### Only cbar here since dbar is included in implicit or explicit
    # 
    # Diffusion, Damping, Implicit or Explicit time differencing, Time filter
    #
    zmn3,dmn3,tmn3 = diffsn(zmn1,zmn3,dmn1,dmn3,tmn1,tmn3,\
                            mw,zw)
    #
    zmn3,dmn3,tmn3,qmn3 = damp_prescribed_mean(zmn1,zmn3,dmn1,dmn3,tmn1,tmn3,qmn1,qmn3,\
                          kmax,mw,zw)
    #
    dt = 86400.0/steps_per_day
    #
    zmn1,zmn2,zmn3,dmn1,dmn2,dmn3,tmn1,tmn2,tmn3,qmn1,qmn2,qmn3 = \
                        explicit(dt,amtrx,cmtrx,dmtrx,emtrx,\
                        zmn1,zmn2,zmn3,dmn1,dmn2,dmn3,tmn1,tmn2,\
                        tmn3,wmn1,wmn2,wmn3,qmn1,qmn2,qmn3,phismn,\
                        delsig,kmax,mw,zw)
    ####
    ## Reset zmn3, dmn3, tmn3,wmn3 & qmn3
    ###
    zmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
    dmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
    tmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
    wmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
    qmn3 = torch.zeros((mw,zw),dtype=torch.complex128)
#
# Done

In [None]:
## Write spectral data for possible restart
##
torch.save(zmn1,datapath+'zmn1.spectral.pt')
torch.save(zmn2,datapath+'zmn2.spectral.pt')
torch.save(dmn1,datapath+'dmn1.spectral.pt')
torch.save(dmn2,datapath+'dmn2.spectral.pt')
torch.save(tmn1,datapath+'tmn1.spectral.pt')
torch.save(tmn2,datapath+'tmn2.spectral.pt')
torch.save(qmn1,datapath+'qmn1.spectral.pt')
torch.save(qmn2,datapath+'qmn2.spectral.pt')