In [1]:
import os
import time
import numpy as np
import netCDF4 as nc
import argparse

from stb import strtobool

from msh import load_mesh, load_flow, \
                sort_mesh, sort_flow
from ops import trsk_mats

from _dx import HH_TINY, UU_TINY
from _dx import invariant, diag_vars, tcpu
from _dt import step_eqns

import torch as tn
import torchtt as tt

from sympy.ntheory import factorint

from temp_init import init_file
from timer import Timer

In [2]:
class base: pass

cnfg = base()

cnfg.save_freq = 100
cnfg.stat_freq = 100

cnfg.integrate = 'RK44'
cnfg.operators = 'TRSK-CV'
cnfg.equations = 'SHALLOW-WATER'
cnfg.ke_upwind = 'AUST-CONST'
cnfg.ke_scheme = 'CENTRE'
cnfg.pv_upwind = 'AUST-ADAPT'
cnfg.pv_scheme = 'UPWIND'

cnfg.du_damp_4 = 0.0
cnfg.vu_damp_4 = 0.0

cnfg.iteration = 0
cnfg.no_rotate = False

# testing with simple, quasi-linear
# test case on low resolution mesh
name = 'io/ltc1_cvt_5.nc'
path, file = os.path.split(name)
save = os.path.join(path, "out_" + file)

In [3]:
print("Loading input assets...")

# load mesh + init. conditions
mesh = load_mesh(name)
flow = load_flow(name, None, lean=True)


print("Creating output file...")

init_file(name, cnfg, save, mesh, flow)


print("Reordering mesh data...")

mesh = sort_mesh(mesh, True)
flow = sort_flow(flow, mesh, lean=True)

u0_edge = flow.uu_edge[-1, :, 0]
uu_edge = u0_edge
ut_edge = u0_edge * 0.0

h0_cell = flow.hh_cell[-1, :, 0]
hh_cell = h0_cell
ht_cell = h0_cell * 0.0

hh_cell = np.maximum(HH_TINY, hh_cell)


print("Forming coefficients...")

# set sparse spatial operators
trsk = trsk_mats(mesh)

# remap fe,fc is more accurate?
flow.ff_edge = trsk.edge_stub_sums * flow.ff_vert
flow.ff_edge = \
    (flow.ff_edge / mesh.edge.area)

flow.ff_cell = trsk.cell_kite_sums * flow.ff_vert
flow.ff_cell = \
    (flow.ff_cell / mesh.cell.area)

flow.ff_cell *= (not cnfg.no_rotate)
flow.ff_edge *= (not cnfg.no_rotate)
flow.ff_vert *= (not cnfg.no_rotate)

kp_sums = np.zeros((
    cnfg.iteration // cnfg.stat_freq + 1), dtype=float)
en_sums = np.zeros((
    cnfg.iteration // cnfg.stat_freq + 1), dtype=float)


print('Done.')

Loading input assets...
Creating output file...
Reordering mesh data...
Forming coefficients...
Done.


Let's figure out now to fold a TT operator arising from a nonsquare matrix.

In particular, look at `trsk.cell_flux_sums`, which has shape `(ncells, nedges)` and takes values on edges $\rightarrow$ values on cells.
For example, this calculates the TRiSK divergence of the flux at cell edges:
```
uh_cell = trsk.cell_flux_sums @ uh_edge
```

We will start with a dummy `uh_edge`.

In [4]:
#uu_edge = np.random.rand(mesh.edge.size)
uh_edge = np.ones(mesh.edge.size)
print(f'ncells = {mesh.cell.size}\tnedges = {mesh.edge.size}')

ncells = 10242	nedges = 30720


In [5]:
prime_factors_nedges = list( factorint(mesh.edge.size).items() )
print(f'prime_factors_nedges = {prime_factors_nedges}')

prime_factors_ncells = list( factorint(mesh.cell.size).items() )
print(f'prime_factors_ncells = {prime_factors_ncells}')

prime_factors_nedges = [(2, 11), (3, 1), (5, 1)]
prime_factors_ncells = [(2, 1), (3, 2), (569, 1)]


In [6]:
tt_tens_cell_shape = []
for factor in prime_factors_ncells:
    tt_tens_cell_shape += [ factor[0] ] * factor[1]
# END for

print(f'tt_tens_cell_shape = {tt_tens_cell_shape}')

tt_tens_edge_shape = []
for factor in prime_factors_nedges:
    tt_tens_edge_shape += [ factor[0] ] * factor[1]
# END for

print(f'tt_tens_edge_shape = {tt_tens_edge_shape}')

tt_tens_cell_shape = [2, 3, 3, 569]
tt_tens_edge_shape = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 5]


In [7]:
# something like??
tt_op_shape =  [(2, 2), (3, 2), (3, 2), (569, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 3), (1, 5)]
print(tt_op_shape)

# the number of elements works out
print(2*2*3*2*3*2*569*2*2*2*2*2*2*2*2*3*5)
print(mesh.cell.size*mesh.edge.size)

[(2, 2), (3, 2), (3, 2), (569, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 2), (1, 3), (1, 5)]
314634240
314634240


In [8]:
tt_op_timer = Timer()
tt_round_op_timer = Timer()

tt_tens_timer = Timer()

# need to get dense version of the operator
# to pass to tt.TT()
dense_cell_flux_sums = trsk.cell_flux_sums.todense()


# form the tt tensor
tt_tens_timer.start()
tt_uh_edge = tt.TT(uh_edge, tt_tens_edge_shape)
tt_tens_timer.stop()

print(f'tt_uh_edge = {tt_uh_edge}')
print(f'time to tt.TT(uu_edge) = {tt_tens_timer.get_time()}')


# form the tt operator
tt_op_timer.start()
tt_cell_flux_sums = tt.TT(dense_cell_flux_sums, tt_op_shape)
tt_op_timer.stop()

print(f'\ntt_cell_flux_sums = {tt_cell_flux_sums}')
print(f'time to tt.TT(dense_cell_flux_sums) = {tt_op_timer.get_time()}')

tt_uh_edge = TT with sizes and ranks:
N = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 5]
R = [1, np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1), 1]

Device: cpu, dtype: torch.float64
#entries 30 compression 0.0009765625

time to tt.TT(uu_edge) = 0.0016138553619384766

tt_cell_flux_sums = TT-matrix with sizes and ranks:
M = [2, 3, 3, 569, 1, 1, 1, 1, 1, 1, 1, 1, 1]
N = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 5]
R = [1, 4, np.int64(10), np.int64(26), 1920, 960, 480, 240, 120, 60, 30, 15, 5, 1]
Device: cpu, dtype: torch.float64
#entries 61725926 compression 0.19618311726021936

