In [1]:
import os
import sys

from ipywidgets import interactive, IntSlider
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import PercentFormatter
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import pandas as pd
import polars as pl
import sqlite3
import torch

In [2]:
sys.path.append('../..')

from endure.data.io import Reader
from endure.lsm.cost import EndureCost
from endure.lsm.types import Policy, System, LSMDesign, LSMBounds, Workload
from endure.ltune.util import LTuneEvalUtil
from endure.ltune.model import LTuneModelBuilder
from endure.lsm.solver import ClassicSolver, KLSMSolver

In [3]:
WL_COLUMNS = ['empty_reads', 'non_empty_reads', 'range_queries', 'writes']
SYS_COLUMNS = ['entry_size', 'selectivity', 'entries_per_page', 'num_elmement', 'bits_per_elem_max', 'read_write_asym']

# Database Connection

In [4]:
connection = sqlite3.connect("../../axe_data.db")
env_table = pl.read_database("SELECT * FROM environments;", connection)
run_table = pl.read_database("SELECT * FROM tunings;", connection)

In [37]:
env_table.to_dicts()[0]

{'env_id': 1,
 'empty_reads': 0.25,
 'non_empty_reads': 0.25,
 'range_queries': 0.25,
 'writes': 0.25,
 'entry_size': 8192,
 'selectivity': 4e-07,
 'entries_per_page': 4,
 'num_elmement': 1000000000,
 'bits_per_elem_max': 10.0,
 'read_write_asym': 1.0}

In [5]:
# env_table = env_table.with_columns(pl.lit(100000000).alias('num_elmement'))

# Learned Tuner

In [17]:
path = '/scratchNVMe/ndhuynh/other_data/models/ltune/klsm_100824_1324/'
config = Reader.read_config(os.path.join(path, 'endure.toml'))
design_type = getattr(Policy, config["lsm"]["design"])
bounds = LSMBounds(**config["lsm"]["bounds"])
cf = EndureCost(bounds.max_considered_levels)

In [20]:
model = LTuneModelBuilder(
    size_ratio_range=bounds.size_ratio_range,
    max_levels=bounds.max_considered_levels,
    **config["ltune"]["model"],
).build_model(design_type)
model_name = os.path.join(path, 'best.model')
model_name = os.path.join(path, 'checkpoints/epoch_10.checkpoint')
model_data = torch.load(model_name, weights_only=True)
model_params = model_data['model_state_dict'] if model_name.endswith('checkpoint') else model_data
status = model.load_state_dict(model_params)
model.eval()
status

<All keys matched successfully>

In [21]:
ltune_util = LTuneEvalUtil(config, model, design_type)

## Example of how to evaluate ltune

In [22]:
z0, z1, q, w = workload = ltune_util.gen._sample_workload(4)
system = ltune_util.gen._sample_system()
out = ltune_util.get_ltune_out(system, z0, z1, q, w)
design = ltune_util.convert_ltune_output(out)

## Populating Learned Tuner Table

In [23]:
cursor = connection.cursor()
cursor.execute(
    """
    CREATE TABLE IF NOT EXISTS learned_tunings (
        env_id INTEGER PRIMARY KEY AUTOINCREMENT,
        bits_per_elem REAL,
        size_ratio INTEGER,
        kap0 REAL, kap1 REAL, kap2 REAL, kap3 REAL, kap4 REAL,
        kap5 REAL, kap6 REAL, kap7 REAL, kap8 REAL, kap9 REAL,
        kap10 REAL, kap11 REAL, kap12 REAL, kap13 REAL, kap14 REAL,
        kap15 REAL, kap16 REAL, kap17 REAL, kap18 REAL, kap19 REAL,
        cost REAL,
        FOREIGN KEY (env_id) REFERENCES workloads(env_id)
    );
    """
)
cursor.close()
connection.commit()

In [24]:
cursor = connection.cursor()
environment_ids = cursor.execute("SELECT env_id FROM environments").fetchall()
for (env_id,) in environment_ids:
    env = env_table.filter(pl.col('env_id') == env_id)
    data = run_table.filter(pl.col('env_id') == env_id)
    wl = Workload(*env[WL_COLUMNS].rows()[0])
    system = System(*env[SYS_COLUMNS].rows()[0])
    with torch.no_grad():
        out = ltune_util.get_ltune_out(system, wl.z0, wl.z1, wl.q, wl.w, hard=True)
    ltune_design = ltune_util.convert_ltune_output(out)
    ltune_cost = cf.calc_cost(ltune_design, system, wl.z0, wl.z1, wl.q, wl.w)
    cursor.execute(
        """
        INSERT OR REPLACE INTO learned_tunings (
        env_id,
        bits_per_elem,
        size_ratio,
        kap0, kap1, kap2, kap3, kap4,
        kap5, kap6, kap7, kap8, kap9,
        kap10, kap11, kap12, kap13, kap14,
        kap15, kap16, kap17, kap18, kap19,
        cost
        ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ? ,? ,? ,? ,?, ?, ?)
        """,
        (env_id, ltune_design.h, int(ltune_design.T)) + tuple(ltune_design.K) + (ltune_cost,)
    )
