In [None]:
import os
import subprocess
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch_harmonics as th
import torch_harmonics.distributed as dist

module_path = os.path.abspath(os.path.join('..', 'MultiThread_Model'))
if module_path not in sys.path:
    sys.path.append(module_path)
from subs1_utils import *

# Model Variables
In the following cell you can set the values of the variables relevant to the model. The details of each variable are included in the README. In most cases it is only necessary to set values for the standard variables.

In [None]:
# Set model  parameters.

# Standard Variables
zw = 63
mw = 63
kmax = 26
expname = 'Test'
toffset = 0
datapath_init = None

# Advanced Variables
mw = None
jmax = None
imax = None
steps_per_day = None
custom_path = None
custom_kmax = None

# The calendar month matters in this version with an annual cycle.
times = pd.date_range(start = '1950-01-01', end='2100-01-01', freq='D')

In [None]:
# Initialize the model.

# Set value of kmax if custom_kmax is used.
if not(custom_kmax is None):
    kmax = custom_kmax
    print("Using custom value for kmax:", kmax)
# Otherwise check value for kmax.
elif kmax != 11 and kmax != 26:
    raise Exception(
        "Unexpected value for kmax."
        " Use custom_kmax and note that other values are implementable,"
        " but the user must modify subs1_utils.py routine bscst."
        " If unclear email bkirtman@miami.edu for clarification.")

# Check value for zw.
# Afterwards, set mw, jmax, imax, and steps_per_day values based on the
# value given to zw. If a value is already given for one of the listed
# variables, use that instead.
match zw:
    case 42:
        mw = 42 if (mw is None) else mw
        jmax = 64 if (jmax is None) else jmax
        imax = 128 if (imax is None) else imax
        steps_per_day = 216 if (steps_per_day is None) else steps_per_day
    case 63:
        mw = 63 if (mw is None) else mw
        jmax = 96 if (jmax is None) else jmax
        imax = 192 if (imax is None) else imax
        steps_per_day = 324 if (steps_per_day is None) else steps_per_day
    case 124:
        mw = 124 if (mw is None) else mw
        jmax = 188 if (jmax is None) else jmax
        imax = 376 if (imax is None) else imax
        steps_per_day = 648 if (steps_per_day is None) else steps_per_day
    case _:
        if (mw is None) or (jmax is None) or (imax is None) \
                or (steps_per_day is None):
            raise Exception(
                "Unexpected value for zw. Other values are implementable,"
                " but the user must specify values for mw, jmax, imax,"
                " and steps_per_day in the advanced variables section.")
print("zw =", zw,
      "\nmw =", mw,
      "\nkmax =", kmax,
      "\njmax =", jmax,
      "\nimax =", imax,
      "\nsteps_per_day =", steps_per_day)

In [None]:
# Get preprocess path.
preprocess_path = get_preprocess_path(zw, kmax)

# Set output datapath.
datapath = set_model_data_path(custom_path, expname, toffset)
datapath_init = datapath if (datapath_init is None) else datapath
print("datapath_init =", datapath_init)

cost_lg, wlg, lats, lons, vsht, dsht, disht, dvsht, divsht = \
    set_spectral_transforms(jmax, imax, mw, zw)

In [None]:
# Initialize spectral fields (at rest or to be read in).
def initialize(temp_newton,lnpsclim,kmax,mw,zw):
    for k in range (kmax):
        tmn1[k] = tmn1[k] + temp_newton[k]
        tmn2[k] = tmn2[k] + temp_newton[k]
        tmn3[k] = tmn3[k] + temp_newton[k]
    qmn1 = lnpsclim
    qmn2 = lnpsclim
    qmn3 = lnpsclim
    
    # Small preturbation to get things started.
    qmn2[6,3]=qmn1[6,3]+0.001
    qmn2[5,4]=qmn1[5,4]-0.001
    return tmn1,tmn2,tmn3,qmn1,qmn2,qmn3

In [None]:
# Initialization:
# Could read spectral restarts, or could start at rest.
# If wanting to use grid point see jupyter notebook preprocess.

# Implement at rest initial condition,
# but need coriolis since model predicts total vorticity.
coriolis = np.broadcast_to([(4.0*np.pi/86400)*np.sin(lats[j,np.newaxis]*np.pi/180.0) for j in range(jmax)], (jmax, imax))

# Initialize spectral fields (at rest or to be read in)
# Need to read in background temperature data for Newtonian
# Relaxation and possible initialization, see preprocess
# for how to change source data or formulation.
temp_newton = torch.load(preprocess_path+'temp.spectral_RadiativeEquilibrium.pt') # This is only used when initializing from a cold start

