In [23]:
import pandas as pd
import pickle
import torch

from sklearn.preprocessing import MinMaxScaler

GOLD_PARQUET_PATH = '/Users/olievortex/lakehouse/default/Files/gold/igra2/liftedindex_lr'
ARTIFACTS_PATH = '/Users/olievortex/lakehouse/default/Files/gold/igra2/artifacts'

In [24]:
def load_min_max_scaler()-> MinMaxScaler:
    with open(f'{ARTIFACTS_PATH}/ae_min_max_scaler.pkl', 'rb') as f:
        std_scaler = pickle.load(f)

    return std_scaler

class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(127, 96),
            torch.nn.ReLU(),
            torch.nn.Linear(96, 64)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(64, 96),
            torch.nn.ReLU(),
            torch.nn.Linear(96, 127),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded

In [25]:
# Load the dataset
df = pd.read_parquet(GOLD_PARQUET_PATH)

# Remove irrelevant data
X = df.drop(['id', 'effective_date', 'hour', 'li'], axis=1)

# Scale the X dataset
ss = load_min_max_scaler()
X = ss.transform(X)

# Convert to tensor
X = torch.from_numpy(X).float()

In [26]:
model = AutoEncoder()
model.load_state_dict(torch.load(f'{ARTIFACTS_PATH}/ae_fnn.pt'))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=127, out_features=96, bias=True)
    (1): ReLU()
    (2): Linear(in_features=96, out_features=64, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=64, out_features=96, bias=True)
    (1): ReLU()
    (2): Linear(in_features=96, out_features=127, bias=True)
    (3): Sigmoid()
  )
)

In [27]:
df = df.reset_index()
output = model(X).detach().numpy()
output = ss.inverse_transform(output)
output = pd.DataFrame(output, columns=df.columns[4:-1])
output.insert(0, 'location', df['id'] )
output.insert(1, 'date', df['effective_date'] )
output.insert(2, 'hour', df['hour'] )
output.head()

Unnamed: 0,location,date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,0_v,...,19_temp,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v
0,USM00072250,2000-01-01,0,-0.865629,14.780505,1014.132324,21.87665,18.293819,-1.20557,1.378733,...,-37.39484,-49.959999,20.628508,-2.222271,10000.0,282.259399,-41.600098,-53.706142,21.902159,-2.609789
1,USM00072250,2000-01-01,12,-0.869328,14.255654,1014.879456,17.217493,15.250392,-1.83443,6.294243,...,-38.105301,-54.397388,12.700339,7.305906,10000.0,282.509949,-42.489975,-57.34938,12.66036,6.562858
2,USM00072250,2000-01-02,0,-0.870538,14.209485,1014.109192,23.104187,18.197714,-2.276649,6.770547,...,-37.564384,-49.314903,15.249309,16.362921,10000.0,282.867035,-41.461403,-53.047291,15.790112,16.331799
3,USM00072250,2000-01-02,12,-0.869464,14.054573,1015.149231,20.808556,17.218338,-3.792559,11.200052,...,-38.781662,-50.704411,20.461983,16.535282,10000.0,282.754547,-42.993149,-53.847538,21.15967,16.39337
4,USM00072250,2000-01-03,0,-0.870926,14.099384,1013.30603,24.112448,16.735838,-2.208552,8.223051,...,-38.735584,-51.385555,21.564159,19.165627,10000.0,282.245911,-42.723263,-55.217113,22.075907,19.236759


In [28]:
df.head()

Unnamed: 0,index,id,effective_date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,...,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v,li
0,0,USM00072250,2000-01-01,0,-1.0,14.0,1014.2,23.9,18.4,-5.5,...,-50.5,21.3,-1.5,10000.0,282.8,-41.5,-53.7,22.3,-1.7,1.2
1,1,USM00072250,2000-01-01,12,-1.0,14.0,1013.8,15.8,15.5,-1.1,...,-54.4,13.3,7.1,10000.0,282.1,-42.4,-57.3,12.6,8.2,9.7
2,2,USM00072250,2000-01-02,0,-1.0,14.0,1011.6,23.7,18.9,-1.9,...,-49.2,14.3,16.1,10000.0,282.7,-41.1,-52.8,16.1,15.0,-1.3
3,3,USM00072250,2000-01-02,12,-1.0,14.0,1011.5,20.6,18.9,-1.5,...,-50.6,22.9,16.4,10000.0,281.8,-42.0,-53.5,21.9,14.2,0.8
4,4,USM00072250,2000-01-03,0,-1.0,14.0,1008.7,24.2,18.7,-0.0,...,-51.5,22.4,21.7,10000.0,281.4,-42.2,-54.6,22.1,19.7,-2.1
