# Phase 2: Hybrid RL-Enhanced Cache with Dynamic Workloads
Training on temporal patterns to demonstrate RL value

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from collections import OrderedDict, defaultdict
from typing import List, Dict, Any, Optional, Tuple, Iterator
import gymnasium as gym
from gymnasium import spaces

In [None]:
class LRUCache:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.cache = OrderedDict()
        self.hits = 0
        self.misses = 0

    def access(self, item: int) -> bool:
        if item in self.cache:
            self.cache.move_to_end(item)
            self.hits += 1
            return True
        else:
            self.misses += 1
            if len(self.cache) >= self.capacity:
                self.cache.popitem(last=False)
            self.cache[item] = True
            return False

    def get_metrics(self) -> Dict[str, float]:
        total = self.hits + self.misses
        return {
            'hit_rate': self.hits / total if total > 0 else 0,
            'hits': self.hits,
            'misses': self.misses,
            'total_accesses': total
        }

    def reset(self):
        self.cache.clear()
        self.hits = 0
        self.misses = 0

In [None]:
class HybridCache:
    def __init__(self, capacity: int, base_policy: str = 'lru', rl_weight: float = 0.5,
                 miss_penalty: float = 10.0, hit_latency: float = 1.0):
        self.capacity = capacity
        self.base_policy = base_policy
        self.rl_weight = rl_weight
        self.miss_penalty = miss_penalty
        self.hit_latency = hit_latency

        self.cache = OrderedDict()
        self.access_frequency = defaultdict(int)
        self.last_access_time = {}
        self.current_time = 0

        self.hits = 0
        self.misses = 0
        self.evictions = 0
        self.rl_influenced_evictions = 0

    def access(self, item: int) -> Tuple[bool, float]:
        self.current_time += 1
        self.last_access_time[item] = self.current_time
        self.access_frequency[item] += 1

        if item in self.cache:
            self.cache.move_to_end(item)
            self.hits += 1
            return True, self.hit_latency
        else:
            self.misses += 1
            return False, self.miss_penalty

    def admit(self, item: int, rl_scores: Optional[Dict[int, float]] = None) -> Optional[int]:
        evicted = None
        if len(self.cache) >= self.capacity:
            if rl_scores and self.rl_weight > 0:
                evicted = self._evict_with_rl(rl_scores)
            else:
                evicted = self._base_eviction()

        self.cache[item] = True
        self.cache.move_to_end(item)
        return evicted

    def _evict_with_rl(self, rl_scores: Dict[int, float]) -> int:
        base_scores = self._get_base_scores()
        combined_scores = {}

        for item in self.cache.keys():
            base_score = base_scores.get(item, 0.5)
            rl_score = rl_scores.get(item, 0.5)
            combined_scores[item] = (1 - self.rl_weight) * base_score + self.rl_weight * rl_score

        base_victim = max(base_scores.items(), key=lambda x: x[1])[0]
        rl_victim = max(combined_scores.items(), key=lambda x: x[1])[0]

        if base_victim != rl_victim:
            self.rl_influenced_evictions += 1

        self.evictions += 1
        del self.cache[rl_victim]
        return rl_victim

    def _base_eviction(self) -> int:
        base_scores = self._get_base_scores()
        victim = max(base_scores.items(), key=lambda x: x[1])[0]
        self.evictions += 1
        del self.cache[victim]
        return victim

    def _get_base_scores(self) -> Dict[int, float]:
        scores = {}
        cache_items = list(self.cache.keys())

        if self.base_policy == 'lru':
            for idx, item in enumerate(cache_items):
                scores[item] = 1.0 - (idx / len(cache_items))
        elif self.base_policy == 'lfu':
            max_freq = max([self.access_frequency[item] for item in cache_items], default=1)
            for item in cache_items:
                freq = self.access_frequency[item]
                scores[item] = 1.0 - (freq / max_freq)
        else:
            for item in cache_items:
                scores[item] = 0.5

        return scores

    def get_metrics(self) -> Dict[str, float]:
        total = self.hits + self.misses
        return {
            'hit_rate': self.hits / total if total > 0 else 0,
            'miss_rate': self.misses / total if total > 0 else 0,
            'hits': self.hits,
            'misses': self.misses,
            'evictions': self.evictions,
            'rl_influence_rate': self.rl_influenced_evictions / self.evictions if self.evictions > 0 else 0,
            'total_accesses': total,
            'avg_latency': (self.hits * self.hit_latency + self.misses * self.miss_penalty) / total if total > 0 else 0
        }

    def reset(self):
        self.cache.clear()
        self.access_frequency.clear()
        self.last_access_time.clear()
        self.hits = 0
        self.misses = 0
        self.evictions = 0
        self.rl_influenced_evictions = 0
        self.current_time = 0

    def __len__(self):
        return len(self.cache)

    def __contains__(self, item):
        return item in self.cache

