In [9]:
import pandas as pd
import pickle
import torch

# from sklearn.preprocessing import MinMaxScaler, StandardScaler

IGRA_PATH = '/usr/datalake/silver/igra/gph20s10k'
ARTIFACTS_PATH = '/usr/datalake/silver/stormevents/csvfiles/igra_maidenhead'
STATION_ID = 'USM00072649'

In [10]:
def load_data(station_id: str) -> pd.DataFrame:
    # Load the dataset
    filename = f'{IGRA_PATH}/{station_id}-data-gph20s10k.csv'
    result = pd.read_csv(filename)

    # Remove irrelevant data
    result = result[result['hour'] == 12]
    result = result.drop(['0_gph',
                '1_gph', '2_gph', '3_gph', '4_gph', '5_gph',
                '6_gph', '7_gph', '8_gph', '9_gph', '10_gph',
                '11_gph', '12_gph', '13_gph', '14_gph', '15_gph',
                '16_gph', '17_gph', '18_gph', '19_gph', '20_gph'
                ], axis=1)
    
    # Grab 100 random samples
    result = result.sample(100)
    result = result.reset_index(drop=True)

    return result

def load_scaler(station_id: str):
    with open(f'{ARTIFACTS_PATH}/{station_id}_scaler.pkl', 'rb') as f:
        scaler = pickle.load(f)

    return scaler

def scale_data(scaler, sample: pd.DataFrame):
    result = scaler.transform(sample)

    return result

def unscale_data(scaler, values):
    return scaler.inverse_transform(values).round(1)

class AutoEncoder(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Linear(105, 80),
            torch.nn.Tanh(),
            torch.nn.Linear(80, 60),
            torch.nn.Tanh(),
            torch.nn.Linear(60, 40),
            torch.nn.Tanh(),
            torch.nn.Linear(40, 20)
        )

        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(20, 40),
            torch.nn.Tanh(),
            torch.nn.Linear(40, 60),
            torch.nn.Tanh(),
            torch.nn.Linear(60, 80),
            torch.nn.Tanh(),
            torch.nn.Linear(80, 105)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded

In [11]:
model = AutoEncoder()
model.load_state_dict(torch.load(f'{ARTIFACTS_PATH}/{STATION_ID}_fnn.pt'))
model.eval()

AutoEncoder(
  (encoder): Sequential(
    (0): Linear(in_features=105, out_features=80, bias=True)
    (1): Tanh()
    (2): Linear(in_features=80, out_features=60, bias=True)
    (3): Tanh()
    (4): Linear(in_features=60, out_features=40, bias=True)
    (5): Tanh()
    (6): Linear(in_features=40, out_features=20, bias=True)
  )
  (decoder): Sequential(
    (0): Linear(in_features=20, out_features=40, bias=True)
    (1): Tanh()
    (2): Linear(in_features=40, out_features=60, bias=True)
    (3): Tanh()
    (4): Linear(in_features=60, out_features=80, bias=True)
    (5): Tanh()
    (6): Linear(in_features=80, out_features=105, bias=True)
  )
)

In [12]:
# Load the dataset
df_original = load_data(STATION_ID)
ss = load_scaler(STATION_ID)

# Prepare the sample
x = df_original.drop(['id', 'effective_date', 'hour', 'day_num'], axis=1)
x = scale_data(ss, x)

# Convert to tensor
x = torch.from_numpy(x).float()
x.shape

torch.Size([100, 105])

In [13]:
# Encode the values
logits = model.encoder(x).clone().detach()

# Decode the values
decoded = model.decoder(logits).clone().detach()

# Undo the initial scaling transform
decoded = unscale_data(ss, decoded)

decoded.shape



(100, 105)

In [14]:
# Add the index back to the dataframe
output = pd.DataFrame(decoded, columns=df_original.columns[4:])
output.insert(0, 'id', df_original['id'] )
output.insert(1, 'effective_date', df_original['effective_date'] )
output.insert(2, 'hour', df_original['hour'] )
output.insert(3, 'day_num', df_original['day_num'] )
output.head()

