# Backend for Spectral Methods for Heterogeneous-Agent Models

## Preamble

In [1]:
import sys
import os
import pickle

import time
from timeit import default_timer as timer
from datetime import datetime

import math

import numpy as np
import numpy.matlib
np.set_printoptions(linewidth=220)

import pandas as pd
from IPython.display import display, HTML
def pprint(df): display(HTML(pd.DataFrame(df).head(100).to_html()))

import scipy as sp
import scipy.interpolate
import scipy.sparse as sparse
from scipy.sparse.linalg import spsolve
from scipy import optimize, integrate, stats
from scipy.interpolate import barycentric_interpolate

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.ticker as ticker
from matplotlib import cm
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d import Axes3D
%matplotlib inline
%config InlineBackend.figure_format='retina'

import dmsuite as dm

from tqdm.notebook import tqdm, trange

## Plot helpers

### Saving

In [2]:
def savefig(figorax, filename):
        
    if hasattr(figorax, '_suptitle'):
        plt.setp(figorax._suptitle, visible=False)
        plt.savefig('../figures/'+filename+'_notitle.pdf', bbox_inches='tight')
        plt.setp(figorax._suptitle, visible=True)
    elif hasattr(figorax, 'title'):
        plt.setp(figorax.title, visible=False)
        plt.savefig('../figures/'+filename+'_notitle.pdf', bbox_inches='tight')
        plt.setp(figorax.title, visible=True)
        
    plt.savefig('../figures/'+filename+'.pdf', bbox_inches='tight')
    plt.show()
    print('Saved to: '+filename+'.pdf')

### Matrix plots

In [None]:
def calcdensity(M): return (M!=0).sum() / M.size

In [3]:
def plotmats(matrices, titles, shape, figsize, suptitle, filename):

    fig, ax = plt.subplots(shape[0], shape[1], figsize=figsize)
    ax = ax.flatten()
    
    for i, m in enumerate(matrices):
        ax[i].matshow(m, cmap=cm.seismic, norm=colors.SymLogNorm(linthresh=0.03, linscale=0.03, vmin=-np.max(np.abs(m)), vmax=np.max(np.abs(m))))
        ax[i].set_title(titles[i])

    fig.suptitle(suptitle, size=14)
    plt.tight_layout()
    savefig(fig, filename)

### 3D plots

