# Confident-aware LSTM-based Intelligent Caching

**This project explores the use of a Long Short-Term Memory (LSTM) network combined with Confidence Intervals (CIs) for an intelligent caching system. The aim is to improve prefetching, TTL assignment, and eviction policies, providing a framework that outperforms traditional baseline strategies in terms of hit rate and miss rate.**

This notebook is organized as follows:
- ⚙️ Configuration Settings
- 🎲 Data Generation
- 🧹 Data Preprocessing
- 🧭 Validation
- 🧠 Training
- 🧪 Evaluation (Model Standalone)

## ⚙️ Configuration Settings

**Configuration settings** are centralized in the `config.yaml` file. It is composed of the following sections:

- `data`: Settings for data generation, access patterns, temporal patterns, sequences, and datasets.
- `model`: General model and parameter settings.
- `validation`: Time series cross-validation, early stopping, and hyperparameter search space.
- `training`: Training and optimizer settings.
- `testing`: Testing options.
- `evaluation`: Evaluation metrics.
- `inference`: Confidence intervals and MC dropout for uncertainty estimation.
- `simulation`: Simulation of caching policies using either traditional or LSTM-based methods.

<br>

## `data`

### `data.distribution`

- `seed`: Random seed (`int >= 0`)
- `type`: Dataset type (`static` or `dynamic`)
- `num_requests`: Number of requests to generate (`int > 0`)
- `num_keys`: Number of unique keys (`int > 1`)
- `key_range.first_key`: First key ID (`int`)
- `key_range.last_key`: Last key ID (`int > first_key`)

### `data.access_pattern`

#### `zipf`

- `alpha`: Zipf parameter for static data (`float > 0`)
- `alpha_start`: Initial alpha for dynamic data (`float > 0`)
- `alpha_end`: Final alpha for dynamic data (`float > 0`)
- `time_steps`: Number of time steps to transition alpha (`int > 0`)

#### `access_behavior`

- `repetition_interval`: Re-access interval for repeated keys (`int > 0`)
- `repetition_offset`: Offset to apply when repeating (`int > 0`)
- `toggle_interval`: Toggle interval for alternating accesses (`int > 0`)
- `cycle_base`: Base length for cyclic scanning (`int > 0`)
- `cycle_mod`: Modulus to vary cycle length (`int > 0`)
- `cycle_divisor`: Divisor for cycle variability (`int > 0`)
- `distortion_interval`: Interval for distorted history pattern (`int > 0`)
- `noise_range`: Range of noise to apply to distorted memory ([min, max])
- `memory_interval`: Interval for memory recall pattern (`int > 0`)
- `memory_offset`: Offset to recall historical accesses (`int > 0`)

### `data.temporal_pattern`

#### `burstiness`

- `burst_high`: Scaling factor for burst peaks (`float in [0, burst_low]`)
- `burst_low`: Scaling factor for non-burst (`float > burst_high`)
- `burst_hour_start`: Hour when burst starts (`int in [0, 23]`)
- `burst_hour_end`: Hour when burst ends (`int in [0, 23]`)

#### `periodic`

- `base_scale`: Base frequency of periodic pattern (`int > 0`)
- `amplitude`: Amplitude of the periodic variation (`int >= 0`)

### `data.sequence`

- `len`: Input sequence length (`int > 0`)
- `embedding_dim`: Embedding dimension for keys (`int > 0`)

### `data.dataset`

- `training_perc`: Fraction of data for training (`float in [0.0, 1.0]`)
- `validation_perc`: Fraction of training data for validation (`float in [0.0, 1.0)`)
- `static_save_path`: Path to save static dataset (`string`)
- `dynamic_save_path`: Path to save dynamic dataset (`string`)

<br>

## `model`

### `model.general`

- `num_features`: Number of input features for the model (`int > 0`)
- `save_path`: Path to save trained model (`string`)

### `model.params`

- `hidden_size`: Size of LSTM hidden state (`int > 0`)
- `num_layers`: Number of LSTM layers (`int > 0`)
- `bias`: Whether to use bias in LSTM (`bool`)
- `batch_first`: Use batch-first input format (`bool`)
- `dropout`: Dropout between LSTM layers (`float in [0.0, 1.0)`)
- `bidirectional`: Use bidirectional LSTM (`bool`)
- `proj_size`: Size of projection layer (`int >= 0`)