Unnamed: 0,id,effective_date,hour,day_num,0_pres,0_temp,0_dp,0_u,0_v,1_pres,...,19_pres,19_temp,19_dp,19_u,19_v,20_pres,20_temp,20_dp,20_u,20_v
0,USM00072649,2005-12-12,12,-0.97,977.3,-4.8,-7.6,0.0,-0.5,928.9,...,272.3,-49.7,-84.9,16.9,-38.6,252.5,-50.8,-85.2,17.3,-36.2
1,USM00072649,2001-11-04,12,-0.62,993.9,5.1,1.6,0.3,-2.8,946.6,...,294.6,-46.3,-59.0,17.0,-31.5,273.7,-50.1,-61.9,17.2,-32.7
2,USM00072649,2020-06-03,12,0.91,977.7,16.7,14.5,-0.1,-0.1,933.7,...,298.9,-38.9,-57.4,38.8,-1.3,278.3,-42.2,-61.0,40.3,-1.3
3,USM00072649,2011-05-19,12,0.75,980.5,12.2,10.1,0.0,0.1,935.8,...,292.4,-44.9,-57.1,14.0,3.2,271.7,-48.4,-60.6,15.9,4.0
4,USM00072649,2025-04-30,12,0.5,983.4,4.5,2.1,0.3,0.3,937.0,...,289.6,-43.0,-58.3,35.3,20.6,269.2,-45.9,-60.2,38.0,21.4


In [15]:
# Interleave the results into the source for visual comparison
interleave = pd.concat([df_original, output])
interleave = interleave.sort_index()
interleave.head(10)

Unnamed: 0,id,effective_date,hour,day_num,0_pres,0_temp,0_dp,0_u,0_v,1_pres,...,19_pres,19_temp,19_dp,19_u,19_v,20_pres,20_temp,20_dp,20_u,20_v
0,USM00072649,2005-12-12,12,-0.97,977.4,-7.3,-8.1,-0.0,-0.0,928.6,...,271.9,-49.9,-83.8,17.0,-43.3,252.0,-51.4,-84.8,19.2,-42.5
0,USM00072649,2005-12-12,12,-0.97,977.3,-4.8,-7.6,0.0,-0.5,928.9,...,272.3,-49.7,-84.9,16.9,-38.6,252.5,-50.8,-85.2,17.3,-36.2
1,USM00072649,2001-11-04,12,-0.62,993.9,5.1,1.6,0.3,-2.8,946.6,...,294.6,-46.3,-59.0,17.0,-31.5,273.7,-50.1,-61.9,17.2,-32.7
1,USM00072649,2001-11-04,12,-0.62,993.1,4.6,2.1,0.5,-2.6,946.3,...,294.3,-46.8,-59.2,17.9,-33.7,273.6,-51.4,-62.1,15.2,-33.3
2,USM00072649,2020-06-03,12,0.91,977.3,16.6,16.4,-0.0,-0.0,933.4,...,298.9,-38.3,-57.6,38.3,-2.5,278.1,-42.7,-58.9,39.8,0.7
2,USM00072649,2020-06-03,12,0.91,977.7,16.7,14.5,-0.1,-0.1,933.7,...,298.9,-38.9,-57.4,38.8,-1.3,278.3,-42.2,-61.0,40.3,-1.3
3,USM00072649,2011-05-19,12,0.75,980.5,12.2,10.1,0.0,0.1,935.8,...,292.4,-44.9,-57.1,14.0,3.2,271.7,-48.4,-60.6,15.9,4.0
3,USM00072649,2011-05-19,12,0.75,981.1,12.7,10.8,-0.0,-0.0,936.0,...,292.5,-45.3,-57.2,11.5,3.4,271.6,-49.0,-60.2,18.3,4.9
4,USM00072649,2025-04-30,12,0.5,982.6,2.3,0.8,0.0,-0.0,936.5,...,289.8,-43.2,-58.7,43.1,25.8,269.3,-46.9,-54.0,47.9,31.7
4,USM00072649,2025-04-30,12,0.5,983.4,4.5,2.1,0.3,0.3,937.0,...,289.6,-43.0,-58.3,35.3,20.6,269.2,-45.9,-60.2,38.0,21.4


In [16]:
interleave.to_csv(f'{ARTIFACTS_PATH}/ae_compare.csv')