In [11]:
import pandas as pd
import pickle
import torch

from sklearn.preprocessing import MinMaxScaler

GOLD_PARQUET_PATH = '/usr/datalake/silver/igra/liftedindex_lr/gph20s10k_li.parquet'
ARTIFACTS_PATH = '/usr/datalake/silver/igra/liftedindex_lr/artifacts'

In [12]:
def load_min_max_scaler()-> MinMaxScaler:
    with open(f'{ARTIFACTS_PATH}/ae_min_max_scaler.pkl', 'rb') as f:
        std_scaler = pickle.load(f)

    return std_scaler

class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(127, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, 10)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(10, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, 127)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded

In [40]:
# Load the dataset
df = pd.read_parquet(GOLD_PARQUET_PATH)

# Grab 100 random samples
df = df.sample(100)
df = df.reset_index(drop=True)

# Remove irrelevant data
X = df.drop(['id', 'effective_date', 'hour', 'li'], axis=1)

# Scale the dataset
ss = load_min_max_scaler()
X = ss.transform(X)

# Convert to tensor
X = torch.from_numpy(X).float()
X.shape

torch.Size([100, 127])

In [41]:
model = AutoEncoder()
model.load_state_dict(torch.load(f'{ARTIFACTS_PATH}/ae_fnn.pt'))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=127, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=10, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=127, bias=True)
  )
)

In [54]:
# Encode the values
logits = model.encoder(X).detach()

# Decode the values
decoded = model.decoder(logits).detach()

# Undo the initial scaling transform
decoded = ss.inverse_transform(decoded).round(1)
decoded.shape

(100, 127)

In [58]:
# Add the index back to the dataframe
output = pd.DataFrame(decoded, columns=df.columns[3:-1])
output.insert(0, 'id', df['id'] )
output.insert(1, 'effective_date', df['effective_date'] )
output.insert(2, 'hour', df['hour'] )
output.head()

Unnamed: 0,id,effective_date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,0_v,...,19_temp,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v
0,USM00072645,2009-04-30,12,0.5,220.4,986.7,12.4,6.7,-0.5,1.7,...,-44.3,-53.0,21.4,12.8,10000.0,271.8,-47.8,-56.6,22.6,12.8
1,USM00072249,2006-05-07,0,0.6,245.3,983.5,19.6,15.5,-1.0,-0.8,...,-39.5,-59.6,42.7,2.7,10000.0,276.6,-43.0,-61.8,44.0,3.1
2,USM00072250,2007-03-19,0,-0.2,233.2,991.9,18.0,9.3,-3.3,5.6,...,-38.2,-44.4,23.5,-8.1,10000.0,282.1,-42.4,-48.6,25.0,-8.5
3,USM00072764,2003-08-28,0,0.5,494.3,952.0,26.6,19.2,-3.1,8.0,...,-34.5,-39.6,17.6,3.7,10000.0,284.3,-38.7,-44.2,18.5,3.2
4,USM00072451,2025-09-24,0,0.1,703.7,933.2,22.0,13.3,0.4,-3.3,...,-34.5,-56.8,28.3,9.7,10000.0,282.7,-37.8,-59.1,29.6,10.4


In [62]:
# Interleave the results into the source for visual comparison
interleave = pd.concat([df, output])
interleave = interleave.sort_index()# .sort_values(by=['id', 'effective_date'])
interleave.head()

Unnamed: 0,id,effective_date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,0_v,...,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v,li
0,USM00072645,2009-04-30,12,0.5,209.0,990.9,10.0,7.1,-0.4,2.1,...,-53.1,29.6,19.6,10000.0,271.3,-48.9,-58.2,32.0,16.4,12.0
0,USM00072645,2009-04-30,12,0.5,220.4,986.7,12.4,6.7,-0.5,1.7,...,-53.0,21.4,12.8,10000.0,271.8,-47.8,-56.6,22.6,12.8,
1,USM00072249,2006-05-07,0,0.6,245.3,983.5,19.6,15.5,-1.0,-0.8,...,-59.6,42.7,2.7,10000.0,276.6,-43.0,-61.8,44.0,3.1,
1,USM00072249,2006-05-07,0,0.6,198.0,987.8,19.9,17.0,-0.5,-6.2,...,-64.7,38.0,-2.8,10000.0,275.9,-43.1,-65.0,41.0,-4.1,1.0
2,USM00072250,2007-03-19,0,-0.21,198.0,992.9,21.8,18.0,-7.4,9.8,...,-49.6,23.7,-10.8,10000.0,280.4,-44.4,-53.3,24.7,-15.1,-1.2


In [63]:
interleave.to_csv('/usr/datalake/silver/igra/liftedindex_lr/artifacts/compare.csv')