time to tt.TT(dense_cell_flux_sums) = 30.32522702217102


In [9]:
tt_mult_timer = Timer()
tt_round_timer = Timer()

tt_mult_timer.start()
tt_mult_result = tt_cell_flux_sums @ tt_uh_edge
tt_mult_timer.stop()

print(f'tt_mult_result (before round) = {tt_mult_result}')
print(f'time to @ = {tt_mult_timer.get_time()}')


tt_round_timer.start()
tt_mult_result = tt_mult_result.round()
tt_round_timer.stop()

print(f'\ntt_mult_result (after round) = {tt_mult_result}')
print(f'time to round = {tt_round_timer.get_time()}')

print(f'\ntotal time = {tt_mult_timer.get_time() + tt_round_timer.get_time()}')

tt_mult_result (before round) = TT with sizes and ranks:
N = [2, 3, 3, 569, 1, 1, 1, 1, 1, 1, 1, 1, 1]
R = [1, 4, 10, 26, 1920, 960, 480, 240, 120, 60, 30, 15, 5, 1]

Device: cpu, dtype: torch.float64
#entries 30862918 compression 3013.3682874438587

time to @ = 0.03167009353637695

tt_mult_result (after round) = TT with sizes and ranks:
N = [2, 3, 3, 569, 1, 1, 1, 1, 1, 1, 1, 1, 1]
R = [1, 2, 6, 18, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

Device: cpu, dtype: torch.float64
#entries 10615 compression 1.0364186682288616

time to round = 1.6153059005737305

total time = 3.293951988220215


In [10]:
csr_mult_timer = Timer()

csr_mult_timer.start()
csr_mult_result = trsk.cell_flux_sums @ uh_edge
csr_mult_timer.stop()

print(f'csr mult time = {csr_mult_timer.get_time()}')


print(f'tt mult time / csr mult time = {(tt_mult_timer.get_time() + tt_round_timer.get_time()) / csr_mult_timer.get_time()}')

csr mult time = 0.00012874603271484375
tt mult time / csr mult time = 19188.661111111112


In [11]:
full_mult_timer = Timer()

full_mult_timer.start()
full_mult_result = dense_cell_flux_sums @ uh_edge
full_mult_timer.stop()

print(f'full mult time = {full_mult_timer.get_time()}')

full mult time = 0.036134958267211914


In [12]:
csr_err = np.sum( np.square(full_mult_result - csr_mult_result) )
print(f'csr_err = {csr_err}')

tt_err = np.sum( np.square(full_mult_result - tt_mult_result.full().flatten().numpy()) )
print(f'tt_err = {tt_err}')

csr_err = 4.153273855628752e-19
tt_err = 5.544198350717207e-13