<br>

## `validation`

### `cross_validation`

- `num_folds`: Number of cross-validation folds (`int > 1`)
- `num_epochs`: Epochs per fold (`int > 0`)

### `early_stopping`

- `patience`: Epochs to wait for improvement (`int >= 0`)
- `delta`: Minimum loss improvement (`float >= 0`)

### `search_space`

#### `model.params`

- `hidden_size_range`: List of hidden sizes to try (`List[int > 0]`)
- `num_layers_range`: List of layer counts to try (`List[int > 0]`)
- `dropout_range`: List of dropout values (`List[float in [0.0, 1.0))`)

#### `training.optimizer`

- `learning_rate_range`: Learning rates to try (`List[float > 0]`)

<br>

## `training`

### `training.general`

- `num_epochs`: Total training epochs (`int > 0`)
- `batch_size`: Batch size during training (`int > 0`)

### `training.optimizer`

- `type`: Optimizer type (`adam`, `adamw`, `sgd`)
- `learning_rate`: Initial learning rate (`float > 0`)
- `weight_decay`: L2 regularization (`float >= 0`)
- `momentum`: Momentum for optimizer (if supported) (`float in [0.0, 1.0]`)

### `training.early_stopping`

- `patience`: Epochs without improvement before stopping (`int >= 0`)
- `delta`: Minimum validation improvement to continue training (`float >= 0`)

<br>

## `testing`

### `testing.general`

- `batch_size`: Batch size during evaluation (`int > 0`)

<br>

## `evaluation`

- `top_k`: `k` for Top-k accuracy metric (`int > 0`)

<br>

## `inference`

### `confidence_intervals`

- `confidence_level`: Confidence level for interval estimates (`float in [0.0, 1.0]`)

### `mc_dropout`

- `num_samples`: Number of stochastic passes for MC Dropout (`int > 0`)

<br>

## `simulation`

### `general`

- `cache_size`: Maximum number of keys in the cache (`int > 0`)

### `traditional_cache`

- `ttl`: Fixed TTL value for each key in the cache (`float > 0`)

### `lstm_cache`

- `prediction_interval`: How often the model makes predictions (`int > 0`)
- `threshold_prob`: Minimum prediction probability to insert keys in cache (`float in [0.0, 1.0]`)
- `threshold_ci`: Confidence interval threshold related to predictions to insert keys in cache (`float in [0.0, 1.0]`)
- `ttl_base`: Base TTL for dynamic TTL calculation (`float > 0`)
- `alpha`: Multiplicative factor applied to the predicted probability for TTL calculation (`float > 0`)
- `beta`: Multiplicative factor applied to the confidence interval width for TTL adjustment (`float > 0`)

In [None]:
from config import prepare_config

config_settings = prepare_config()

## 🎲 Data Generation

Before training, evaluating or running experiments we need to **generate synthetic data**, reflecting the realistic nature of memory workload behaviours. We generate diverse **access patterns**, which determine the key to be accessed based on the time of the day:
- **Repetition (05:00-09:00)**: Models short-term locality by periodically re-accessing recently used keys.
- **Toggle (09:00-12:00)**: Simulates oscillating access behaviour.
- **Cyclic scanning (12:00-18:00)**:  Models sequential scanning over subset of keys.
- **Distorted history (18:00-23:00)**: Introduces noise to past accesses to simulate imprecise repetition or mutation.
- **Memory call & Zipfian sampling (23:00-05:00)**: Alternates between accessing deep historical keys and sampling from a Zipf distribution.

Inter-request times are modeled as a combination of **periodic** and **bursty** behaviours. Bursty activities occur on the mid-day (10:00-18:00).

We generate 30,000 **requests** over a 25-day period for 30 distinct **keys**, organized in two **datasets** (each with two columns: `timestamp` and `request`, the latter indicating the ID of the requested key):
- **Static dataset**: Assumes fixed key popularities over time, with the Zipf parameter set to 0.8.
- **Dynamic dataset**: Models time-varying key popularity by linearly increasing the Zipf parameter from 0.5 to 2.0 over five steps.

In [None]:
from main import config_settings
from data_generation import data_generation

data_generation(config_settings)

## 🧹 Data Preprocessing