In [None]:
class TemporalShiftWorkload:
    def __init__(self, num_items: int, phase_length: int = 200, num_popular_sets: int = 3,
                 popular_set_size: int = 20, alpha: float = 1.5, seed: Optional[int] = None):
        self.num_items = num_items
        self.phase_length = phase_length
        self.num_popular_sets = num_popular_sets
        self.popular_set_size = popular_set_size
        self.alpha = alpha
        self.rng = np.random.RandomState(seed)
        self.request_count = 0

        self.popular_sets = []
        items = list(range(num_items))
        self.rng.shuffle(items)
        for i in range(num_popular_sets):
            start = i * popular_set_size
            end = start + popular_set_size
            self.popular_sets.append(items[start:end])

    def _current_phase(self) -> int:
        return (self.request_count // self.phase_length) % self.num_popular_sets

    def generate(self, num_requests: int) -> List[int]:
        requests = []
        for _ in range(num_requests):
            phase = self._current_phase()
            popular_items = self.popular_sets[phase]

            if self.rng.random() < 0.8:
                idx = min(self.rng.zipf(self.alpha), len(popular_items)) - 1
                item = popular_items[idx % len(popular_items)]
            else:
                item = self.rng.randint(0, self.num_items)

            requests.append(item)
            self.request_count += 1

        return requests

    def reset(self):
        self.request_count = 0

    def __iter__(self) -> Iterator[int]:
        while True:
            phase = self._current_phase()
            popular_items = self.popular_sets[phase]

            if self.rng.random() < 0.8:
                idx = min(self.rng.zipf(self.alpha), len(popular_items)) - 1
                item = popular_items[idx % len(popular_items)]
            else:
                item = self.rng.randint(0, self.num_items)

            self.request_count += 1
            yield item

In [None]:
class PopularitySpikeWorkload:
    def __init__(self, num_items: int, alpha: float = 1.5, spike_probability: float = 0.01,
                 spike_duration: int = 50, spike_intensity: float = 0.9, seed: Optional[int] = None):
        self.num_items = num_items
        self.alpha = alpha
        self.spike_probability = spike_probability
        self.spike_duration = spike_duration
        self.spike_intensity = spike_intensity
        self.rng = np.random.RandomState(seed)

        self.current_spike_item = None
        self.spike_remaining = 0
        self.request_count = 0

    def generate(self, num_requests: int) -> List[int]:
        requests = []
        for _ in range(num_requests):
            if self.spike_remaining == 0 and self.rng.random() < self.spike_probability:
                self.current_spike_item = self.rng.randint(0, self.num_items)
                self.spike_remaining = self.spike_duration

            if self.spike_remaining > 0:
                if self.rng.random() < self.spike_intensity:
                    item = self.current_spike_item
                else:
                    item = self.rng.zipf(self.alpha) % self.num_items
                self.spike_remaining -= 1
            else:
                item = self.rng.zipf(self.alpha) % self.num_items

            requests.append(item)
            self.request_count += 1

        return requests

    def reset(self):
        self.current_spike_item = None
        self.spike_remaining = 0
        self.request_count = 0

    def __iter__(self) -> Iterator[int]:
        while True:
            if self.spike_remaining == 0 and self.rng.random() < self.spike_probability:
                self.current_spike_item = self.rng.randint(0, self.num_items)
                self.spike_remaining = self.spike_duration

            if self.spike_remaining > 0:
                if self.rng.random() < self.spike_intensity:
                    item = self.current_spike_item
                else:
                    item = self.rng.zipf(self.alpha) % self.num_items
                self.spike_remaining -= 1
            else:
                item = self.rng.zipf(self.alpha) % self.num_items

            self.request_count += 1
            yield item

In [None]:
class TimeOfDayWorkload:
    def __init__(self, num_items: int, cycle_length: int = 500, num_cycles: int = 4,
                 phase_overlap: float = 0.1, seed: Optional[int] = None):
        self.num_items = num_items
        self.cycle_length = cycle_length
        self.num_cycles = num_cycles
        self.phase_overlap = phase_overlap
        self.rng = np.random.RandomState(seed)
        self.request_count = 0

        self.cycle_items = []
        items_per_cycle = num_items // num_cycles
        for i in range(num_cycles):
            start = i * items_per_cycle
            end = start + items_per_cycle
            self.cycle_items.append(list(range(start, end)))

    def generate(self, num_requests: int) -> List[int]:
        requests = []
        for _ in range(num_requests):
            position = (self.request_count % self.cycle_length) / self.cycle_length
            phase = int(position * self.num_cycles)

            main_items = self.cycle_items[phase]

            if self.rng.random() < self.phase_overlap:
                adjacent_phase = (phase + 1) % self.num_cycles
                main_items = main_items + self.cycle_items[adjacent_phase]

            if len(main_items) > 0:
                idx = min(self.rng.zipf(1.5), len(main_items)) - 1
                item = main_items[idx % len(main_items)]
            else:
                item = self.rng.randint(0, self.num_items)

            requests.append(item)
            self.request_count += 1

        return requests

    def reset(self):
        self.request_count = 0

    def __iter__(self) -> Iterator[int]:
        while True:
            position = (self.request_count % self.cycle_length) / self.cycle_length
            phase = int(position * self.num_cycles)

            main_items = self.cycle_items[phase]

            if self.rng.random() < self.phase_overlap:
                adjacent_phase = (phase + 1) % self.num_cycles
                main_items = main_items + self.cycle_items[adjacent_phase]

            if len(main_items) > 0:
                idx = min(self.rng.zipf(1.5), len(main_items)) - 1
                item = main_items[idx % len(main_items)]
            else:
                item = self.rng.randint(0, self.num_items)

            self.request_count += 1
            yield item

In [None]:
class HybridCacheEnv(gym.Env):
    metadata = {'render_modes': ['human']}

    def __init__(self, cache_capacity: int, num_items: int, workload_generator,
                 episode_length: int = 1000, state_size: int = 30, base_policy: str = 'lru',
                 rl_weight: float = 0.5, alpha: float = 1.0, beta: float = 0.1, gamma: float = 0.05):
        super().__init__()

        self.cache = HybridCache(
            capacity=cache_capacity,
            base_policy=base_policy,
            rl_weight=rl_weight,
            miss_penalty=10.0,
            hit_latency=1.0
        )
        self.num_items = num_items
        self.workload = workload_generator
        self.episode_length = episode_length
        self.state_size = state_size
        self.base_policy = base_policy

        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        self.current_step = 0
        self.current_request = None

        self.action_space = spaces.Box(low=0.0, high=1.0, shape=(cache_capacity,), dtype=np.float32)
        self.observation_space = spaces.Box(low=0, high=1, shape=(state_size,), dtype=np.float32)

        self.request_history = []
        self.item_to_cache_idx = {}
        self.item_last_access_time = {}
        self.item_access_trend = {}
        self.item_recent_frequency = {}
        self.item_historical_frequency = {}

    def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]:
        super().reset(seed=seed)
        self.cache.reset()
        self.current_step = 0
        self.request_history = []
        self.item_to_cache_idx = {}
        self.item_last_access_time = {}
        self.item_access_trend = {}
        self.item_recent_frequency = {}
        self.item_historical_frequency = {}
        self.current_request = next(iter(self.workload))
        return self._get_observation(), {}

    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        is_hit, latency = self.cache.access(self.current_request)
        rl_eviction_scores = self._action_to_eviction_scores(action)

        evicted = None
        if not is_hit:
            evicted = self.cache.admit(self.current_request, rl_eviction_scores)

        hit_reward = self.alpha if is_hit else -self.alpha
        latency_penalty = -self.beta * latency / self.cache.miss_penalty
        bandwidth_penalty = 0 if is_hit else -self.gamma
        reward = hit_reward + latency_penalty + bandwidth_penalty

        if evicted is not None:
            future_access_penalty = 0.0
            if evicted in self.request_history[-10:]:
                future_access_penalty = -0.2
            trend = self.item_access_trend.get(evicted, 0.0)
            if trend > 0.5:
                reward += -0.3
            reward += future_access_penalty

        self._update_temporal_features(self.current_request)
        self.request_history.append(self.current_request)
        if len(self.request_history) > 100:
            self.request_history.pop(0)

        self.current_step += 1
        done = self.current_step >= self.episode_length
        truncated = False
        self.current_request = next(iter(self.workload))

        obs = self._get_observation()
        info = {
            'hit_rate': self.cache.get_metrics()['hit_rate'],
            'cache_size': len(self.cache),
            'evicted': evicted,
            'rl_influenced': self.cache.rl_influenced_evictions
        }
        return obs, reward, done, truncated, info

    def _action_to_eviction_scores(self, action: np.ndarray) -> Dict[Any, float]:
        scores = {}
        cache_items = list(self.cache.cache.keys())
        self.item_to_cache_idx = {}
        for idx, item in enumerate(cache_items):
            self.item_to_cache_idx[item] = idx
        for item in cache_items:
            idx = self.item_to_cache_idx.get(item, 0)
            if idx < len(action):
                scores[item] = float(action[idx])
            else:
                scores[item] = 0.5
        return scores

    def _update_temporal_features(self, item: int):
        self.item_last_access_time[item] = self.current_step
        self.item_recent_frequency[item] = self.item_recent_frequency.get(item, 0) + 1
        self.item_historical_frequency[item] = self.item_historical_frequency.get(item, 0) + 1

        if self.current_step % 100 == 0:
            for it in list(self.item_recent_frequency.keys()):
                recent_freq = self.item_recent_frequency.get(it, 0)
                hist_freq = self.item_historical_frequency.get(it, 1)
                trend = recent_freq / max(hist_freq / (self.current_step / 100 + 1), 1)
                self.item_access_trend[it] = min(trend, 1.0)
                self.item_recent_frequency[it] = 0

    def _get_observation(self) -> np.ndarray:
        state = np.zeros(self.state_size, dtype=np.float32)
        state[0] = len(self.cache) / self.cache.capacity
        state[1] = self.current_request / self.num_items
        freq = self.cache.access_frequency.get(self.current_request, 0)
        state[2] = min(freq / 10.0, 1.0)
        if self.current_request in self.cache:
            state[3] = 1.0
        recent_unique = len(set(self.request_history[-20:]))
        state[4] = recent_unique / 20.0 if self.request_history else 0.0
        metrics = self.cache.get_metrics()
        state[5] = metrics['hit_rate']
        if len(self.request_history) >= 5:
            recent_requests = self.request_history[-5:]
            state[6] = 1.0 if self.current_request in recent_requests else 0.0
        cache_items = list(self.cache.cache.keys())
        for i, item in enumerate(cache_items[:10]):
            freq = self.cache.access_frequency.get(item, 0)
            state[7 + i] = min(freq / 10.0, 1.0)
        state[17] = metrics.get('rl_influence_rate', 0.0)
        if len(self.request_history) >= 3:
            pattern = [self.request_history[-3], self.request_history[-2], self.request_history[-1]]
            state[18] = 1.0 if self.current_request in pattern else 0.0
        if self.current_request in self.item_last_access_time:
            time_since_access = self.current_step - self.item_last_access_time[self.current_request]
            state[19] = min(time_since_access / 100.0, 1.0)
        else:
            state[19] = 1.0
        state[20] = self.item_access_trend.get(self.current_request, 0.0)
        for idx, item in enumerate(cache_items[:5]):
            if item in self.item_last_access_time:
                recency = self.current_step - self.item_last_access_time[item]
                state[21 + idx] = min(recency / 100.0, 1.0)
            else:
                state[21 + idx] = 1.0
        for idx, item in enumerate(cache_items[:4]):
            state[26 + idx] = self.item_access_trend.get(item, 0.0)
        return state

    def render(self):
        pass