cursor.close()
connection.commit()

In [25]:
ltune_table = pl.read_database("SELECT * FROM learned_tunings;", connection)
ltune_table

env_id,bits_per_elem,size_ratio,kap0,kap1,kap2,kap3,kap4,kap5,kap6,kap7,kap8,kap9,kap10,kap11,kap12,kap13,kap14,kap15,kap16,kap17,kap18,kap19,cost
i64,f64,i64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
1,0.578876,10,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,28.118148
2,5.881237,10,6.0,2.0,1.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.313381
3,7.180369,10,4.0,2.0,2.0,2.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.128365
4,0.022507,30,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,99.084018
5,0.860694,11,6.0,4.0,4.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,3.399901
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
11,0.082107,10,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,54.028529
12,0.283133,30,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,34.477474
13,4.244865,14,6.0,4.0,4.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.651698
14,0.277557,10,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,36.804842


# Optimizer

In [26]:
solver = ClassicSolver(bounds)
ksolver = KLSMSolver(bounds)

Create the monkey tunings table, this will hold tunings and their associated cost

In [27]:
cursor = connection.cursor()
cursor.execute(
    """
    CREATE TABLE IF NOT EXISTS monkey_tunings (
        env_id INTEGER PRIMARY KEY AUTOINCREMENT,
        bits_per_elem REAL,
        size_ratio REAL,
        is_leveling INTEGER,
        cost REAL,
        kcost REAL,
        FOREIGN KEY (env_id) REFERENCES workloads(env_id)
    );
    """
)
cursor.close()
connection.commit()

Finding all of the tunings, here we limit everything to be leveling or tiering

In [28]:
cursor = connection.cursor()
environment_ids = cursor.execute("SELECT env_id FROM environments").fetchall()
for (env_id,) in environment_ids:
    env = env_table.filter(pl.col('env_id') == env_id)
    data = run_table.filter(pl.col('env_id') == env_id)
    wl = Workload(*env[WL_COLUMNS].rows()[0])
    system = System(*env[SYS_COLUMNS].rows()[0])
    design, _ = solver.get_nominal_design(system, wl.z0, wl.z1, wl.q, wl.w)
    design.T = np.ceil(design.T)
    kdesign, _ = ksolver.get_nominal_design(system, wl.z0, wl.z1, wl.q, wl.w)
    kdesign.T = np.ceil(design.T)
    kdesign.K = [np.ceil(ki) for ki in kdesign.K]
    cost = cf.calc_cost(design, system, wl.z0, wl.z1, wl.q, wl.w)
    kcost = cf.calc_cost(kdesign, system, wl.z0, wl.z1, wl.q, wl.w)
    cursor.execute(
        """
        INSERT OR REPLACE INTO monkey_tunings (
            env_id,
            bits_per_elem,
            size_ratio,
            is_leveling,
            cost,
            kcost)
        VALUES (?, ?, ?, ?, ?, ?)
        """,
        (env_id, design.h, design.T, 1 if design.policy == Policy.Leveling else 0, cost, kcost)
    )
cursor.close()
connection.commit()

In [29]:
monkey_table = pl.read_database("SELECT * FROM monkey_tunings;", connection)
monkey_table

env_id,bits_per_elem,size_ratio,is_leveling,cost,kcost
i64,f64,f64,i64,f64,f64
1,3.211367,6.0,1,27.830134,28.141494
2,8.778876,7.0,1,1.157607,1.171342
3,6.72313,7.0,1,2.08895,2.095246
4,1.0,30.0,1,99.111233,100.044999
5,3.311264,20.0,0,2.682739,2.670772
…,…,…,…,…,…
11,1.0,6.0,1,53.681294,54.396088
12,4.514403,30.0,1,34.245665,34.539509
13,8.006566,9.0,0,2.378837,2.424117
14,3.005914,6.0,1,36.39797,36.792728
