In [27]:
import sys, os, toml
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, HTML
from itertools import combinations_with_replacement
from torch.utils.data import Dataset

sys.path.append(os.path.join(sys.path[0], '../..'))

from data.kcost_dataset import KCostDataSet
from data.io import Writer, Reader
from lsm.cost import EndureQFixedCost, EndureTierLevelCost, EndureKHybridCost
from lsm.lsmtype import LSMTree, LSMSystem, Policy

In [2]:
def wl_to_array(wl_dict):
    return (wl_dict['id'], wl_dict['z0'], wl_dict['z1'], wl_dict['q'], wl_dict['w'])

config = Reader.read_config(os.path.join(sys.path[0], '../..', 'config', 'endure.toml'))
reader = Reader(config)
writer = Writer(config)

system_vars = LSMSystem(**config['system'])

In [23]:
class KCostDataSetSplit(Dataset):
    MAX_LEVELS = 16
    
    def __init__(self, paths: list[str], transform=None, target_transform=None):
        df = []
        for path in paths:
            print(f'Reading in {path}')
            df.append(pd.read_feather(path))
        df = pd.concat(df)

        cont_inputs = ['h', 'z0', 'z1', 'q', 'w']
        cate_inputs = ['T'] + [f'K_{i}' for i in range(self.MAX_LEVELS)]
        output_cols = ['new_cost']

        mean = df[cont_inputs].mean()
        std = df[cont_inputs].std()
        std[std == 0] = 1
        df[cont_inputs] = (df[cont_inputs] - mean) / std
        self.normalize_vars = (mean ,std)
        
        self.cont_inputs = torch.from_numpy(df[cont_inputs].values).float()
        print('Normalized continous vars')
        self.cate_inputs = torch.from_numpy(df[cate_inputs].values).to(torch.int64)
        
        self.outputs = torch.from_numpy(df[output_cols].values).float()

    def __len__(self):
        return len(self.cont_inputs)

    def __getitem__(self, idx):
        categories = torch.flatten(nn.functional.one_hot(self.cate_inputs[idx], num_classes=50), start_dim=-2)
        inputs = torch.cat((self.cont_inputs[idx], categories), dim=-1)
        label = self.outputs[idx]

        return inputs, label

In [24]:
%%time

base_path = '/scratchNVM0/ndhuynh/data/cost_k_feather'
paths = [os.path.join(base_path, f'k_wl_{i}.feather') for i in range(15)]
paths

data = KCostDataSetSplit(paths)

Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_0.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_1.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_2.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_3.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_4.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_5.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_6.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_7.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_8.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_9.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_10.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_11.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_12.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_13.feather
Reading in /scratchNVM0/ndhuynh/data/cost_k_feather/k_wl_1

In [5]:
def create_k_levels(levels, max_size_ratio):
    arr = combinations_with_replacement(range(max_size_ratio, 0, -1), levels)
    return arr

arr = list(create_k_levels(5, 3))[0]
arr = np.pad(arr, (0, 16 - len(arr)))
arr

array([3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])

In [31]:
len(data)

22056240