In [11]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import random
import pandas as pd
import yaml
from liftoff import parse_opts
from argparse import Namespace
from experiment_src import *
import numpy as np
import networkx as nx
import seaborn as sns
sns.set_style("whitegrid", {'axes.grid' : False})

root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath("."))))
sys.path.append(root_dir)

from experiment_src import (
    train_net_with_value_function_approximation,
    generate_random_policy_transitions,
    generate_transitions_observations
)
from experiments.experiment_utils import setup_logger, seed_everything
from overfitting.src.policy_iteration import random_policy_evaluation_q_stochastic
from overfitting.src.utils import (
    create_random_policy,
    extract_V_from_Q_for_stochastic_policy,
)
from overfitting.src.visualize import draw_simple_gridworld
from experiment_src import generate_train_test_split_with_valid_path, check_path_existence_to_any_terminal

import logging

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
def count_occurrences_and_compute_percentage(
    sampled_transitions_list, total_unique_transitions, N
):
    # Count occurrences of each index in the sampled list
    occurrences_count = {}
    for index in sampled_transitions_list:
        occurrences_count[index] = occurrences_count.get(index, 0) + 1

    # Compute the number of indexes that appear at least N times
    at_least_N = sum(1 for count in occurrences_count.values() if count >= N)

    # Compute the percentage relative to the total number of unique transitions
    percentage = (at_least_N / total_unique_transitions) * 100
    return percentage, occurrences_count

In [13]:
logger = logging.getLogger(__name__)

rows, cols = 10, 10
start_state = (0, 0)
terminal_states = {(rows - 2, cols - 2): 1.0}
p_success = 1
seed = 3

num_steps = 40_000
min_samples = 20
# min_samples = 0

# Learning hyperparameters
alpha = 0.1  # Learning rate
gamma = 0.9  # Discount factor
epsilon = 0.05  # Convergence criterion
tau = 100
batch_size = 32
train_max_iterations = 50
theta = 1e-6

env = make_env(rows, cols, start_state, p_success, terminal_states, seed)

states = list(set([s for s, _ in env.mdp.keys()]))
actions = list(set([a for _, a in env.mdp.keys()]))

In [14]:
tau = 1000
transitions_list = [(key[0], key[1], *value[0]) for key, value in env.mdp.items()]
transitions_train, transitions_val = train_test_split(
    transitions_list, test_size=0.2, random_state=seed
)

random_policy_transitions = generate_transitions_observations(
    transitions_list,
    num_steps,
    tau=tau,
    min_samples=min_samples,
)


### Training
input_size = len(states[0])  # Or another way to represent the size of your input
output_size = len(actions)

# Initialize the DQN
qnet_random_policy = QNET(input_size, output_size)

# loss_record_random_policy = train_net_with_value_function_approximation(
#     qnet_random_policy,
#     random_policy_transitions,
#     states,
#     actions,
#     gamma,
#     epsilon,
#     batch_size,
#     train_max_iterations,
#     logger,
# )

In [15]:
# def train_net_with_neural_fitted_q_scaled_loss(
#     net,
#     transitions,
#     Q_pi_random,
#     states,
#     actions,
#     gamma,
#     epsilon,
#     batch_size,
#     max_iterations,
#     logger=None,
# ):

max_iterations = 10
transitions = random_policy_transitions

net = QNET(input_size, output_size)
if logger is None:
    logger = logging.getLogger(__name__)

net.train()

transitions_for_counting = [(s, a, ns, r, int(d)) for s, a, ns, r, d, _ in transitions]
transition_counts = Counter(transitions_for_counting)

# Calculate expected frequency under uniform distribution
N_total = len(transitions)
N_unique = len(set(transitions_for_counting))
expected_frequency = N_total / N_unique

# Compute scaling factor relative to uniform distribution
inverse_frequency_scaling = {
    t: expected_frequency / count for t, count in transition_counts.items()
}

dataset = TransitionDataset(transitions)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(net.parameters(), lr=0.001)
loss_fn = nn.MSELoss(reduction="none")
loss_record = []

for epoch in range(max_iterations):
    total_loss = 0
    for state, action, next_state, reward, done in dataloader:
        optimizer.zero_grad()
        q_values = net(state)
        next_q_values = net(next_state)
        max_next_q_values = next_q_values.detach().max(1)[0].unsqueeze(1)
        target_q_values = reward.unsqueeze(1) + gamma * max_next_q_values * (
            ~done.unsqueeze(1)
        )
        target_q_values = torch.where(
            done.unsqueeze(1), reward.unsqueeze(1), target_q_values
        )

        scaled_losses = torch.zeros(size=(len(reward), 1), device=q_values.device)
        for i, trans in enumerate(zip(state, action, next_state, reward, done)):
            transition = tuple(trans[:5])
            scale_factor = inverse_frequency_scaling.get(transition, 1.0)
            loss = loss_fn(q_values[i, action[i]].unsqueeze(0), target_q_values[i])
            scaled_losses[i] = loss * scale_factor

        loss = scaled_losses.mean()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    loss_record.append((epoch, total_loss, None))

logger.info(f"Exiting after {epoch + 1} epochs with total loss {total_loss}")