In [1]:
import pandas as pd
import pickle
import torch

from sklearn.preprocessing import MinMaxScaler

GOLD_PARQUET_PATH = '/usr/datalake/silver/igra/liftedindex_lr/gph20s10k_li.parquet'
ARTIFACTS_PATH = '/usr/datalake/silver/igra/liftedindex_lr/artifacts'

In [2]:
def load_min_max_scaler()-> MinMaxScaler:
    with open(f'{ARTIFACTS_PATH}/ae_min_max_scaler.pkl', 'rb') as f:
        std_scaler = pickle.load(f)

    return std_scaler

class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(127, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, 10)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(10, 50),
            torch.nn.ReLU(),
            torch.nn.Linear(50, 127)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded

In [3]:
# Load the dataset
df = pd.read_parquet(GOLD_PARQUET_PATH)

# Grab 100 random samples
df = df.sample(100)
df = df.reset_index(drop=True)

# Remove irrelevant data
X = df.drop(['id', 'effective_date', 'hour', 'li'], axis=1)

# Scale the dataset
ss = load_min_max_scaler()
X = ss.transform(X)

# Convert to tensor
X = torch.from_numpy(X).float()
X.shape

torch.Size([100, 127])

In [4]:
model = AutoEncoder()
model.load_state_dict(torch.load(f'{ARTIFACTS_PATH}/ae_fnn.pt'))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=127, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=10, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=10, out_features=50, bias=True)
    (1): ReLU()
    (2): Linear(in_features=50, out_features=127, bias=True)
  )
)

In [5]:
# Encode the values
logits = model.encoder(X).detach()

# Decode the values
decoded = model.decoder(logits).detach()

# Undo the initial scaling transform
decoded = ss.inverse_transform(decoded).round(1)
decoded.shape

(100, 127)

In [6]:
# Add the index back to the dataframe
output = pd.DataFrame(decoded, columns=df.columns[3:-1])
output.insert(0, 'id', df['id'] )
output.insert(1, 'effective_date', df['effective_date'] )
output.insert(2, 'hour', df['hour'] )
output.head()

Unnamed: 0,id,effective_date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,0_v,...,19_temp,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v
0,USM00072440,2005-03-03,12,-0.4,376.8,972.3,0.2,-2.9,-0.0,-0.9,...,-50.0,-66.2,15.6,-4.3,10000.0,259.7,-51.8,-68.3,16.5,-3.8
1,USM00072747,2012-10-29,12,-0.5,330.7,981.1,-4.1,-3.9,-2.3,2.2,...,-53.8,-61.7,24.2,-15.2,10000.0,257.3,-56.2,-64.7,25.3,-15.2
2,USM00072440,2020-09-23,0,0.1,348.5,977.7,18.1,15.2,-2.9,-0.1,...,-34.5,-36.8,6.9,4.6,10000.0,284.8,-38.5,-41.8,8.1,4.6
3,USM00072261,2002-10-30,0,-0.5,314.0,976.5,18.5,12.6,-0.1,-0.2,...,-35.0,-48.4,29.9,14.4,10000.0,282.3,-38.7,-51.8,31.2,14.6
4,USM00072451,2002-05-06,0,0.6,719.3,927.1,23.4,14.5,-0.7,5.1,...,-37.9,-45.5,40.8,6.8,10000.0,281.2,-42.4,-49.7,41.9,6.6


In [7]:
# Interleave the results into the source for visual comparison
interleave = pd.concat([df, output])
interleave = interleave.sort_index()# .sort_values(by=['id', 'effective_date'])
interleave.head()

Unnamed: 0,id,effective_date,hour,day_num,0_gph,0_pres,0_temp,0_dp,0_u,0_v,...,19_dp,19_u,19_v,20_gph,20_pres,20_temp,20_dp,20_u,20_v,li
0,USM00072440,2005-03-03,12,-0.47,391.0,971.4,0.1,-2.0,-0.5,1.4,...,-66.0,22.0,-1.2,10000.0,258.2,-52.8,-68.0,26.3,-0.6,13.9
0,USM00072440,2005-03-03,12,-0.4,376.8,972.3,0.2,-2.9,-0.0,-0.9,...,-66.2,15.6,-4.3,10000.0,259.7,-51.8,-68.3,16.5,-3.8,
1,USM00072747,2012-10-29,12,-0.5,330.7,981.1,-4.1,-3.9,-2.3,2.2,...,-61.7,24.2,-15.2,10000.0,257.3,-56.2,-64.7,25.3,-15.2,
1,USM00072747,2012-10-29,12,-0.54,357.0,983.7,-1.0,-1.4,-2.1,0.3,...,-60.7,18.3,-13.0,10000.0,256.8,-51.0,-62.1,17.6,-10.9,13.5
2,USM00072440,2020-09-23,0,0.05,391.0,973.2,21.3,16.9,-1.4,-0.4,...,-37.1,13.7,11.7,10000.0,284.6,-37.5,-41.3,15.5,15.2,-0.4


In [8]:
interleave.to_csv(f'{ARTIFACTS_PATH}/ae_compare.csv')