In [None]:
import sys, os, copy, datetime
sys.path.append('/content')
from src.fluvius import WaterData, WaterStation
import pandas as pd

env_vars = open("/content/credentials","r").read().split('\n')

for var in env_vars[:-1]:
    key, value = var.split(' = ')
    os.environ[key] = value
    
storage_options= {'account_name':os.environ['ACCOUNT_NAME'],
                    'account_key':os.environ['BLOB_KEY']}

In [42]:
import torch, pickle, fsspec, torch.nn as nn
from src.utils import MultipleRegression

fs = fsspec.filesystem("az", **storage_options)
# Load in the top model metadata
with open("/content/output/mlp/top_model_metadata.pickle", "rb") as f:
    meta = pickle.load(f)

model = MultipleRegression(len(meta["features"]), len(meta["layer_out_neurons"]), meta["layer_out_neurons"], activation_function=eval(f'nn.{meta["activation"]}'))

with open("/content/output/mlp/top_model.pt", "rb") as f:
    model.load_state_dict(torch.load(f))
meta.keys()

dict_keys(['training_data', 'buffer_distance', 'day_tolerance', 'cloud_thr', 'min_water_pixels', 'features', 'learning_rate', 'learn_sched_step_size', 'learn_sched_gamma', 'batch_size', 'layer_out_neurons', 'epochs', 'activation', 'train_loss', 'train_pooled_mse', 'test_site_mse', 'test_pooled_mse', 'y_train_sample_id', 'y_test_sample_id', 'y_obs_train', 'y_pred_train', 'y_obs_test', 'y_pred_test'])

In [None]:
pred_features = pd.read_csv("/content/local/predictions/itv3_features.csv").dropna()
pred_features["is_brazil"] = 1
import numpy as np
from sklearn.preprocessing import MinMaxScaler

data = pd.read_csv(meta["training_data"])
data["Log SSC (mg/L)"] = np.log(data["SSC (mg/L)"])

data = data[data["partition"] != "testing"]

response = "Log SSC (mg/L)"
not_enough_water = data["n_water_pixels"] < meta["min_water_pixels"]
data.drop(not_enough_water[not_enough_water].index, inplace=True)
lnssc_0 = data["Log SSC (mg/L)"] == 0
data.drop(lnssc_0[lnssc_0].index, inplace=True)

scaler = MinMaxScaler()
X_train_scaled = scaler.fit_transform(data[meta["features"]])

X_pred_scaled = np.array(scaler.transform(pred_features[meta["features"]]), dtype=float)

In [49]:
scaled_pred_features = torch.tensor(X_pred_scaled, dtype=torch.float32)
scaled_pred_features
with torch.no_grad():
    y_extrap = model(scaled_pred_features).squeeze().numpy().tolist()
    y_train_pred = model(torch.tensor(X_train_scaled, dtype=torch.float32)).squeeze().numpy().tolist()
y_train_pred

[1.5579839944839478,
 1.6923900842666626,
 3.4504494667053223,
 1.777937889099121,
 1.0353400707244873,
 1.6918058395385742,
 1.1379835605621338,
 0.7803628444671631,
 1.0533170700073242,
 1.0594568252563477,
 4.465786933898926,
 4.740458011627197,
 5.404466152191162,
 3.326268196105957,
 1.3620200157165527,
 1.6048187017440796,
 3.6549360752105713,
 4.35579252243042,
 3.0899577140808105,
 2.711575746536255,
 2.786865711212158,
 2.5722169876098633,
 1.577256441116333,
 2.5347843170166016,
 3.179082155227661,
 3.2267935276031494,
 3.1635141372680664,
 2.8469629287719727,
 2.2190966606140137,
 3.6540207862854004,
 0.8201462626457214,
 2.419983386993408,
 1.3707671165466309,
 1.897777795791626,
 1.025177001953125,
 2.8730804920196533,
 3.41896390914917,
 1.3146989345550537,
 3.879582166671753,
 1.802438735961914,
 1.4955785274505615,
 5.472539901733398,
 5.59542989730835,
 5.601964473724365,
 6.494561195373535,
 3.5865345001220703,
 6.32408332824707,
 3.4910662174224854,
 7.04069423675537

In [53]:
meta["y_pred_train"]

[1.5579842329025269,
 1.6923893690109253,
 3.450448989868164,
 1.7779377698898315,
 1.0353403091430664,
 1.6918069124221802,
 1.1379834413528442,
 0.7803627848625183,
 1.0533170700073242,
 1.0594568252563477,
 4.465786933898926,
 4.740459442138672,
 5.404465675354004,
 3.3262691497802734,
 1.3620210886001587,
 1.6048191785812378,
 3.654934883117676,
 4.355792999267578,
 3.0899577140808105,
 2.711576461791992,
 2.786865711212158,
 2.5722174644470215,
 1.5772567987442017,
 2.534783363342285,
 3.179081439971924,
 3.2267935276031494,
 3.1635138988494873,
 2.8469622135162354,
 2.2190968990325928,
 3.6540210247039795,
 0.8201457262039185,
 2.4199841022491455,
 1.37076735496521,
 1.89777672290802,
 1.0251760482788086,
 2.8730804920196533,
 3.41896390914917,
 1.3146986961364746,
 3.879581928253174,
 1.8024381399154663,
 1.495578408241272,
 5.472540855407715,
 5.59542989730835,
 5.601964473724365,
 6.494561195373535,
 3.5865349769592285,
 6.3240838050842285,
 3.4910664558410645,
 7.040692806243