In [1]:
%load_ext autoreload
%autoreload 2

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import os
import time
import json
import jax.numpy as np
import numpy as onp
import jax
import pickle
import matplotlib.pyplot as plt
import pandas as pd
from timecast.learners import AR
from timecast.learners._ar import _ar_predict, _ar_batch_window
from timecast.utils.numpy import ecdf
from timecast.utils.losses import MeanSquareError
import torch
import matplotlib

plt.rcParams['figure.figsize'] = [20, 10]

import tqdm.notebook as tqdm



In [2]:
from ealstm.gaip import FloodLSTM
from ealstm.gaip import FloodData
from ealstm.gaip.utils import MSE, NSE

from timecast.optim import SGD, Adam
from timecast.learners import AR

cfg_path = "/home/dsuo/src/toy_flood/ealstm/runs/run_2503_0429_seed283956/cfg.json"
ea_data = pickle.load(open("../ealstm/runs/run_2503_0429_seed283956/lstm_seed283956.p", "rb"))
flood_data = FloodData(cfg_path)

LR_AR = 1e-4
AR_INPUT_DIM=32
AR_OUTPUT_DIM=1
BATCH_SIZE=1

In [3]:
import jax.numpy as np
from timecast.utils.losses.core import Loss

class BatchedMeanSquareError(Loss):
    def __init__(self):
        pass

    def compute(self, y_pred: np.ndarray, y_true: np.ndarray):
        return np.mean(np.mean((y_pred - y_true) ** 2, axis=tuple(range(1, y_true.ndim))))

In [4]:
results = {}
mses = []
nses = []
for X, _, basin in tqdm.tqdm(flood_data.generator(), total=len(flood_data.basins)):
#     if basin == "01466500":
#         break
#     sgd = SGD(learning_rate=LR_AR, online=False)
    adam = Adam(learning_rate=LR_AR)
    ar = AR(input_dim=AR_INPUT_DIM,
            output_dim=AR_OUTPUT_DIM,
            window_size=flood_data.cfg["seq_length"],
            optimizer=adam,
            history=X[:flood_data.cfg["seq_length"]],
            fit_intercept=True,
            constrain=False
           )

    # NOTE: difference in indexing convention, so need to pad one row
    X = np.vstack((X[flood_data.cfg["seq_length"]:], np.ones((1, X.shape[1]))))
    Y = np.array(ea_data[basin].qobs).reshape(-1, 1)

    Y_lstm = np.array(ea_data[basin].qsim).reshape(-1, 1)
    Y_target = Y - Y_lstm
    
    break

    Y_ar = ar.predict_and_update(X, Y_target, BATCH_SIZE)

    Y_hat = Y_lstm + Y_ar

    mse = MSE(Y, Y_hat)
    nse = NSE(Y, Y_hat)
    results[basin] = {
        "mse": mse,
        "nse": nse,
        "count": X.shape[0],
        "avg_mse": np.mean(np.array(mses)),
        "avg_nse": np.mean(np.array(nses))
    }
    mses.append(mse)
    nses.append(nse)
    print(basin, mse, nse, np.mean(np.array(mses)), np.mean(np.array(nses)))

HBox(children=(FloatProgress(value=0.0, max=531.0), HTML(value='')))

In [5]:
from timecast.utils.losses.core import Loss

In [6]:
class ProxyLoss(Loss):
    def __init__(self, reg=1.0):
        self._reg = reg
    
    def compute(self, y_pred, y_true):
        return np.dot(y_pred, y_true) + (self._reg / 2) * np.dot(y_pred, y_pred)
        

In [7]:
N = 5
learners = [AR(input_dim=AR_INPUT_DIM,
            output_dim=AR_OUTPUT_DIM,
            window_size=flood_data.cfg["seq_length"],
            optimizer=adam,
            history=X[:flood_data.cfg["seq_length"]],
            fit_intercept=True,
            constrain=False,
            loss=ProxyLoss()
           ) for i in range(N)]
