Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions run_stl.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,34 @@
set -ex

num_keys=20
workers=1
slide_size=288
window_size=864
source_sleep_per_batch=0.01

# make sure update_throughput < event_throughput
let update_throughput="$workers*2"
echo $update_throughput
echo "scale=2 ; $num_keys/ ($source_sleep_per_batch*$slide_size)" | bc

python workloads/stl/stl_server.py \
--scheduler=ce \
--window_size=864 \
--slide_size=288 \
--workers=16 \
--window_size=${window_size}\
--slide_size=${slide_size}\
--workers=${workers}\
--azure_database /home/ubuntu/cleaned_sqlite_3_days_min_ts.db \
--num_keys=${num_keys}\
--source_sleep_per_batch ${source_sleep_per_batch}

python workloads/stl/stl_server.py \
--scheduler=rr \
--window_size=${window_size}\
--slide_size=${slide_size}\
--workers=${workers}\
--azure_database /home/ubuntu/cleaned_sqlite_3_days_min_ts.db \
--num_keys=5000 \
--source_sleep_per_batch 0.001
--num_keys=${num_keys}\
--source_sleep_per_batch ${source_sleep_per_batch}




77 changes: 70 additions & 7 deletions workloads/stl/stl_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,20 @@
required=False,
)

flags.DEFINE_float(
"epsilon",
default=None,
help="Default error to add to each time a query is made",
required=False,
)

KeyType = int


class BasePriorityScheduler(BaseScheduler):
def __init__(self):
self.key_to_event: Dict[KeyType, Record] = dict()
# set initial priority scores to infinity
self.key_to_priority: Dict[KeyType, float] = dict()
self.sorted_keys_by_timestamp = SortedSet(
key=lambda key: self.key_to_priority[key]
Expand All @@ -87,6 +95,8 @@ def __init__(self):
self.stop_iteration = None
self._writer_lock = None

self.max_prio = 100000000000

@property
def writer_lock(self):
if self._writer_lock is None:
Expand All @@ -109,8 +119,12 @@ def push_event(self, record: Record):
# Remove the key so we can recompute sort key
if record_key in self.sorted_keys_by_timestamp:
self.sorted_keys_by_timestamp.remove(record_key)

# Update priority
self.key_to_priority[record_key] = self.compute_priority(record)

self.sorted_keys_by_timestamp.add(record_key)
print("add", record_key)
self.wake_waiter_if_needed()

def pop_event(self) -> Record:
Expand All @@ -120,9 +134,14 @@ def pop_event(self) -> Record:
with self.writer_lock:
if len(self.key_to_event) == 0:
return Record.make_wait_event(self.new_waker())

print(len(self.key_to_event), len(self.sorted_keys_by_timestamp))
latest_key = self.sorted_keys_by_timestamp.pop()
record = self.key_to_event.pop(latest_key)
self.key_to_priority.pop(latest_key)
prio = self.key_to_priority.pop(latest_key)
#self.key_to_priority[latest_key] = 0
if self.qsize() == 0:
logger.msg(f"Queue size is zero - system not fully utilized")
return record

def qsize(self) -> int:
Expand All @@ -133,39 +152,62 @@ class KeyAwareLifo(BasePriorityScheduler):
"""Always prioritize the latest record by arrival time."""

def compute_priority(self, _: Record) -> float:
feature: Record[TimeSeriesValue] = self._operator.get(record.shard_key)
if feature is None:
logger.msg(
f"Missing feature for key {record.shard_key}, returning max_prio"
)
return self.max_prio

return time.time()


class RoundRobinScheduler(BasePriorityScheduler):
"""Prioritize the key that hasn't been updated for longest."""

def compute_priority(self, record: Record) -> float:
feature: Record[TimeSeriesValue] = self._operator.get(record.shard_key)
if feature is None:
logger.msg(
f"Missing feature for key {record.shard_key}, returning max_prio"
)
return self.max_prio


return self.key_to_priority.get(record.shard_key, 0) + 1


class CumulativeErrorScheduler(BasePriorityScheduler):
"""Prioritize the key that has highest prediction error so far"""

max_prio = 10000000

def __init__(self):
def __init__(self, epsilon = None):
# TODO: bring back the logic that temporarily disable a key if it is pending update
# If that ever becomes an issue.
# self.pending_updates: Dict[KeyType, float] = []

super().__init__()
self.epsilon = epsilon
self.last_seqno = {}

def compute_priority(self, record: Record["WindowValue"]) -> float:
assert isinstance(record.entry, WindowValue)

# start from last seen seqno
start = self.last_seqno.get(record.shard_key, -1)
incoming_seqnos = np.array([n for n in record.entry.seq_nos if n > start])
self.last_seqno[record.shard_key] = incoming_seqnos.max()
#print("length", incoming_seqnos.shape, start)

# lookup current feature
feature: Record[TimeSeriesValue] = self._operator.get(record.shard_key)
if feature is None:
logger.msg(
f"Missing feature for key {record.shard_key}, returning max_prio"
)
return self.max_prio

incoming_seqnos = np.array(record.entry.seq_nos)

forecast = np.array(feature.forecast)
window_last_seqno = feature.last_seqno
forecast_indicies = incoming_seqnos - window_last_seqno - 1
Expand All @@ -177,12 +219,22 @@ def compute_priority(self, record: Record["WindowValue"]) -> float:
y_pred = np.take(forecast, forecast_indicies)
y_train = np.array(feature.y_train)

assert len(record.entry.seq_nos) == 864, f"Unexpected length {len(record.entry.seq_nos)}"
assert len(y_true) % 288 == 0, f"Unexpected length {len(y_true)}"

# TODO: sample if too heavy weight
# TODO: maybe scale this by staleness
error = mean_absolute_scaled_error(
y_true=y_true, y_pred=y_pred, y_train=y_train
)
return error
) * len(y_true)

