In [None]:
import os
import pandas as pd
import pickle
import torch
import glob

from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import PowerTransformer
from torch.utils.data import DataLoader
from torch.nn import Module, Sequential, Linear, Tanh, MSELoss

IGRA_PATH = '/usr/datalake/silver/igra/gph20s10k'
STATION_LIST = '/usr/datalake/silver/igra/doc/igra2-station-list.csv'
ARTIFACTS_PATH = '/usr/datalake/silver/stormevents/artifacts/igra_storm_event_autoencoder'

In [None]:
class AutoEncoder(Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = Sequential(
            Linear(105, 80),
            Tanh(),
            Linear(80, 60),
            Tanh(),
            Linear(60, 40),
            Tanh(),
            Linear(40, 20)
        )

        self.decoder = Sequential(
            Linear(20, 40),
            Tanh(),
            Linear(40, 60),
            Tanh(),
            Linear(60, 80),
            Tanh(),
            Linear(80, 105)
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded

In [None]:
class olie_igra_trainer:
    batch_size = 256
    epochs = 1024
    learning_rate = 0.005
    learning_rate_gamma = 0.99
    lambda_param = .0001

    def __init__(self, igra_path: str, artifact_path: str, station_id: str, model):
        self.igra_path = igra_path
        self.station_id = station_id
        self.artifact_path = artifact_path
        self.model = model

    def load_transform_dataset(self):
        X = pd.read_csv(f'{self.igra_path}/{self.station_id}-data-gph20s10k.csv')

        # Remove irrelevant data
        X = X[X['hour'] == 12]
        X = X.drop(['id', 'effective_date', 'hour', 'day_num', '0_gph',
                    '1_gph', '2_gph', '3_gph', '4_gph', '5_gph',
                    '6_gph', '7_gph', '8_gph', '9_gph', '10_gph',
                    '11_gph', '12_gph', '13_gph', '14_gph', '15_gph',
                    '16_gph', '17_gph', '18_gph', '19_gph', '20_gph'
                    ], axis=1)
        if X.shape[0] == 0:
            return False
        
        # Scale the X dataset
        ss = PowerTransformer()
        X = ss.fit_transform(X)

        # Save the transform
        os.makedirs(self.artifact_path, exist_ok=True)
        with open(f'{self.artifact_path}/{self.station_id}_scaler.pkl', 'wb') as f:
            pickle.dump(ss, f)
        
        train, test = train_test_split(X, test_size=0.2)
        self.x_train = torch.from_numpy(train).float().cuda()
        self.x_test = torch.from_numpy(test).float().cuda()
        self.n_batches = self.x_train.size()[0] // self.batch_size

        print (f"Station ID: {self.station_id}, Training size: {self.x_train.size()[0]:,}, Predict size: {self.x_test.size()[0]:,}, Feature count: {self.x_train.size()[1]}, Number of batches: {self.n_batches}")

        return True
    
    def train(self, inputs, labels) -> float:
        self.optimizer.zero_grad()

        # Calculate error
        logits = self.model(inputs)
        cost = self.loss_function(logits, labels)

        # Back propagation
        cost.backward()
        self.optimizer.step()

        return float(cost.item())

    def predict(self, inputs):
        self.optimizer.zero_grad()

        # Calculate error
        logits = self.model(inputs).clone().detach()

        return logits
    
    def r2_score_manual(self, preds, target):
        target_mean = torch.mean(target)
        ss_tot = torch.sum((target - target_mean) ** 2) # Total sum of squares
        ss_res = torch.sum((target - preds) ** 2)       # Residual sum of squares
        r2 = 1 - (ss_res / ss_tot)
        
        return float(r2.item())
    
    def output_progress(self, epoch: int, cost: float):
        preds = self.predict(self.x_test)
        acc = self.r2_score_manual(self.x_test, preds)
        print(f"Epoch: {epoch+1}, cost: {cost / self.n_batches:.4f}, acc: {acc:.3f}, lr: {self.scheduler.get_last_lr()[0]:.2e}\r", end="")
        
    def train_orch(self):
        self.optimizer = torch.optim.Adam(self.model.parameters(), self.learning_rate)
        self.loss_function = MSELoss()
        self.scheduler = torch.optim.lr_scheduler.ExponentialLR(self.optimizer, gamma=self.learning_rate_gamma)
    
        for epoch in range(self.epochs):
            cost = 0
            loader = DataLoader(dataset = self.x_train, batch_size = self.batch_size, shuffle = True)

            for batch in loader:
                cost += self.train(batch, batch)

            self.scheduler.step()

            if epoch % 32 == 0:
                self.output_progress(epoch, cost)
        
        self.output_progress(epoch, cost)
        print()

    def save_weights(self):
        torch.save(self.model.state_dict(), f'{self.artifact_path}/{self.station_id}_fnn.pt')

    def exists_weights(self):
        return os.path.exists(f'{self.artifact_path}/{self.station_id}_fnn.pt')

    def dispose(self):
        del self.x_train
        del self.x_test
        del self.optimizer
        del self.loss_function
        del self.scheduler
        del self.model

In [None]:
def process_station(station_id: str):
    model = AutoEncoder().cuda()
    train = olie_igra_trainer(IGRA_PATH, ARTIFACTS_PATH, station_id, model)

    if train.exists_weights():
        print(f"Station {station_id} already processed")
        return

    result = train.load_transform_dataset()
    if not result:
        print(f"Station {station_id} has zero usable rows")
        return
    
    train.train_orch()
    train.save_weights()
    train.dispose()

    del train
    del model

In [None]:
for filepath in glob.glob(f'{IGRA_PATH}/*-data-gph20s10k.csv'):
    filename = Path(filepath).name
    station_id = filename.split('-')[0]

    process_station(station_id)