g = jax.grad(MSE)

In [8]:
results = []
se = 0
count = 0
for x, y in tqdm.tqdm(zip(X, Y_target), total=X.shape[0]):
    u_i = [0] * (N + 1)
    y_i = [0] * (N + 1)

    u_i[0] = Y_lstm[0]
    for i in range(1, N + 1):
        eta = 2 / (i + 1)
        y_i[i] = learners[i - 1].predict(x)
        u_i[i] = (1 - eta) * u_i[i - 1] + eta * y_i[i]

    l_i = [0] * (N + 1)
    for i in range(1, N + 1):
        l_i[i] = u_i[i - 1]
        learners[i-1].update(x, g(l_i[i], y.item()))
    
    se += ((u_i[N] - y.item()) ** 2)
    count += 1
    mse = se / count
    print(mse, u_i[N])
    results.append(u_i[N])

HBox(children=(FloatProgress(value=0.0, max=3652.0), HTML(value='')))


[0.00043597] [0.]
[0.05320835] [0.22228757]
[0.08211982] [0.25052735]
[0.07435507] [0.08845995]
[0.06015467] [-0.14964774]
[0.05705834] [-0.22903796]
[0.05315273] [-0.13319266]
[0.04950465] [0.15160407]
[0.05516519] [0.3551321]
[0.06207538] [0.35503277]
[0.06379399] [0.2238663]
[0.05915467] [0.01103893]
[0.05565081] [-0.08561886]
[0.05258388] [0.07640697]
[0.05330439] [0.2076168]
[0.05244794] [0.16901927]
[0.04988781] [0.01430075]
[0.04820101] [-0.20993726]
[0.04695353] [-0.2100434]
[0.04774935] [0.00542375]
[0.06068968] [0.08493429]
[0.05816153] [-0.00358952]
[0.05671545] [0.0259418]
[0.0562822] [0.24387634]
[0.06000724] [0.4053083]
[0.06107292] [0.35524875]
[0.05927717] [0.1913663]
[0.05722069] [0.04454485]
[0.05568274] [-0.04934408]
[0.054546] [-0.06330849]
[0.05364909] [-0.03822611]
[0.05205981] [-0.10758612]
[0.05433958] [-0.15452398]
[0.07209935] [-0.08590274]
[0.0723576] [0.16343655]
[0.07038833] [0.4015878]
[0.07017312] [0.54965067]
[0.07339274] [0.58084846]
[0.07250825] [0.35

[0.5863387] [-1.3472344]
[0.5872085] [-1.2321043]
[0.5866473] [-0.9511783]
[0.5852044] [-0.62556034]
[0.5834065] [-0.26040673]
[0.5819351] [0.12088132]
[0.5809871] [0.3604321]
[0.57984656] [0.35993272]
[0.57827127] [0.21239412]
[0.57653195] [0.06180519]
[0.57479346] [-0.07251346]
[0.5731039] [-0.07483357]
[0.5715931] [-0.00451738]
[0.57037354] [0.12169331]
[0.5693421] [0.25544384]
[0.5682938] [0.29533863]
[0.5669743] [0.25151327]
[0.5656356] [0.26192737]
[0.56449217] [0.33871883]
[0.5628346] [0.47072816]
[0.5613486] [0.6857321]
[0.5607366] [0.9051129]
[0.5612873] [1.1417098]
[0.5619665] [1.136106]
[0.56205326] [0.94848764]
[0.56184924] [0.6932584]
[0.56116986] [0.45833752]
[0.56006384] [0.34285444]
[0.5590334] [0.3829239]
[0.55915457] [0.5528403]
[0.5606342] [0.7358367]
[0.5618958] [0.7464778]
[0.56229764] [0.6074872]
[0.56244725] [0.51885116]
[0.5629764] [0.5228577]
[0.5641173] [0.6043739]
[0.5666043] [0.6435899]
[0.57567585] [0.60370505]
[0.5797666] [0.36801246]
[0.5791838] [0.115982

[0.85926956] [1.3173361]
[0.860972] [1.3578345]
[0.86251915] [1.3325818]
[0.86387503] [1.26524]
[0.8648337] [1.1871796]
[0.8656227] [1.108633]
[0.86612695] [0.9984632]
[0.8661524] [0.8494446]
[0.865674] [0.71366554]
[0.86503386] [0.68431866]
[0.8644603] [0.7301059]
[0.8640509] [0.8167198]
[0.8636903] [0.8092072]
[0.8630311] [0.70537996]
[0.862245] [0.5709801]
[0.86135054] [0.44804534]
[0.860346] [0.34182006]
[0.8592103] [0.23891595]
[0.85806245] [0.18177813]
[0.85680145] [0.16989125]
[0.855607] [0.22737867]
[0.85450804] [0.36176604]
[0.8534743] [0.42749596]
[0.8523697] [0.3896049]
[0.8512263] [0.27777427]
[0.8499987] [0.13457954]
[0.8487344] [0.0221048]
[0.84748065] [-0.07499489]
[0.84633493] [-0.16048007]
[0.8451537] [-0.15088157]
[0.8439035] [-0.04381357]
[0.84266883] [0.09547469]
[0.84146] [0.20993046]
[0.84026635] [0.2078372]
[0.8391568] [0.11570823]
[0.83792627] [0.03946111]
[0.83670187] [0.02412924]
[0.8355478] [0.01545274]
[0.8343278] [-0.03001785]
[0.83312756] [-0.07526296]
[0.

[1.0758598] [-0.42079073]
[1.075264] [-0.54736626]
[1.0747708] [-0.5795561]
[1.0742489] [-0.53533775]
[1.0735999] [-0.47162205]
[1.0751979] [-0.43644333]
[1.0771021] [-0.37546825]
[1.0774809] [-0.2482329]
[1.0770133] [-0.13135463]
[1.0763566] [-0.13503766]
[1.0756389] [-0.18924087]
[1.0748993] [-0.23187727]
[1.074115] [-0.28114253]
[1.0733967] [-0.39111578]
[1.0729063] [-0.5623965]
[1.0726928] [-0.7517791]
[1.0727088] [-0.90542376]
[1.0728745] [-0.98333025]
[1.073268] [-1.0596488]
[1.0735475] [-1.1833663]
[1.0747056] [-1.3587968]
[1.0797395] [-1.4833496]
[1.0821412] [-1.4097106]
[1.0866674] [-1.2019776]
[1.093101] [-0.93891]
[1.0969099] [-0.62613523]
[1.0985193] [-0.2799402]
[1.098135] [-0.0880636]
[1.0975057] [-0.01651108]
[1.0967661] [-0.05318844]
[1.0960076] [-0.15465951]
[1.0953244] [-0.28124547]
[1.0946941] [-0.39922068]
[1.0939255] [-0.50565696]
[1.0931685] [-0.7202236]
[1.0933646] [-0.9977031]
[1.0944626] [-1.1969066]
[1.0952786] [-1.297914]
[1.0954885] [-1.3156257]
[1.0959144] 

[1.1710079] [-0.31528938]
[1.1701126] [-0.32289076]
[1.1692556] [-0.19068986]
[1.168515] [-0.01389337]
[1.1678523] [0.03288466]
[1.1671146] [-0.01780993]
[1.1662942] [-0.09728253]
[1.1654531] [-0.08131862]
[1.1646734] [0.00724196]
[1.1639842] [0.13789105]
[1.1634415] [0.25397474]
[1.1630076] [0.3514462]
[1.1625196] [0.47399807]
[1.162183] [0.61668813]
[1.1623855] [0.7272555]
[1.1622971] [0.724473]
[1.1619834] [0.7841362]
[1.1617564] [0.90653557]
[1.1617918] [1.0546792]
[1.1619965] [1.1778313]
[1.1623921] [1.3005167]
[1.1630448] [1.4328858]
[1.1638237] [1.5017028]
[1.1643851] [1.4322164]
[1.1646507] [1.3423848]
[1.1648145] [1.2657925]
[1.1647923] [1.2158581]
[1.1647257] [1.3172183]
[1.1651679] [1.5303447]
[1.1662813] [1.776017]
[1.1677544] [1.9027138]
[1.1691092] [1.835006]
[1.1703581] [1.731303]
[1.171458] [1.6070228]
[1.1725883] [1.5256084]
[1.1748818] [1.4663086]
[1.1761072] [1.3502051]
[1.1770226] [1.2085593]
[1.1774648] [1.0889895]
[1.177719] [1.0795275]
[1.1789105] [1.1407902]
[1.

[1.2230107] [-1.8942277]
[1.2223713] [-2.0989206]
[1.2216278] [-2.1178703]
[1.220894] [-1.9778217]
[1.2204707] [-1.6569717]
[1.2212634] [-1.2270172]
[1.221146] [-0.74705446]
[1.2207444] [-0.26262537]
[1.2201067] [0.2083765]
[1.2196385] [0.67040676]
[1.2203346] [1.0529834]
[1.2207903] [1.1820341]
[1.2208365] [1.2220294]
[1.22187] [1.1953068]
[1.2225194] [1.0698417]
[1.2222582] [0.8518092]
[1.2216724] [0.8560869]
[1.2220541] [0.9932552]
[1.2226312] [1.0285835]
[1.2232178] [0.8592459]
[1.2227248] [0.63707995]
[1.222035] [0.506727]
[1.2215381] [0.45067465]
[1.2209746] [0.33547464]
[1.2203201] [0.19747886]
[1.2196186] [0.1039452]
[1.218959] [-0.0066545]
[1.2182372] [-0.1470708]
[1.2175092] [-0.25261497]
[1.2167828] [-0.29771298]
[1.2160854] [-0.22029996]
[1.2153587] [-0.23029977]
[1.2146327] [-0.29706386]
[1.2139597] [-0.2650901]
[1.2135692] [-0.21257317]
[1.2129675] [-0.31068185]
[1.2123582] [-0.4419417]
[1.2136539] [-0.607109]
[1.2136902] [-0.498738]
[1.2129693] [-0.41338336]
[1.2122542] 

[1.2073483] [-1.6529028]
[1.2092134] [-1.5972475]
[1.2106789] [-1.5547761]
[1.2118065] [-1.5048548]
[1.2128899] [-1.4517287]
[1.2135658] [-1.342089]
[1.2139775] [-1.2681593]
[1.2143825] [-1.3534914]
[1.2149767] [-1.5661863]
[1.2160364] [-1.8040222]
[1.2174062] [-1.9775826]
[1.2178571] [-1.9919817]
[1.2205644] [-2.0234714]
[1.2213515] [-2.0457616]
[1.222678] [-2.1406846]
[1.2263793] [-2.1938958]
[1.2309855] [-2.1411762]
[1.235238] [-1.9911964]
[1.238511] [-1.8704354]
[1.2416222] [-1.7594779]
[1.2442796] [-1.6591041]
[1.2454932] [-1.5565555]
[1.2471285] [-1.538582]
[1.2481371] [-1.4691654]
[1.248964] [-1.4584463]
[1.2495985] [-1.4815892]
[1.2505517] [-1.5259609]
[1.2513433] [-1.6127352]
[1.252737] [-1.6938651]
[1.2540276] [-1.784126]
[1.2552445] [-1.9145131]
[1.2563584] [-2.0640087]
[1.2571238] [-2.2093391]
[1.2578206] [-2.3189664]
[1.2587674] [-2.3930957]
[1.25962] [-2.4653132]
[1.2603883] [-2.5750656]
[1.2615497] [-2.699255]
[1.2632933] [-2.7957523]
[1.2648971] [-2.854179]
[1.2676193] 

[1.2334011] [1.7850163]
[1.2350159] [2.2257512]
[1.2345103] [2.7730718]
[1.2343239] [3.3276024]
[1.2337942] [3.732355]
[1.2337061] [4.034257]
[1.2333156] [4.334549]
[1.2328694] [4.5940886]
[1.232595] [4.8093133]
[1.2324811] [4.9378953]
[1.2347654] [4.9414897]
[1.2380084] [4.8213587]
[1.2411692] [4.521916]
[1.2441213] [4.141579]
[1.2464148] [3.675984]
[1.2478772] [3.130289]
[1.2486843] [2.5624084]
[1.2491995] [2.0874958]
[1.2492398] [1.7416499]
[1.249451] [1.6098497]
[1.2498083] [1.5168355]
[1.2502179] [1.4065931]
[1.2502147] [1.2643037]
[1.2499664] [1.1748595]
[1.2497112] [1.1644686]
[1.2495933] [1.2490165]
[1.2497356] [1.3946165]
[1.2500741] [1.5354066]
[1.2504872] [1.636757]
[1.2510858] [1.7463799]
[1.2519255] [1.7942502]
[1.2536647] [1.7618004]
[1.2564949] [1.558646]
[1.2608027] [1.185286]
[1.2637112] [0.7663009]
[1.2655877] [0.32889253]
[1.2656848] [0.0277285]
[1.2656891] [-0.15509859]
[1.2659972] [-0.30920786]
[1.2658321] [-0.48616445]
[1.2652937] [-0.64914227]
[1.264969] [-0.6121

[1.3194072] [0.7176974]
[1.3207148] [0.5005509]
[1.3207774] [0.23334965]
[1.3202784] [0.04639861]
[1.3197787] [0.00377083]
[1.3192812] [0.08011778]
[1.3190011] [0.18605405]
[1.3185618] [0.43154395]
[1.3181409] [0.8031983]
[1.3180692] [1.1287149]
[1.3181775] [1.2829665]
[1.3177913] [1.2682422]
[1.3181872] [1.2325776]
[1.3180345] [1.1930456]
[1.3176894] [1.165485]
[1.3172724] [1.148911]
[1.3168187] [1.0689725]
[1.3163509] [0.9035289]
[1.3158559] [0.73687714]
[1.315373] [0.6620904]
[1.3150094] [0.6690263]
[1.3146226] [0.6469387]
[1.3141552] [0.5855222]
[1.3137847] [0.48025322]
[1.3136065] [0.32830924]
[1.3132662] [0.17357194]
[1.3128341] [0.04669869]
[1.3123603] [-0.07619488]
[1.311868] [-0.17970955]
[1.3114158] [-0.22750258]
[1.3109406] [-0.16252047]
[1.3105159] [-0.11666858]
[1.3100364] [-0.13203019]
[1.3095461] [-0.20289296]
[1.3090564] [-0.24017584]
[1.308575] [-0.21851838]
[1.3081315] [-0.19264448]
[1.3079646] [-0.18119937]
[1.310225] [-0.16442078]
[1.3098117] [-0.00194645]
[1.309326

[1.3660867] [-0.04246235]
[1.3656414] [-0.10085356]
[1.3651953] [-0.12095928]
[1.3647535] [-0.09701872]
[1.3642995] [-0.03114593]
[1.3638428] [0.06693661]
[1.3634108] [0.16790915]
[1.3629541] [0.26262927]
[1.3625636] [0.28999472]
[1.3622668] [0.26668572]
[1.3619182] [0.30236864]
[1.3615265] [0.38158274]
[1.3610706] [0.49149776]
[1.3606203] [0.46046317]
[1.3601669] [0.28866184]
[1.359715] [0.0608182]
[1.3592666] [-0.18379426]
[1.3588166] [-0.3628546]
[1.3584281] [-0.51646036]
[1.358211] [-0.6898038]
[1.3580011] [-0.88227785]
[1.3579258] [-1.0347346]
[1.3578682] [-1.0440173]
[1.3577546] [-0.978557]
[1.3576529] [-1.0085363]
[1.3576274] [-1.1085423]
[1.357659] [-1.2300817]
[1.357783] [-1.3298599]
[1.3579708] [-1.4106265]
[1.358251] [-1.5127358]
[1.3586407] [-1.65262]
[1.3591655] [-1.7563297]
[1.359561] [-1.685217]
[1.3599039] [-1.584667]
[1.360168] [-1.4730552]
[1.3604163] [-1.4572083]
[1.3605652] [-1.4940681]
[1.3608229] [-1.549146]
[1.3611047] [-1.6646466]
[1.3614655] [-1.8062155]
[1.361

[1.4910514] [-0.9370325]
[1.4907537] [-0.9268637]
[1.4906063] [-0.9738966]
[1.4904674] [-0.9932341]
[1.4902782] [-0.952176]
[1.490057] [-0.852988]
[1.4897792] [-0.740947]
[1.4894394] [-0.6314194]
[1.4890643] [-0.5326452]
[1.4886508] [-0.42886817]
[1.4882374] [-0.337888]
[1.4878324] [-0.24664396]
[1.4875252] [-0.10887133]
[1.4871478] [0.04750951]
[1.4867163] [0.1756843]
[1.4862689] [0.24185891]
[1.4858217] [0.15667325]
[1.4853818] [-0.00920895]
[1.4849464] [-0.07334962]
[1.4845091] [-0.04705551]
[1.484063] [0.02808146]
[1.4836173] [0.07750314]
[1.483172] [0.06989117]
[1.4827267] [0.03313488]
[1.4823205] [0.01612172]
[1.4820056] [0.04549879]
[1.4815776] [0.14596874]
[1.4811624] [0.27792785]
[1.4807872] [0.39526066]
[1.48035] [0.50539756]
[1.4799106] [0.6100232]
[1.4795046] [0.682793]
[1.479184] [0.6976594]
[1.4787483] [0.57442856]
[1.478307] [0.36796847]
[1.4778813] [0.17348064]
[1.477457] [0.08138734]
[1.477021] [0.0770067]
[1.4765828] [0.13985315]
[1.4762545] [0.17236134]
[1.4758285] [

[1.4550833] [0.8820933]
[1.4546843] [1.030908]
[1.4544382] [1.2250346]
[1.4540998] [1.3202043]
[1.4537225] [1.2593415]
[1.453362] [1.2138178]
[1.4529687] [1.2427601]
[1.4527957] [1.240912]
[1.4526893] [1.2003901]
[1.4525287] [1.125236]
[1.4522517] [1.0788587]
[1.451989] [0.9891808]



In [8]:
def test(xy):
    u_i = [0] * (N + 1)
    y_i = [0] * (N + 1)
    
    u_i[0] = Y_lstm[0]
    for i in range(1, N + 1):
        eta = 2 / (i + 1)
        y_i[i] = learners[i - 1].predict(xy[0])
        u_i[i] = (1 - eta) * u_i[i - 1] + eta * y_i[i]

    l_i = [0] * (N + 1)
    for i in range(1, N + 1):
        l_i[i] = u_i[i - 1]
        learners[i-1].update(xy[0], g(l_i[i], xy[1].reshape(())))
    return u_i[N]

In [10]:
%%timeit -r 1 -n 1
Y_xboost = jax.lax.map(test, (X, Y_target))

46.6 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [9]:
Y_xboost = jax.lax.map(test, (X, Y_target))




In [10]:
Y_xboost

DeviceArray([[0.],
             [0.],
             [0.],
             ...,
             [0.],
             [0.],
             [0.]], dtype=float32)