In [1]:
import os
import pickle
import pandas as pd
import torch

from torch.nn import Module, Sequential, Linear, Tanh

ARTIFACTS_PATH = '/usr/datalake/silver/stormevents/artifacts/igra_storm_event_autoencoder'
STATION_LIST = '/usr/datalake/silver/igra/doc/igra2-station-list.csv'
IGRA_SOURCE = '/usr/datalake/silver/igra/gph20s10k'
IGRA_DESTINATION = '/usr/datalake/silver/stormevents/igra_encoded'

In [2]:
def load_data(station_id: str) -> tuple[pd.DataFrame, pd.DataFrame]:
    # Load the dataset
    filename = f'{IGRA_SOURCE}/{station_id}-data-gph20s10k.csv'
    result = pd.read_csv(filename)

    # Remove irrelevant data
    result = result[result['hour'] == 12]
    result = result.drop(['id', 'hour', '0_gph',
                '1_gph', '2_gph', '3_gph', '4_gph', '5_gph',
                '6_gph', '7_gph', '8_gph', '9_gph', '10_gph',
                '11_gph', '12_gph', '13_gph', '14_gph', '15_gph',
                '16_gph', '17_gph', '18_gph', '19_gph', '20_gph'
                ], axis=1)
    result = result.reset_index(drop=True)

    index = result.iloc[:,:2]
    data = result.iloc[:,2:]

    return index, data

def load_scaler(station_id: str):
    with open(f'{ARTIFACTS_PATH}/{station_id}_scaler.pkl', 'rb') as f:
        scaler = pickle.load(f)

    return scaler

def scale_data(scaler, sample: pd.DataFrame):
    result = torch.from_numpy(scaler.transform(sample)).float()
    return result

def load_model(station_id: str):
    model = AutoEncoder()
    model.load_state_dict(torch.load(f'{ARTIFACTS_PATH}/{station_id}_autoencoder.pt'))
    model.eval()

    return model

class AutoEncoder(Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = Sequential(
            Linear(105, 80),
            Tanh(),
            Linear(80, 60),
            Tanh(),
            Linear(60, 40),
            Tanh(),
            Linear(40, 20)
        )

        self.decoder = Sequential(
            Linear(20, 40),
            Tanh(),
            Linear(40, 60),
            Tanh(),
            Linear(60, 80),
            Tanh(),
            Linear(80, 105)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded

In [3]:
os.makedirs(IGRA_DESTINATION, exist_ok=True)

In [4]:
df_brains = pd.read_csv(f'{ARTIFACTS_PATH}/brains_station_list.csv')
df_brains.head()

Unnamed: 0,id,latitude,longitude,elevation,state,name,fst_year,lst_year,nobs
0,BBM00078954,13.0716,-59.4922,56.6,,GRANTLEY ADAMS,1965,2025,31817
1,BHM00078583,17.5333,-88.3,5.0,,BELIZE/PHILLIP GOLDSTON INTL.,1980,2025,21481
2,CJM00078384,19.2944,-81.3632,3.0,,OWEN ROBERTS AIRPORT GRAND CAY,1956,2025,37597
3,COM00080001,12.5833,-81.7167,1.0,,SAN ANDRES (ISLA)/SESQUICENTEN,1956,2025,27604
4,DRM00078486,18.4734,-69.8705,14.0,,SANTO DOMINGO (78486-0),1962,2025,28418


In [5]:
for _, brain in df_brains.iterrows():
    station_id = brain['id']
    print(station_id)

    scaler = load_scaler(station_id)
    model = load_model(station_id)

    df_index, df_data = load_data(station_id)
    data = scale_data(scaler, df_data)
    
    # Encode the values
    logits = model.encoder(data).clone().detach().numpy()
    
    # Add the index back to the dataframe
    output = pd.DataFrame(logits)
    output = pd.concat([df_index, output], axis=1)

    output.to_csv(f'{IGRA_DESTINATION}/{station_id}_igra_encoded.csv', index=False)

BBM00078954
BHM00078583
CJM00078384
COM00080001
DRM00078486
JMM00078397
NNM00078866
RQM00078526
TDM00078970
UCM00078988
USM00070026
USM00070133
USM00070200
USM00070219
USM00070231
USM00070261
USM00070273
USM00070308
USM00070316
USM00070326
USM00070350
USM00070361
USM00070398
USM00072201
USM00072202
USM00072206
USM00072208
USM00072210
USM00072215
USM00072230
USM00072233
USM00072235
USM00072240
USM00072248
USM00072249
USM00072250
USM00072251
USM00072261
USM00072265
USM00072274
USM00072293
USM00072305
USM00072317
USM00072318
USM00072327
USM00072340
USM00072357
USM00072363
USM00072364
USM00072365
USM00072376
USM00072381
USM00072388
USM00072402
USM00072403
USM00072426
USM00072440
USM00072451
USM00072456
USM00072476
USM00072489
USM00072493
USM00072501
USM00072518
USM00072520
USM00072528
USM00072558
USM00072562
USM00072572
USM00072582
USM00072597
USM00072632
USM00072634
USM00072645
USM00072649
USM00072659
USM00072662
USM00072672
USM00072681
USM00072694
USM00072712
USM00072747
USM00072764
USM0

In [6]:
output.head()

Unnamed: 0,effective_date,day_num,0,1,2,3,4,5,6,7,...,10,11,12,13,14,15,16,17,18,19
0,1993-01-02,-1.0,0.073122,-0.074968,-0.007362,-0.223213,-0.092808,-0.039658,0.043359,-0.118171,...,0.162569,0.276106,0.240923,0.308072,-0.064227,-0.124865,0.003078,0.295161,-0.269146,0.065332
1,1993-01-04,-1.0,0.262698,-0.106609,0.011647,-0.051817,-0.203466,-0.073859,-0.1201,-0.173953,...,0.08515,-0.033903,0.184961,0.194819,-0.270973,0.038257,0.034263,0.257189,0.012582,-0.077878
2,1993-01-06,-0.99,0.121747,-0.141006,-0.124942,0.082351,-0.055091,-0.1522,0.02396,-0.092244,...,-0.156705,0.160974,0.507969,0.153353,-0.049423,-0.045559,-0.017898,0.340577,-0.168742,0.126351
3,1993-01-08,-0.99,-0.012959,-0.033137,0.051791,0.060087,-0.190153,-0.252145,0.042759,-0.099491,...,0.154783,0.072165,0.163632,0.240157,-0.124703,-0.102542,-0.024001,0.156921,-0.082751,-0.112238
4,1993-01-09,-0.99,0.128621,0.028059,0.032118,-0.128122,-0.092963,0.05922,0.03105,-0.10768,...,0.084866,0.216784,0.187969,0.173121,-0.098716,-0.195191,0.139028,0.068395,-0.19544,0.045995