In [4]:
def plot_3D_surfaces_side_by_side(xx_left, yy_left, zz_left, xx_right, yy_right, zz_right,
                                  title_left, title_right, suptitle, filename=None,
                                  sharez=True, showmeans=False, clipmask_left=None, clipmask_right=None,
                                  xlabel='z', ylabel='a', cmap=cm.magma, showedges=False):
    
    
    targetshape_left = (len(np.unique(yy_left)), len(np.unique(xx_left)))
    xx_left_, yy_left_, zz_left_ = xx_left.copy().reshape(targetshape_left), yy_left.copy().reshape(targetshape_left), zz_left.copy().reshape(targetshape_left)
    
    targetshape_right = (len(np.unique(yy_right)), len(np.unique(xx_right)))
    xx_right_, yy_right_, zz_right_ = xx_right.copy().reshape(targetshape_right), yy_right.copy().reshape(targetshape_right), zz_right.copy().reshape(targetshape_right)
 
    if sharez: zmin, zmax = np.nanmin([np.nanmin(zz_left), np.nanmin(zz_right)]), np.nanmax([np.nanmax(zz_left), np.nanmax(zz_right)])
    else: zmin, zmax = None, None
    
    if showmeans:
        title_left = title_left + '\n\n Abs. mean: {:.2e}'.format(np.nanmean(np.abs(zz_left))) + ' - Abs. median: {:.2e}'.format(np.nanmedian(np.abs(zz_left)))
        title_right = title_right + '\n\n Abs. mean: {:.2e}'.format(np.nanmean(np.abs(zz_right))) + ' - Abs. median: {:.2e}'.format(np.nanmedian(np.abs(zz_right)))
        
    if clipmask_left is not None:
        clipmask_left_ = clipmask_left.reshape(targetshape_left)
        xx_left_[clipmask_left_], yy_left_[clipmask_left_], zz_left_[clipmask_left_] = np.nan, np.nan, np.nan
    if clipmask_right is not None:
        clipmask_right_ = clipmask_right.reshape(targetshape_right)
        xx_right_[clipmask_right], yy_right_[clipmask_right_], zz_right_[clipmask_right_] = np.nan, np.nan, np.nan
    
    fig = plt.figure(figsize=(10,5))
        
    ax0 = fig.add_subplot(121, projection='3d')
    ax0.set_proj_type('ortho')
    if showedges: ax0.plot_surface(xx_left_, yy_left_, zz_left_, cmap=cmap, vmin=zmin, vmax=zmax, linewidth=0.5, edgecolors='k', rstride=1, cstride=1, antialiased=True)
    else: ax0.plot_surface(xx_left_, yy_left_, zz_left_, cmap=cmap, vmin=zmin, vmax=zmax, linewidth=0, rstride=1, cstride=1, antialiased=True)
    ax0.view_init(30, 35+180)
    ax0.set_title(title_left)
    ax0.set_zlim3d([zmin, zmax])

    ax1 = fig.add_subplot(122, projection='3d')
    ax1.set_proj_type('ortho')
    if showedges: ax1.plot_surface(xx_right_, yy_right_, zz_right_, cmap=cmap, vmin=zmin, vmax=zmax, linewidth=0.5, edgecolors='k', rstride=1, cstride=1, antialiased=True)
    else: ax1.plot_surface(xx_right_, yy_right_, zz_right_, cmap=cmap, vmin=zmin, vmax=zmax, linewidth=5, rstride=1, cstride=1, antialiased=True)
    ax1.view_init(30, 35+180)
    ax1.set_title(title_right)
    ax1.set_zlim3d([zmin, zmax])
        
    ax0.ticklabel_format(axis='z', style='sci', scilimits=(-4,4), useOffset=False, useMathText=True)
    ax1.ticklabel_format(axis='z', style='sci', scilimits=(-4,4), useOffset=False, useMathText=True) 
    fig.canvas.draw()
    
    if ax0.zaxis.get_offset_text().get_text() != '':
        ax0.zaxis.get_offset_text().set_visible(False)
        if sharez: suptitle = suptitle + ' ('+ax0.zaxis.get_offset_text().get_text()+')'
        elif showmeans: title_left = title_left.replace('\n\n', ' ('+ax0.zaxis.get_offset_text().get_text()+') \n\n')
        else: title_left = title_left + ' ('+ax1.zaxis.get_offset_text().get_text()+')'
    if ax1.zaxis.get_offset_text().get_text() != '':
        ax1.zaxis.get_offset_text().set_visible(False)
        if not sharez and showmeans: title_right = title_right.replace('\n\n', ' ('+ax1.zaxis.get_offset_text().get_text()+') \n\n')
        elif not sharez: title_right = title_right + ' ('+ax1.zaxis.get_offset_text().get_text()+')'
        
    ax0.set_title(title_left)
    ax0.set_xlabel(xlabel)
    ax0.set_ylabel(ylabel)
    ax1.set_title(title_right)
    ax1.set_xlabel(xlabel)
    ax1.set_ylabel(ylabel)
       
    fig.suptitle(suptitle, size=20, y=1)
    plt.tight_layout()
    
    if filename is not None: savefig(fig, filename)
    else: plt.show()

In [1]:
nice_sizes_dict = {4:(2,2), 9:(3,3), 10:(2, 5), 12:(3, 4), 20:(4, 5)}

def sizes_from_samples(samples, ratio=2/3):
    if samples in nice_sizes_dict.keys(): nice_sizes_dict[samples]
    else:
        nhorizontal = math.ceil((samples/ratio)**0.5)
        nvertical = math.ceil(samples/nhorizontal)
        return nhorizontal,nvertical