print(record.shard_key, "Marginal error", error, forecast_indicies.max() - start, len(y_true), self.key_to_priority.get(record.shard_key, 0))

if self.epsilon is not None:
error = max(error, self.epsilon) # minimum error

# add to current error
return self.key_to_priority.get(record.shard_key, 0) + error


@dataclass
Expand Down Expand Up @@ -268,6 +320,8 @@ def on_event(self, _: Record) -> List[Record[SourceValue]]:
if len(batch) == 0:
return

#print(f"Sending {len(batch)} rows at {self.ts} at {ingest_time}")

self.result_file.write(json.dumps([i.entry.__dict__ for i in batch]))
self.result_file.write("\n")
self.result_file.flush()
Expand Down Expand Up @@ -334,6 +388,8 @@ class STLFitForecast(BaseTransform):

def __init__(self, results_dir):
self.results_dir = results_dir
self.start_time = None
self.num_updates = 0

def prepare(self):
self.data = defaultdict(lambda: None)
Expand All @@ -347,6 +403,8 @@ def get(self, key):
def on_event(self, record: Record[WindowValue]):
key_id = record.shard_key

if self.start_time is None: self.start_time = time.time()

with warnings.catch_warnings():
# catch warning for ML fit
warnings.filterwarnings("ignore")
Expand All @@ -358,6 +416,9 @@ def on_event(self, record: Record[WindowValue]):
).fit()
forecast = model.forecast(9000)

self.num_updates += 1
print("avg throughput", self.num_updates, self.num_updates / (time.time() - self.start_time))

forecast_record = TimeSeriesValue(
key_id=key_id,
forecast=forecast.tolist(),
Expand Down Expand Up @@ -407,6 +468,7 @@ def _get_config() -> Dict:

def main(argv):
logger.msg("Running STL pipeline on ralf...")
print("Results", f"{FLAGS.results_dir}/metrics")

# Setup dataset directory
conn = sqlite3.connect(FLAGS.azure_database)
Expand Down Expand Up @@ -453,7 +515,7 @@ def main(argv):
schedulers = {
"lifo": KeyAwareLifo(),
"rr": RoundRobinScheduler(),
"ce": CumulativeErrorScheduler(),
"ce": CumulativeErrorScheduler(FLAGS.epsilon),
}

app.source(
Expand Down Expand Up @@ -483,6 +545,7 @@ def main(argv):

app.deploy()
app.wait()
print("Finished", FLAGS.results_dir)


if __name__ == "__main__":
Expand Down