In [1]:
import sys
import os
import torch
import toml

In [2]:
sys.path.append(os.path.join(sys.path[0], '../..'))

from endure.lsm.cost import EndureCost
from endure.lsm.types import Policy, System, LSMDesign, LSMBounds
from endure.ltune.util import LTuneEvalUtil
from endure.ltune.data.generator import LTuneDataGenerator
from endure.ltune.model import LTuneModelBuilder

In [3]:
config = toml.load('../../endure.toml')

In [4]:
config['lsm']['bounds']['size_ratio_range'] = [2, 10]
config['lsm']['bounds']['max_considered_levels'] = 10

In [5]:
cf = EndureCost(config)
design_type = getattr(Policy, config["lsm"]["design"])
bounds = LSMBounds(**config["lsm"]["bounds"])
gen = LTuneDataGenerator(bounds)
design_type, bounds

(<Policy.KHybrid: 3>,
 LSMBounds(max_considered_levels=10, bits_per_elem_range=(1, 10), size_ratio_range=[2, 10], page_sizes=[4, 8, 16], entry_sizes=[1024, 2048, 4096, 8192], memory_budget_range=[5, 20], selectivity_range=[1e-07, 1e-09], elements_range=[100000000, 1000000000]))

In [6]:
model = LTuneModelBuilder(
    size_ratio_range=bounds.size_ratio_range,
    max_levels=bounds.max_considered_levels,
    **config["ltune"]["model"],
).build_model(design_type)
model

KapLSMTuner(
  (in_norm): BatchNorm1d(9, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (in_layer): Linear(in_features=9, out_features=64, bias=True)
  (relu): ReLU(inplace=True)
  (dropout): Dropout(p=0, inplace=False)
  (hidden): Sequential(
    (0): Linear(in_features=64, out_features=64, bias=True)
  )
  (k_path): Linear(in_features=64, out_features=64, bias=True)
  (t_path): Linear(in_features=64, out_features=64, bias=True)
  (bits_path): Linear(in_features=64, out_features=64, bias=True)
  (k_decision): KapDecision(
    (decision_layers): ModuleList(
      (0-9): 10 x Linear(in_features=64, out_features=8, bias=True)
    )
  )
  (t_decision): Linear(in_features=64, out_features=8, bias=True)
  (bits_decision): Linear(in_features=64, out_features=1, bias=True)
)

In [21]:
batch = [gen.generate_row_csv() for _ in range(2)]
x = torch.Tensor(batch)
x

tensor([[2.1700e-01, 1.8700e-01, 3.1100e-01, 2.8500e-01, 6.4000e+01, 6.7556e-08,
         2.0480e+03, 1.0666e+01, 3.7697e+08],
        [3.2700e-01, 3.2000e-02, 2.3100e-01, 4.1000e-01, 6.4000e+01, 9.3230e-08,
         2.0480e+03, 1.0824e+01, 4.3409e+08]])

In [22]:
model(x, temp=1, hard=False)

tensor([[-0.0121,  0.0274,  0.0618,  0.7062,  0.0358,  0.0604,  0.0706,  0.0146,
          0.0231,  0.2362,  0.2335,  0.0576,  0.0611,  0.0700,  0.0586,  0.1042,
          0.1786,  0.0514,  0.1569,  0.0623,  0.0199,  0.1548,  0.3299,  0.0292,
          0.1957,  0.1307,  0.1213,  0.0362,  0.1206,  0.1541,  0.1705,  0.0231,
          0.2434,  0.0829,  0.0214,  0.0260,  0.3989,  0.1421,  0.0600,  0.2338,
          0.0350,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  1.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000],
        [ 0.1076,  0.0234,  0.0235,  0.0481,  0.0273,  0.0844,  0.5703,  0.0538,
         

In [23]:
out = model.in_norm(x)
out = model.in_layer(out)
out = model.relu(out)
out = model.dropout(out)
out = model.hidden(out)

bits_out = model.bits_path(out)
bits = model.bits_decision(bits_out)

t_out = model.t_path(out)
t = model.t_decision(t_out)
t = torch.nn.functional.gumbel_softmax(t, tau=1, hard=False)

k_out = model.k_path(out)
k = model.k_decision(k_out, temp=1, hard=False)

max_levels = model.calc_max_level(x, bits, t)
max_levels = max_levels - 1
max_levels = max_levels.to(torch.long)

mask = torch.nn.functional.one_hot(max_levels, num_classes=model.num_kap)
cum_sum = torch.cumsum(mask, dim=1)
mask = 1 - cum_sum + mask
default = torch.zeros(model.capacity_range)
default[0] = 1
k = mask.unsqueeze(-1) * k
k[mask == 0] += default
k = torch.flatten(k, start_dim=1)
out = torch.concat([bits, t, k], dim=-1

tensor([[-4.3348e-02, -6.3075e-03,  1.0505e-03,  4.6666e-01,  1.8032e-03,
         -2.3303e-01,  1.3103e-01, -5.7893e-01,  4.5307e-01,  7.0589e-02,
         -1.6229e-01, -1.2630e-01,  3.0667e-01, -4.0167e-01, -4.0689e-02,
         -6.1250e-02,  1.0734e-01,  1.3212e-02, -1.1114e-01, -1.9225e-01,
          1.6383e-01, -1.9113e-02, -1.4412e-02, -4.1255e-02,  2.3014e-02,
         -3.4731e-01,  2.3220e-01, -1.8513e-01,  1.7293e-01, -1.5584e-01,
         -6.1889e-02, -1.3527e-01,  4.0011e-01, -2.8953e-01, -2.5096e-01,
          2.6681e-02, -3.3066e-01,  3.8333e-02,  9.1707e-02, -1.5441e-01,
         -2.6939e-01,  2.8793e-01, -3.1192e-01, -2.2801e-02,  2.7427e-01,
          3.6499e-01, -3.6190e-01,  1.2182e-02,  2.7682e-01,  1.2048e-01,
          8.5851e-02,  4.5942e-02, -1.4887e-01, -3.6888e-02, -5.4908e-02,
          2.6361e-01, -3.5416e-01,  4.2265e-01, -3.8471e-01, -5.0420e-02,
          1.7632e-01, -5.4793e-01,  6.8953e-02, -1.1594e-01],
        [-6.3974e-02,  2.7086e-01, -1.9346e-01, -1