temp_newton_ggrid = torch.load(preprocess_path+'temp.ggrid_RadiativeEquilibrium.pt')
temporary_data_ggrid = torch.load(preprocess_path+'cooling.ggrid_RadiativeEquilibrium.pt') # Newtonian cooling coefs
newton_ggrid = torch.from_numpy(temporary_data_ggrid)

# lnps for initializing a cold start.
lnpsclim = torch.load(preprocess_path+'lnps.spectral_RadiativeEquilibrium.pt')

zmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
zmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
if ( toffset > 0 ): # Restart
    zmn1 = torch.load(datapath_init+'zmn1.spectral.pt')
    zmn2 = torch.load(datapath_init+'zmn2.spectral.pt')
zmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)

dmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
dmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
if ( toffset > 0 ): # Restart
    dmn1 = torch.load(datapath_init+'dmn1.spectral.pt')
    dmn2 = torch.load(datapath_init+'dmn2.spectral.pt')
dmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)

tmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
tmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
if ( toffset > 0 ): # Restart
    tmn1 = torch.load(datapath_init+'tmn1.spectral.pt') 
    tmn2 = torch.load(datapath_init+'tmn2.spectral.pt')
tmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)

wmn1 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
wmn2 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
wmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)

qmn1 = torch.zeros((mw,zw),dtype=torch.complex128)
qmn2 = torch.zeros((mw,zw),dtype=torch.complex128)
if ( toffset > 0 ): # Restart
    qmn1 = torch.load(datapath_init+'qmn1.spectral.pt')
    qmn2 = torch.load(datapath_init+'qmn2.spectral.pt')
qmn3 = torch.zeros((mw,zw),dtype=torch.complex128)

if ( toffset == 0 ): # Only do this if cold start
    tmn1,tmn2,tmn3,qmn1,qmn2,qmn3 =\
    initialize(temp_newton,lnpsclim,kmax,mw,zw)

# Topography data:
# This should be spectral data or can be initialized to zero.
# If grid point data is desired see preprocess for how to convert
# to spectral. 

# Setting topography to zero here.
phismn = torch.zeros((mw,zw),dtype=torch.complex128)

heat = torch.zeros((kmax,jmax,imax),dtype=torch.float64)

In [None]:
# Constants, parameters, vertical differencing parameters,
# matricies for geopotential height, etc ...
# The following functions can be found in subs1_utils.py

# The bsct function is related to vertical structure.
# The code in subs1_utils.py would need to be changed if the vertical
# resolution was changed. This could be done by simply specifying delsig
# in bscst.
delsig, si, sl, sikap, slkap, cth1, cth2, r1b, r2b = bscst(kmax)

# The mcoeff function is for geopotential height and implicit scheme.
# It's unlikely any changes would be needed.
amtrx, cmtrx, dmtrx = mcoeff(kmax, si, sl, slkap, r1b, r2b, delsig)

# The emtrix function is used in the implicit time scheme.
# It's computed once here to save cpu time.
# Changes are likely not necessary.
emtrx = inv_em(dmtrx, steps_per_day, kmax, mw, zw)

In [None]:
# Preprocessing is complete - now time to run model

# The Model Runs in 30-day chuncks.
# We need to specify how many 30-day chunks to run.

# tl is the chunk size.
# This is typically 30 days, but for testing 3 is reasonable.
tl = 30

ae = 6.371E+06 # Earth radius
tmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
zmnt = tmnt.detach().clone()
dmnt = tmnt.detach().clone()
qmnt = torch.zeros((tl,mw,zw),dtype=torch.complex128)
wmnt = tmnt.detach().clone()

ddtdiv = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
ddtvort = ddtdiv.detach().clone()
ttend = ddtdiv.detach().clone()

vort = torch.zeros((kmax,jmax,imax),dtype=torch.float64)
div = ddtdiv.detach().clone()
temp = ddtdiv.detach().clone()
qdot = ddtdiv.detach().clone()
heatt = torch.zeros((tl,kmax,jmax,imax),dtype=torch.float64)

# Suggested ichunk for time dependent models: 120
ichunk = 12

idays = tl * ichunk