In [6]:
def plot_3D_surface_split(z, x, y, sss, zlabel='z', xlabel='x', ylabel='y', samples=6, cmap=cm.magma, sharez=True, suptitle=None, filename=None):
    
    sss_ = sss.reshape((len(z), len(x), len(y)))    
    xx, yy = np.meshgrid(y, x)

    nhorizontal, nvertical = sizes_from_samples(samples, ratio=2/3)
    
    fig, ax = plt.subplots(nvertical, nhorizontal, figsize=(16,8), subplot_kw={'projection':'3d'})
    ax = ax.flatten()
    
    z_samples = np.linspace(0, len(z)-1, samples).astype('int')
    z_samples_isint = np.allclose(z[z_samples].astype('int'), z[z_samples])

    for i in range(nvertical*nhorizontal):
        if i<samples:
            ax[i].plot_surface(xx, yy, sss_[z_samples[i],:,:], cmap=cmap, linewidth=0, antialiased=True)
            ax[i].view_init(30, 35+180)
            if z_samples_isint:  ax[i].set_title(zlabel+f'={z[z_samples[i]]:n}')
            else: ax[i].set_title(zlabel+f'={z[z_samples[i]]:.2f}')
            ax[i].set_xlabel(xlabel)
            ax[i].set_ylabel(ylabel)
            if sharez: ax[i].set_zlim(sss.min(), sss.max())
        else: ax[i].set_axis_off()
    
    plt.suptitle(suptitle, size=20, y=1)
    plt.tight_layout()
    
    if filename is not None: savefig(fig, filename)
    else: plt.show()

### Miscellaneous

In [7]:
from matplotlib.ticker import Locator, LogLocator, AutoLocator, LinearLocator

class MinorSymLogLocator(Locator):
    """
    Minor tick positions for a symlog axis. From: https://stackoverflow.com/a/45696768
    """
    def __init__(self, linthresh, nints=10):
        self.linthresh = linthresh
        self.nintervals = nints

    def __call__(self):
        majorlocs = self.axis.get_majorticklocs()

        if len(majorlocs) == 1:
            return self.raise_if_exceeds(np.array([]))

        dmlower = majorlocs[1] - majorlocs[0]
        dmupper = majorlocs[-1] - majorlocs[-2]

        if majorlocs[0] != 0. and ((majorlocs[0] != self.linthresh and dmlower > self.linthresh) or (dmlower == self.linthresh and majorlocs[0] < 0)):
            majorlocs = np.insert(majorlocs, 0, majorlocs[0]*10.)
        else:
            majorlocs = np.insert(majorlocs, 0, majorlocs[0]-self.linthresh)

        if majorlocs[-1] != 0. and ((np.abs(majorlocs[-1]) != self.linthresh and dmupper > self.linthresh) or (dmupper == self.linthresh and majorlocs[-1] > 0)):
            majorlocs = np.append(majorlocs, majorlocs[-1]*10.)
        else:
            majorlocs = np.append(majorlocs, majorlocs[-1]+self.linthresh)

        minorlocs = []

        for i in range(1, len(majorlocs)):
            majorstep = majorlocs[i] - majorlocs[i-1]
            if abs(majorlocs[i-1] + majorstep/2) < self.linthresh:
                ndivs = self.nintervals
            else:
                ndivs = self.nintervals - 1.

            minorstep = majorstep / ndivs
            locs = np.arange(majorlocs[i-1], majorlocs[i], minorstep)[1:]
            minorlocs.extend(locs)

        return self.raise_if_exceeds(np.array(minorlocs))

    def tick_values(self, vmin, vmax):
        raise NotImplementedError('Cannot get tick locations for a %s type.' % type(self))

## Differentiation matrix helpers

In [8]:
def rescale_nodes_and_derivatives(nodes, Ds, left, right): 
    a = (right-left)/(nodes[-1]-nodes[0])
    b = - (left+right)/(nodes[-1]-nodes[0])
    rescalednodes = a*nodes+b
    rescaledDs = np.zeros_like(Ds)
    for d in range(len(Ds)): rescaledDs[d,:,:] = pow(a, -(d+1))*Ds[d]
    return rescalednodes, rescaledDs