**Data preprocessing** carries out two activities:
- **Missing values removal**: Removes missing values from the dataset.
- **Feature Engineering**: Aims to create two new columns—`sin_time` and `cos_time`—which replace the original `timestamp` with a trigonometric representation, enabling LSTM to better capture cyclical temporal patterns.

In [None]:
from data_preprocessing import data_preprocessing

data_preprocessing(config_settings)

## 🧭 Validation

**Validation** aims at finding the **best hyperparameters** to be used for training the final model. We define the **hyperparameter search space** as follows:
- `hidden_size`: [128, 256]
- `num_layers`: [2, 3]
- `dropout`: [0.1, 0.3]
- `learning_rate`: [0.001, 0.005]

We compute a **Grid Search** over $2^4=16$ hyperparameter combinations. For each combination we perform a **10-fold Time Series Cross-Validation** on the training set (70% of the dataset), useful to avoid data leakage by preserving the temporal order of events.

**Early Stopping** (`patience`=5, `delta`=0.0005) is applied while training on each fold (using **AdamW** as optimizer), stopping the process (involing 20 epochs at most) when the validation loss (calculated through the **weighted Cross Entropy Loss**) starts to increase.

Whenever a new hyperparameter combination achieves the **best average validation loss** seen so far, it is saved as the new best. At the end, we obtain the best hyperparameters (i.e., those yielding the lowest average validation loss).

In [None]:
from validation import validation

config_settings = validation(config_settings)

## 🧠 Training

The final model is obtained by **training** with the optimal hyperparameters identified. We reserve 20% of training set as **validation set** and we define a higher number of epochs (500) than those used for validating the model. **Early Stopping** (`patience`=25, `delta`=0.0005) is applied ensuring the model is trained over how many as possible as epochs, avoiding overfitting.

As soon as a new model has proven to be the current best one (i.e., it returns the best validation loss), its weights are saved. At the end, obtained the **best trained model**, we save it.

In [None]:
from training import training

training(config_settings)

##  🧪 Evaluation (Model Standalone)

After training the model, we **evaluate** it standalone on the testing set (30% of the dataset). The evaluation **metrics** computed are:
- **Average loss** and **Average loss per class**.
- **Class report**: Precision, Recall, and F1 for each class, Precision, and macro-average results.
- **Confusion matrix**: Summarizes the number of correct and incorrect prediction for each class.
- **Top-k accuracy**: How many times the target is predicted in the first `k=3` most probable keys.
- **Kappa statistic**: Compares the model with a random one.

In [None]:
from testing import testing

avg_loss, avg_loss_per_class, metrics, cost_perc = testing(config_settings)

# show results
print("\n" + "="*85)
print(" " * 30 + "Model Standalone Evaluation Report")
print("="*85 + "\n")
print(f"Average Loss:       {avg_loss:.4f}\n")
print("📉 Class Report per Class:")
print(metrics['class_report'] + "\n")
print(f"Top-k Accuracy:     {metrics['top_k_accuracy']:.4f}")
print(f"Kappa Statistic:    {metrics['kappa_statistic']:.4f}")
print("\n" + "="*85 + "\n")

##  🖥️ Simulation (Overall System)

In [None]:
from simulation import simulate
from cachetools import LRUCache, LFUCache, FIFOCache
from simulation.LSTMCache import LSTMCache
from simulation.RandomCache import RandomCache
from simulation.CacheWrapper import CacheWrapper

# setup cache strategies
strategies = {
    'LRU': CacheWrapper(LRUCache, config_settings),
    'LFU': CacheWrapper(LFUCache, config_settings),
    'FIFO': CacheWrapper(FIFOCache, config_settings),
    'RANDOM': RandomCache(config_settings),
    'LSTM': LSTMCache(config_settings),
}

# run simulation
results = []
for policy, cache in strategies.items():
    result = simulate(cache, policy, config_settings)
    results.append(result)

# show results
print("\n" + "="*90)
print(" " * 30 + "Overall System Evaluation Report")
print("="*90 + "\n")
print(f"{'Policy':<25} | {'Hit Rate (%)':>12} | {'Miss Rate (%)':>13}")
print("-"*90)
for res in results:
    print(f"{res['policy']:<25} | {res['hit_rate']:>12.2f} | {res['miss_rate']:>13.2f}")
print("\n" + "="*90 + "\n")