# Confident-aware LSTM-based Intelligent Caching

**This project explores the use of a Long Short-Term Memory (LSTM) network combined with Confidence Intervals (CIs) for an intelligent caching system implemented in Redis. The aim is to improve prefetching, TTL assignment, and eviction policies, providing a framework that outperforms traditional baseline strategies in terms of hit rate and miss rate.**

This notebook is organized as follows:
1. [**Configuration Settings**](#-configuration-settings)
2. [**Data Generation**](#-data-generation)
3. [**Data Preprocessing**](#-data-preprocessing)
4. [**Validation**](#-validation)
5. [**Training**](#-training)
6. [**Testing**](#-testing)

## ⚙️ Configuration Settings

**Configuration settings** are centralized in `config.yaml` file. The file is composed by the following main sections:
- `data`: Contains settings regard data generation (`distribution`, `access_pattern`, and `temporal_pattern`), about the data sequences (`sequence`), and the dataset (`dataset`).
- `model`: Contains general model settings (`general`), and model parameter's settings (`params`).
- `validation`: Contains settings about the Time Series Cross-Validation (`cross_validation`), Early Stopping (`early_stopping`), and hyperparameter search space (`search_space`).
- `training`: Contains general training settings (`general`) and optimizer settings (`optimizer`).
- `testing`: Contains general testing options (`general`).
- `evaluation`: Contains settings about metrics (`top_k`).

<br>

## `data`
### `data.distribution`
- `data.distribution.seed`: The seed used for random generations (`int >= 0`).
- `data.distribution.type`: The type of dataset to generate (`static` or `dynamic`).
- `data.distribution.num_requests`: The number of requests to generate (i.e., the number of dataset's rows) (`int > 0`).
- `data.distribution.num_keys`: The number of unique keys (`int > 1`, equal to `data.distribution.key_range.last_key - data.distribution.key_range.first_key + 1`).
- `data.distribution.freq_windows`: The windows within to calculate the relative frequencies of the keys (`List[int > 0]`).

#### `data.distribution.key_range`
- `data.distribution.key_range.first_key`: The first key's ID to generate (`int < data.distribution.key_range.last_key`).
- `data.distribution.key_range.last_key`: The last key's ID to generate (`int > data.distribution.key_range.first_key`).

### `data.access_pattern`
#### `data.access_pattern.zipf`
- `data.access_pattern.zipf.apha`: The fixed Zipf parameter (for static data distribution) (`float > 0`).
- `data.access_pattern.zipf.alpha_start`: The initial Zipf parameter (for dynamic data distribution) (`float > 0`, `<= data.access_pattern.zipf.alpha_end`).
- `data.access_pattern.zipf.alpha_end`: The final Zipf parameter (for dynamic data distribution) (`float > 0`, `>= data.access_pattern.zipf.alpha_start`).
- `data.access_pattern.zipf.time_steps`: The number of time steps to be considered while generating dynamic data distribution (`int > 0`).

#### `data.access_pattern.locality`
- `data.access_pattern.locality.prob`: The probability of local access pattern (`float in [0.0, 1.0]`).

### `data.temporal_pattern`
#### `data.temporal_pattern.burstiness`
- `data.temporal_pattern.burstiness.burst_high`: Scaling factor used during the peak phase of burst (the lower the value the higher the number of requests) (`float > 0`, `< data.temporal_pattern.burstiness.burst_low`).
- `data.temporal_pattern.burstiness.burst_low`: Scaling factor used outside the phase of burst (the higher the value the lower the number of requests) (`float > 0`, `> data.temporal_pattern.burstiness.burst_low`).
- `data.temporal_pattern.burstiness.burst_every`: Defines the periodicity in terms of number of requests (e.g., each 100 requests there is a new burst period) (`int > 0`).
- `data.temporal_pattern.burstiness.burst_peak`: Specifies the duration (in terms of requests) of the peak phase of burst (`int >= 0`).

#### `data.temporal_pattern.periodic`
- `data.temporal_pattern.periodic.base_scale`: The base frequency scaling factor (controls the request density baseline) (`int/float > 0`).
- `data.temporal_pattern.periodic.amplitude`: The amplitude of the periodic variation (controls how much the request frequency oscillates) (`int/float >= 0`).

### `data.sequence`
- `data.sequence.len`: The length of the input sequence used as model input (`int > 0`).
- `data.sequence.embedding_dim`: The dimensionality of the embedding space used to encode input keys (`int > 0`).

### `data.dataset`
- `data.dataset.training_perc`: The percentage of the dataset to be used as training set(`float in [0.0, 1.0]`).
- `data.dataset.validation_perc`: The percentage of the training set to be used as validation set (`float in [0.0, 1.0)`).
- `data.dataset.static_save_path`: Path where the static dataset will be saved (`str`).
- `data.dataset.dynamic_save_path`: Path where the dynamic dataset will be saved (`str`).

<br>

## `model`
### `model.general`
- `model.general.num_features`: The number of features per input element (`int > 0`).
- `model.general.save_path`: Path to save the trained model (`str`).

### `model.params`
- `model.params.hidden_size`: Number of units in the LSTM hidden layer (`int > 0`).
- `model.params.num_layers`: Number of LSTM layers (`int > 0`).
- `model.params.bias`: Whether to use bias weights in the LSTM (`bool`).
- `model.params.batch_first`: Whether input/output tensors are provided as (batch, seq, feature) (`bool`).
- `model.params.dropout`: Dropout probability between LSTM layers (`float in [0.0, 1.0)`).
- `model.params.bidirectional`: Whether to use a bidirectional LSTM (`bool`).
- `model.params.proj_size`: Size of the projection layer in LSTM (`int >= 0`, `<= model.params.hidden_size`).

<br>

## `validation`
### `validation.cross_validation`
- `validation.cross_validation.num_folds`: Number of folds to split the training set into during Time Series Cross-Validation (`int > 1`).
- `validation.cross_validation.num_epochs`: Number of training epochs for each fold while validating the model (`int > 0`).

### `validation.early_stopping`
- `validation.early_stopping.patience`: Number of epochs to wait after the last improvement before stopping training (`int >= 0`).
- `validation.early_stopping.delta`: Minimum change in validation loss to qualify as an improvement (`int/float >= 0`).

### `validation.search_space.model.params`
- `validation.search_space.model.params.hidden_size_range`: Non-empty list of possible values for `hidden_size` hyperparameter (`List[int > 0]`).
- `validation.search_space.model.params.num_layers_range`: List of possible values for `num_layers` hyperparameter (`List[int > 0]`).
- `validation.search_space.model.params.dropout_range`: Non-empty list of possible values for `dropout` hyperparameter (`List[float in [0.0, 1.0))`).

### `validation.search_space.training.optimizer`
- `validation.search_space.training.optimizer.learning_rate_range`: Non-empty list of learning rates to try (`List[float > 0]`).

<br>

## `training`
### `training.general`
- `training.general.num_epochs`: Number of epochs to train the final model (`int > 0`).
- `training.general.batch_size`: Size of training batches (`int > 0`).

### `training.optimizer`
- `training.optimizer.type`: Type of optimizer used during training (`adam`, `adamw`, or `sgd`).
- `training.optimizer.learning_rate`: Initial learning rate (`float > 0`).
- `training.optimizer.weight_decay`: Weight decay factor (`float >= 0`).
- `training.optimizer.momentum`: Momentum factor (if supported by optimizer) (`float in [0.0, 0.1]`).

<br>

## `testing`
### `testing.general`
- `testing.general.batch_size`: Size of batches during testing (`int > 0`).

<br>

## `evaluation`
- `evaluation.top_k`: Value of `k` to compute Top-k accuracy (`int > 0`).

<br>

## `inference`
### `inference.confidence_intervals`
- `inference.confidence_intervals.confidence_level`: Confidence level used for confidence intervals calculation. (`float in [0.0, 1.0]`).

### `inference.mc_dropout`
- `inference.mc_dropout.num_samples`: The number of samples on which to apply MC dropout. (`int > 0`).

In [1]:
from config import prepare_config

config_settings = prepare_config()

## 🎲  Data Generation

Before training, evaluating or running experiments, we need to **generate synthetic data** on which to work on. Data should reflect the nature of the accesses in real-world systems, which is often characterized by:
- **Skewed popularity**: Some objects are more popular than others.
- **Locality**: Some objects are correlated and are often accessed in sequence.
- **Periodic access patterns**: Objects are accessed in recurring time intervals.
- **Burstiness periods**: Sudden spikes in access frequency occur during short time periods.

To simulate this behavior, we generate two types of synthetic access patterns:

- **Spatial accesses**: Governed by Zipf distribution (the first keys are more likely to be used than later ones) and locality (after accessing to a certain key, the neighboring ones are more likely to be accessed also). The first access always targets a popular key, then the probability of accessing to a popular key is 30%. After the first access, the probability to access to a neighboring key is 70%.
- **Temporal accesses**: Modelled by using a combination of periodic and bursty patterns. While some time periods follow predictable, recurring intervals, others exhibit intervals of sudden, high-frequency access.

For doing more realistic experiments, we generate both a **static** and **dynamic dataset**. The first assumes key popularities are fixed over time (i.e., the Zipf parameter remains constant), whereas the second simulates changes in key popularities over time (i.e., the Zipf parameter varies).

The final datasets are composed by the following columns:
- `id`: The ID of the current request.
- `delta_time`: The temporal distance between the current request and the previous one.
- `freq_last_10`: The relative frequency of the current requested key in the last 10 accesses.
- `freq_last_100`: The relative frequency of the current requested key in the last 100 accesses.
- `freq_last_1000`: The relative frequency of the current requested key in the last 1000 accesses.
- `request`: The ID of the requested key.

In [None]:
from main import config_settings
from data_generation import data_generation

data_generation(config_settings)

## 🧹 Data Preprocessing

**Data preprocessing** aims to prepare data for being used in the next steps. Three processes are performed here:
- **Duplicates removal**: Removes all the duplicates based on one or more columns. In that case, `id` duplicates are removed from the dataset.
- **Missing values removal**: Removes all rows having missing values from the dataset.
- **Standardization**: Standardizes one or more columns, avoiding too large value distances between data in a given column. In that case, we standardize `id`, `delta_time`, `freq_last_10`, `freq_last_100`, and `freq_last_1000`.

In [None]:
from data_preprocessing import data_preprocessing

data_preprocessing(config_settings)

## 🧭 Validation

**Validation** aims at finding the **best hyperparameters** to be used for training the final model. After defining the hyperparameter search space, we compute a **Grid Search** to explore all possible combinations. For each combination we perform a **Time Series Cross-Validation**, useful to avoid data leakage by preserving the temporal order of events. **Early Stopping** is applied while training on each fold, stopping the process when the validation loss starts to increase. Whenever a new hyperparameter combination achieves the best average validation loss seen so far, it is saved as the new best. At the end, we obtain the best hyperparameters (i.e., those yielding the lowest average validation loss).

In [None]:
from validation import validation

config_settings = validation(config_settings)

## 🧠 Training

After identifying the best hyperparameters, the **final model** is obtained by **training** using those optimal values. We reserve a percentage of training set as **validation set** and we define a higher number of epochs than those used for validating the model, for applying **Early Stopping** ensuring the model is trained over how many as possible as epochs, avoiding overfitting. As soon as a new model has proven to be the current best one (i.e., it returns the best validation loss), its weights are saved. At the end, obtained the best trained model, we save it.

In [None]:
from training import training

training(config_settings)

##  🧪  Testing (Model Standalone)

After training the model, we **evaluate** it standalone on the testing set. The evaluation **metrics** computed are:
- **Average loss**: Weighted Cross Entropy.
- **Average loss per class**: Weighted Cross Entropy.
- **Class report**: Precision, Recall, and F1 for each class.
- **Confusion matrix**: Summarizes the number of correct and incorrect prediction for each class.
- **Top-k accuracy**: How many times the target is predicted in the first k most probable keys.
- **Kappa statistic**: Compares the model with a random one.

In [None]:
from testing import testing

avg_loss, avg_loss_per_class, metrics, cost_perc = testing(config_settings)