In [34]:
import pandas as pd
import pickle
import torch

from sklearn.preprocessing import MinMaxScaler

GOLD_PARQUET_PATH = '/Users/olievortex/lakehouse/default/Files/gold/igra2/liftedindex_lr'
ARTIFACTS_PATH = '/Users/olievortex/lakehouse/default/Files/gold/igra2/artifacts'

In [35]:
def load_min_max_scaler()-> MinMaxScaler:
    with open(f'{ARTIFACTS_PATH}/ae_min_max_scaler.pkl', 'rb') as f:
        std_scaler = pickle.load(f)

    return std_scaler

class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(128, 96),
            torch.nn.ReLU(),
            torch.nn.Linear(96, 64)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(64, 96),
            torch.nn.ReLU(),
            torch.nn.Linear(96, 128),
            torch.nn.Sigmoid()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded

In [36]:
# Load the dataset
df = pd.read_parquet(GOLD_PARQUET_PATH)

# Remove irrelevant data
X = df.drop(['id', 'effective_date', 'hour'], axis=1)

# Scale the X dataset
ss = load_min_max_scaler()
X = ss.transform(X)

# Convert to tensor
X = torch.from_numpy(X).float()

In [37]:
model = AutoEncoder()
model.load_state_dict(torch.load(f'{ARTIFACTS_PATH}/ae_fnn.pt'))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=128, out_features=96, bias=True)
    (1): ReLU()
    (2): Linear(in_features=96, out_features=64, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=64, out_features=96, bias=True)
    (1): ReLU()
    (2): Linear(in_features=96, out_features=128, bias=True)
    (3): Sigmoid()
  )
)

In [38]:
df = df.reset_index()
output = model(X).detach().numpy()
output = ss.inverse_transform(output)
output = pd.DataFrame(output, columns=df.columns[4:])
output.insert(0, 'location', df['id'] )
output.insert(1, 'date', df['effective_date'] )
output.insert(2, 'hour', df['hour'] )
output.head()

Unnamed: 0,location,date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,0_v,...,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v,li
0,USM00072250,2000-01-01,0,-0.866159,14.795053,1013.644287,22.06636,18.798054,-1.993419,0.487667,...,-50.250244,20.526367,-2.177341,10000.0,282.077087,-41.758781,-54.424194,22.490833,-2.599827,1.30645
1,USM00072250,2000-01-01,12,-0.868711,15.475845,1012.876343,15.873312,15.722672,-2.230345,6.192312,...,-54.754837,12.278285,8.33694,10000.0,282.552582,-42.646358,-57.593594,12.024493,7.782256,9.620577
2,USM00072250,2000-01-02,0,-0.871053,15.365338,1011.56958,22.320944,19.204458,-1.797362,6.785477,...,-49.411469,15.183727,17.326891,10000.0,282.676208,-41.306965,-53.500645,15.546011,17.61141,0.269226
3,USM00072250,2000-01-02,12,-0.871176,15.221201,1011.624939,19.696054,19.285385,-3.550931,10.230122,...,-50.248737,21.167494,16.89926,10000.0,282.385345,-42.410706,-54.468704,22.956781,16.821882,2.617467
4,USM00072250,2000-01-03,0,-0.872003,17.11968,1009.241394,23.051853,18.275434,-1.942967,7.50482,...,-51.362766,21.606489,19.412184,10000.0,281.874634,-42.861259,-55.592026,22.094749,19.266973,0.313183


In [39]:
df.head()

Unnamed: 0,index,id,effective_date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,...,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v,li
0,0,USM00072250,2000-01-01,0,-1.0,14.0,1014.2,23.9,18.4,-5.5,...,-50.5,21.3,-1.5,10000.0,282.8,-41.5,-53.7,22.3,-1.7,1.2
1,1,USM00072250,2000-01-01,12,-1.0,14.0,1013.8,15.8,15.5,-1.1,...,-54.4,13.3,7.1,10000.0,282.1,-42.4,-57.3,12.6,8.2,9.7
2,2,USM00072250,2000-01-02,0,-1.0,14.0,1011.6,23.7,18.9,-1.9,...,-49.2,14.3,16.1,10000.0,282.7,-41.1,-52.8,16.1,15.0,-1.3
3,3,USM00072250,2000-01-02,12,-1.0,14.0,1011.5,20.6,18.9,-1.5,...,-50.6,22.9,16.4,10000.0,281.8,-42.0,-53.5,21.9,14.2,0.8
4,4,USM00072250,2000-01-03,0,-1.0,14.0,1008.7,24.2,18.7,-0.0,...,-51.5,22.4,21.7,10000.0,281.4,-42.2,-54.6,22.1,19.7,-2.1