In [None]:
class PriorityNetwork(torch.nn.Module):
    def __init__(self, state_size: int, action_size: int, hidden_size: int = 128):
        super(PriorityNetwork, self).__init__()
        self.fc1 = torch.nn.Linear(state_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

class ReplayBuffer:
    def __init__(self, capacity: int = 10000):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size: int):
        batch = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, next_states, dones = zip(*[self.buffer[i] for i in batch])
        return states, actions, rewards, next_states, dones

    def __len__(self):
        return len(self.buffer)

class PriorityAgent:
    def __init__(self, state_size: int, action_size: int, lr: float = 0.001, gamma: float = 0.99,
                 epsilon: float = 1.0, epsilon_min: float = 0.01, epsilon_decay: float = 0.995,
                 device: str = 'cpu', batch_size: int = 64, buffer_capacity: int = 10000):
        self.state_size = state_size
        self.action_size = action_size
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.device = device
        self.batch_size = batch_size

        self.policy_net = PriorityNetwork(state_size, action_size).to(device)
        self.target_net = PriorityNetwork(state_size, action_size).to(device)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()

        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=lr)
        self.replay_buffer = ReplayBuffer(buffer_capacity)
        self.update_counter = 0
        self.target_update_freq = 10

    def select_action(self, state: np.ndarray) -> np.ndarray:
        if np.random.random() < self.epsilon:
            return np.random.rand(self.action_size).astype(np.float32)
        else:
            with torch.no_grad():
                state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
                priorities = self.policy_net(state_tensor)
                return priorities.cpu().numpy()[0]

    def store_transition(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)

    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return 0.0

        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.FloatTensor(np.array(actions)).to(self.device)
        rewards = torch.FloatTensor(rewards).unsqueeze(1).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.FloatTensor(dones).unsqueeze(1).to(self.device)

        current_priorities = self.policy_net(states)
        next_priorities = self.target_net(next_states).detach()

        target_priorities = rewards + (1 - dones) * self.gamma * next_priorities.mean(dim=1, keepdim=True)
        loss = torch.nn.functional.mse_loss(current_priorities.mean(dim=1, keepdim=True), target_priorities)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        self.update_counter += 1
        if self.update_counter % self.target_update_freq == 0:
            self.target_net.load_state_dict(self.policy_net.state_dict())

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

        return loss.item()

    def save(self, path: str):
        torch.save({
            'policy_net': self.policy_net.state_dict(),
            'target_net': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epsilon': self.epsilon
        }, path)

    def load(self, path: str):
        checkpoint = torch.load(path, map_location=self.device)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.target_net.load_state_dict(checkpoint['target_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.epsilon = checkpoint['epsilon']

## Training Configuration

In [None]:
WORKLOAD_TYPE = 'temporal_shift'
CACHE_SIZE = 100
NUM_ITEMS = 1000
EPISODES = 1000
STEPS_PER_EPISODE = 1000
RL_WEIGHT = 0.5
BASE_POLICY = 'lru'
LEARNING_RATE = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

## Initialize Workload and Environment

In [None]:
if WORKLOAD_TYPE == 'temporal_shift':
    workload = TemporalShiftWorkload(num_items=NUM_ITEMS, seed=42)
elif WORKLOAD_TYPE == 'popularity_spike':
    workload = PopularitySpikeWorkload(num_items=NUM_ITEMS, seed=42)
elif WORKLOAD_TYPE == 'time_of_day':
    workload = TimeOfDayWorkload(num_items=NUM_ITEMS, seed=42)
else:
    workload = TemporalShiftWorkload(num_items=NUM_ITEMS, seed=42)

env = HybridCacheEnv(
    cache_capacity=CACHE_SIZE,
    num_items=NUM_ITEMS,
    workload_generator=workload,
    episode_length=STEPS_PER_EPISODE,
    state_size=30,
    base_policy=BASE_POLICY,
    rl_weight=RL_WEIGHT
)

agent = PriorityAgent(
    state_size=30,
    action_size=CACHE_SIZE,
    lr=LEARNING_RATE,
    device=device
)

## Training Loop

In [None]:
metrics = {
    'episode_rewards': [],
    'episode_hit_rates': [],
    'episode_rl_influence': [],
    'episode_avg_latency': [],
    'best_hit_rate': 0.0
}

for episode in range(EPISODES):
    state, _ = env.reset()
    total_reward = 0
    episode_hit_rates = []

    for step in range(STEPS_PER_EPISODE):
        action = agent.select_action(state)
        next_state, reward, done, truncated, info = env.step(action)
        agent.store_transition(state, action, reward, next_state, done or truncated)
        agent.train()
        total_reward += reward
        episode_hit_rates.append(info['hit_rate'])
        state = next_state
        if done or truncated:
            break

    cache_metrics = env.cache.get_metrics()
    avg_hit_rate = np.mean(episode_hit_rates)
    metrics['episode_rewards'].append(total_reward)
    metrics['episode_hit_rates'].append(avg_hit_rate)
    metrics['episode_rl_influence'].append(cache_metrics.get('rl_influence_rate', 0.0))
    metrics['episode_avg_latency'].append(cache_metrics.get('avg_latency', 0.0))

    if avg_hit_rate > metrics['best_hit_rate']:
        metrics['best_hit_rate'] = avg_hit_rate

    if (episode + 1) % 50 == 0:
        print(f"Ep {episode + 1}/{EPISODES} | Reward: {total_reward:.2f} | "
              f"Hit Rate: {avg_hit_rate:.4f} | RL Influence: {cache_metrics.get('rl_influence_rate', 0.0):.4f}")

## Baseline Evaluation

In [None]:
baseline_workload = TemporalShiftWorkload(num_items=NUM_ITEMS, seed=42) if WORKLOAD_TYPE == 'temporal_shift' else workload
baseline_cache = LRUCache(CACHE_SIZE)
baseline_requests = baseline_workload.generate(STEPS_PER_EPISODE)

for req in baseline_requests:
    baseline_cache.access(req)

baseline_metrics = baseline_cache.get_metrics()
baseline_hit_rate = baseline_metrics['hit_rate']

improvement = ((metrics['best_hit_rate'] - baseline_hit_rate) / baseline_hit_rate) * 100

print(f"\nBaseline LRU Hit Rate: {baseline_hit_rate:.4f}")
print(f"RL-Enhanced Hit Rate: {metrics['best_hit_rate']:.4f}")
print(f"Improvement: {improvement:.2f}%")

## Visualization

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

axes[0, 0].plot(metrics['episode_hit_rates'], alpha=0.6)
axes[0, 0].axhline(y=baseline_hit_rate, color='r', linestyle='--', label='Baseline LRU')
axes[0, 0].set_title('Hit Rate Over Training')
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Hit Rate')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

axes[0, 1].plot(metrics['episode_rewards'], alpha=0.6)
axes[0, 1].set_title('Episode Rewards')
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Total Reward')
axes[0, 1].grid(True, alpha=0.3)

axes[1, 0].plot(metrics['episode_rl_influence'], alpha=0.6)
axes[1, 0].set_title('RL Influence Rate')
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('RL Influence Rate')
axes[1, 0].grid(True, alpha=0.3)

improvement_per_ep = [(hr - baseline_hit_rate) / baseline_hit_rate * 100 for hr in metrics['episode_hit_rates']]
axes[1, 1].plot(improvement_per_ep, alpha=0.6)
axes[1, 1].axhline(y=0, color='black', linestyle='-')
axes[1, 1].set_title('Improvement Over Baseline (%)')
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('Improvement (%)')
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{WORKLOAD_TYPE}_training_results.png', dpi=300)
plt.show()

## Save Model and Metrics

In [None]:
agent.save(f'dynamic_agent_{WORKLOAD_TYPE}_colab.pth')
with open(f'metrics_{WORKLOAD_TYPE}_colab.json', 'w') as f:
    json.dump(metrics, f, indent=2)

from google.colab import files
files.download(f'dynamic_agent_{WORKLOAD_TYPE}_colab.pth')
files.download(f'metrics_{WORKLOAD_TYPE}_colab.json')
files.download(f'{WORKLOAD_TYPE}_training_results.png')