In [15]:
def kronecker_broadcast_along_index_level(M, multiindex, broadcastlevel):
    assert len(multiindex.unique(broadcastlevel)) == M.shape[0], 'M size and broadcastlevel incompatible!'
    MM = M.copy()
    for indexlevel in multiindex.names:
        if multiindex.names.index(indexlevel) < multiindex.names.index(broadcastlevel):
            MM = sparse.kron(sparse.eye(len(multiindex.unique(indexlevel))), MM, format='csr')
        elif multiindex.names.index(indexlevel) > multiindex.names.index(broadcastlevel):
            MM = sparse.kron(MM, sparse.eye(len(multiindex.unique(indexlevel))), format='csr')
    if not sparse.issparse(M): MM = MM.toarray()
    return MM

## Masking and indexing helpers

In [10]:
def mask_from_int_idx(fullidx, idx):
    return np.array([x in idx for x in range(len(fullidx))])

def list_of_lists_from_int_idx(idx):
    return np.array([[i] for i in list(idx)])

In [11]:
def int_index_from_labels(multiindex, labels):
    return pd.Series(np.arange(multiindex.size), index=multiindex).loc[labels].values

def mask_from_labels(multiindex, labels):
    idx = int_index_from_labels(multiindex, labels)
    return np.array([x in idx for x in range(len(multiindex))])

In [12]:
def test_generator(matrix):
    print('Diagonal -> Min:', matrix.diagonal().min(), ' Mean:', matrix.diagonal().mean(), ' Max:', matrix.diagonal().max())
    if scipy.sparse.issparse(matrix): zerodiag = matrix - sparse.diags(matrix.diagonal())
    else: zerodiag = matrix - np.diag(matrix.diagonal())
    print('Off-diagonal -> Min:', zerodiag.min(), ' Mean:', zerodiag.mean(), ' Max:', zerodiag.max())
    print('Maximum absolute off-diagonal vs diagonal error:', np.abs(np.sum(matrix, axis=1)).max())

def test_transpose_generator(matrix):
    test_generator(matrix.T)

## Interpolation helpers

In [13]:
def ndspline(a_rough, z_rough, t_rough, s_rough, a_fine, z_fine, t_fine):
    N, M, L = len(a_rough), len(z_rough), len(t_rough)
    I, J, K = len(a_fine), len(z_fine), len(t_fine)

    s_rough_reshape = s_rough.reshape([N, M, L])
    s_fine1 = np.zeros([I, M, L])
    s_fine2 = np.zeros([I, J, L])
    s_fine3 = np.zeros([I, J, K])

    for m, z in enumerate(z_rough):
        for l, t in enumerate(t_rough):
            s_fine1[:,m,l] = sp.interpolate.InterpolatedUnivariateSpline(a_rough, s_rough_reshape[:,m,l])(a_fine)

    for i, a in enumerate(a_fine):
        for l, t in enumerate(t_rough):
            s_fine2[i,:,l] = sp.interpolate.InterpolatedUnivariateSpline(z_rough, s_fine1[i,:,l])(z_fine)

    for i, a in enumerate(a_fine):
        for j, z in enumerate(z_fine):
            s_fine3[i,j,:] = sp.interpolate.InterpolatedUnivariateSpline(t_rough, s_fine2[i,j,:])(t_fine)

    return s_fine3.flatten()

In [3]:
def timespline(t_rough, a_rough, z_rough, s_rough, t_fine, a_fine, z_fine):
    O, N, M =  len(t_rough), len(a_rough), len(z_rough)
    K, I, J = len(t_fine), len(a_fine), len(z_fine)

    s_rough_reshape = s_rough.reshape([O, N, M])
    s_fine_rough = np.zeros([K, N, M])
    s_fine_fine = np.zeros([K, I, J])

    for n, a in enumerate(a_rough):
        for m, z in enumerate(z_rough):
            s_fine_rough[:,n,m] = sp.interpolate.InterpolatedUnivariateSpline(t_rough, s_rough_reshape[:,n,m])(t_fine)

    for o, t in enumerate(t_fine):
        s_fine_fine[o] = sp.interpolate.RectBivariateSpline(a_rough, z_rough, s_fine_rough[o])(a_fine, z_fine)

    return s_fine_fine.flatten()

## Miscellaneous

In [14]:
def backend_test_function():
    print('Backend import successful! -', datetime.now())