# Begin Time Loop.
ii = 0
savedat = 0
daycount = 0
total_days = 0
nstep = idays*steps_per_day
while ii < nstep:
    ii = ii + 1
    savedat = savedat + 1
    zmnt[daycount] = zmnt[daycount] + zmn1/steps_per_day
    dmnt[daycount] = dmnt[daycount] + dmn1/steps_per_day
    tmnt[daycount] = tmnt[daycount] + tmn1/steps_per_day
    qmnt[daycount] = qmnt[daycount] + qmn1/steps_per_day
    wmnt[daycount] = wmnt[daycount] + wmn1/steps_per_day
    if (savedat == steps_per_day): # post processing

        # Call Postprocessing Routine as needed.
        print((dmn1[10,2,2],dmn1[9,2,1]))
        daycount = daycount + 1
        total_days = total_days + 1
        print(['Day = ',total_days])
        print(['Date = ',times[total_days+toffset]])
        if (daycount == tl):
            times_30day = times[total_days-tl+toffset:total_days+toffset]
            postprocessing(disht,divsht,zmnt,dmnt,tmnt,qmnt,wmnt,heatt,\
                           phismn,amtrx,times_30day,mw,zw,\
                           kmax,imax,jmax,sl,lats,lons,tl,datapath)
            tmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            zmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            dmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            qmnt = torch.zeros((tl,mw,zw),dtype=torch.complex128)
            wmnt = torch.zeros((tl,kmax,mw,zw),dtype=torch.complex128)
            daycount = 0
        savedat = 0
    
    # Run model for one time step.
    
    # Spectral to grid transformation of needed fields:
    # Vorticity, divergence, temperature, U, V, 
    # grad(ln(Ps)), Q prescibed heating.
    vort = disht(zmn2) # This is the relative vorticity
    div = disht(dmn2)
    temp = disht(tmn2)
    qdot = disht(wmn2)
    u,v = uv(divsht,zmn2,dmn2,mw,zw,kmax,imax,jmax)
    dxq,dyq = gradq(divsht,qmn2,mw,zw,imax,jmax)
    
    # Non-Linear products
    a,b,e,ut,vt,ri,wj,cbar,dbar = nlprod(u,v,vort,div,temp,dxq,dyq,heat,coriolis,delsig,si,sikap,slkap,\
                                         r1b,r2b,cth1,cth2,cost_lg,kmax,imax,jmax)
    #
    # Grid point Newtonian Damping Fallowing Held-Suarez
    #
    temp_old = disht(tmn1)
    ri = ri - newton_ggrid*(temp_old - temp_newton_ggrid)
    #
    #

    # Grid to spectral transformation of nlprod results
    #
    ddtdiv,ddtvort = vortdivspec(vsht,a,b,kmax,mw,zw)
    zmn3 = - ddtvort
    dmn3 = ddtdiv - lap_sht(dsht,e,mw,zw)
    _,ttend = vortdivspec(vsht,ut,vt,kmax,mw,zw)
    #
    tmn3 = -ttend + dsht(ri)
    wmn3 = dsht(wj) ### Prescribed heating converted to spectral
    qmn3 = -dsht(cbar) ### Only cbar here since dbar is included in implicit or explicit
    # 
    # Diffusion, Damping, Implicit or Explicit time differencing, Time filter
    #
    zmn3,dmn3,tmn3 = diffsn(zmn1,zmn3,dmn1,dmn3,tmn1,tmn3,mw,zw)
    #
    zmn3,dmn3,tmn3,qmn3 = damp_heldsuarez(zmn1,zmn3,dmn1,dmn3,tmn1,tmn3,qmn1,qmn3,\
                          lnpsclim,kmax,mw,zw,sl)
    #
    dt = 86400.0/steps_per_day
    #
    zmn1,zmn2,zmn3,dmn1,dmn2,dmn3,tmn1,tmn2,tmn3,qmn1,qmn2,qmn3 = \
                        explicit(dt,amtrx,cmtrx,dmtrx,emtrx,\
                        zmn1,zmn2,zmn3,dmn1,dmn2,dmn3,tmn1,tmn2,\
                        tmn3,wmn1,wmn2,wmn3,qmn1,qmn2,qmn3,phismn,\
                        delsig,kmax,mw,zw)
    ####
    ## Reset zmn3, dmn3, tmn3,wmn3 & qmn3
    ###
    zmn3 = torch.zeros((kmax,mw,zw),dtype=torch.complex128)
    dmn3 = zmn3.detach().clone()
    tmn3 = zmn3.detach().clone()
    wmn3 = zmn3.detach().clone()
    qmn3 = torch.zeros((mw,zw),dtype=torch.complex128)
#
# Done

In [None]:
## Write spectral data for possible restart
##
torch.save(zmn1,datapath+'zmn1.spectral.pt')
torch.save(zmn2,datapath+'zmn2.spectral.pt')
torch.save(dmn1,datapath+'dmn1.spectral.pt')
torch.save(dmn2,datapath+'dmn2.spectral.pt')
torch.save(tmn1,datapath+'tmn1.spectral.pt')
torch.save(tmn2,datapath+'tmn2.spectral.pt')
torch.save(qmn1,datapath+'qmn1.spectral.pt')
torch.save(qmn2,datapath+'qmn2.spectral.pt')