In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
import os
import csv
import pandas as pd
from datetime import datetime


In [None]:

class DiabetesDataset(Dataset):
    def __init__(self, csv_file):
        # Load CSV data
        self.df = pd.read_csv(csv_file)
        
        self.df = self.df.ffill().bfill()
        # Extract state features: 8 dimensions
        # [glu, glu_d, glu_t, hr, hr_d, hr_t, iob, hour_norm]
        self.states = self.df[[
            "glu", "glu_d", "glu_t",
            "hr", "hr_d", "hr_t",
            "iob", "hour"
        ]].values.astype(np.float32)
        
        # Extract action features: 2 dimensions [basal, bol]
        self.actions = self.df[["basal", "bolus"]].values.astype(np.float32)
        
        # Extract done flags (1 at episode boundaries, 0 otherwise)
        self.dones = self.df["done"].values.astype(np.float32)
        
       
        target_glucose = 0.0  
        self.rewards = -np.abs(self.states[:, 0] - target_glucose)
        
        # Compute next_states using a vectorized roll
        self.next_states = np.roll(self.states, shift=-1, axis=0)
        
        # For transitions where the current step is an episode end,
        # set the next state to be the current state so that transitions do not cross episodes.
        self.next_states[self.dones == 1] = self.states[self.dones == 1]
        
        # Remove the final row since it doesn't have a valid next state
        self.states = self.states[:-1]
        self.actions = self.actions[:-1]
        self.rewards = self.rewards[:-1]
        self.next_states = self.next_states[:-1]
        self.dones = self.dones[:-1]

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        # Return a dictionary for the transition (s, a, r, s', done)
        return {
            "state":      torch.tensor(self.states[idx],      dtype=torch.float32),
            "action":     torch.tensor(self.actions[idx],     dtype=torch.float32),
            "reward":     torch.tensor(self.rewards[idx],     dtype=torch.float32),
            "next_state": torch.tensor(self.next_states[idx], dtype=torch.float32),
            "done":       torch.tensor(self.dones[idx],       dtype=torch.float32)
        }


In [11]:
def compute_reward_torch(glucose_next):
    """
    Compute RI-based reward in PyTorch.
    """
    glucose_next = torch.clamp(glucose_next, min=1e-6)
    log_term = torch.log(glucose_next) ** 1.084
    f = 1.509 * (log_term - 5.381)
    ri = 10 * f ** 2

    reward = -torch.clamp(ri / 100.0, 0, 1)
    reward[glucose_next <= 39.0] = -15.0
    return reward

In [None]:
import numpy as np

def compute_risk_index(glucose_mgdl):
    """
    Computes the blood glucose Risk Index (RI) for each glucose value.
    Applies the formula from Kovatchev et al. (2005).
    """
    glucose_mgdl = np.clip(glucose_mgdl, 1e-6, None)  # Prevent log(0)
    log_term = np.log(glucose_mgdl) ** 1.084
    f = 1.509 * (log_term - 5.381)
    ri = 10 * f ** 2
    return ri

def compute_reward(glucose_next):
    """
    Computes the reward for each next glucose level (g_t+1)
    using the RI-based reward function.
    
    Inputs:
        glucose_next: array of glucose values in mg/dL (NumPy array)
    
    Outputs:
        reward: array of shape (len(glucose_next),), float
    """
    reward = np.zeros_like(glucose_next, dtype=np.float32)

    # Severe hypoglycemia penalty
    severe_hypo = glucose_next <= 39
    reward[severe_hypo] = -15

    # For all others, compute -normalized RI
    safe = ~severe_hypo
    ri = compute_risk_index(glucose_next[safe])
    ri_normalized = np.clip(ri / 100.0, 0, 1)  # Normalize to [0, 1]
    reward[safe] = -ri_normalized  # In range [-1, 0]

    return reward


In [3]:
# State: [current_glucose, glucose_trend, heart_rate, heart_rate_trend, insulin_on_board]
state_dim = 8
action_dim = 2  # Continuous insulin dose (steps of 0.05 units)

# Hyperparameters
alpha = 0.2  # Entropy coefficient
cql_weight = 5.0  # CQL penalty strength
batch_size = 256
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
class SACCQL(nn.Module):
    def __init__(self):
        super().__init__()
        # Actor (policy) network
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh()  # Output in [-1, 1] (rescale to insulin range)
        )
        
        # Critic networks (twin Q-functions)
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def act(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action = self.actor(state)
        return action.cpu().numpy()



In [5]:
dataset  = DiabetesDataset(csv_file="datasets/processed/563-test.csv")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [6]:
sample = dataset[0]
print("State:", sample["state"])
print("Action:", sample["action"])
print("Reward:", sample["reward"])
print("Next State:", sample["next_state"])
print("Done:", sample["done"])

State: tensor([ 1.5537, -0.1646, -0.5483, -0.0855, -1.3954, -0.5561, -1.2715,  0.0000])
Action: tensor([ 3.1836, -0.1611])
Reward: tensor(-1.5537)
Next State: tensor([ 1.5320, -0.1646, -0.5483, -1.0180, -1.3954, -0.5561, -1.2561,  0.0000])
Done: tensor(0.)


In [7]:
# Initialize networks and optimizers
model = SACCQL().to(device)
optimizer_actor = optim.Adam(model.actor.parameters(), lr=3e-4)
optimizer_critic = optim.Adam(list(model.q1.parameters()) + list(model.q2.parameters()), lr=3e-4)



In [10]:
# Enable anomaly detection
torch.autograd.set_detect_anomaly(True)

# Assume cql_weight, device, model, optimizer_critic, optimizer_actor, dataloader, etc. are defined.

print_interval = 100  # adjust as needed for debugging

for epoch in range(1000):
    for i, batch in enumerate(dataloader):
        states = batch["state"].to(device)
        actions = batch["action"].to(device)
        rewards = batch["reward"].to(device).unsqueeze(1)
        next_states = batch["next_state"].to(device)
        dones = batch["done"].to(device).unsqueeze(1)

        # -----------------------------
        # Critic Loss (CQL + TD error)
        # -----------------------------
        with torch.no_grad():
            next_actions = model.actor(next_states)
            q1_next = model.q1(torch.cat([next_states, next_actions], dim=1))
            q2_next = model.q2(torch.cat([next_states, next_actions], dim=1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next

        # Current Q-values for state-action pairs
        current_q1 = model.q1(torch.cat([states, actions], dim=1))
        current_q2 = model.q2(torch.cat([states, actions], dim=1))

        # TD loss (MSE between current Q and target Q)
        mse_loss = nn.MSELoss()
        td_loss = mse_loss(current_q1, target_q) + mse_loss(current_q2, target_q)

        # -----------------------------
        # CQL Penalty
        # -----------------------------
        random_actions = torch.rand_like(actions) * 2 - 1  # random actions in [-1, 1]
        q1_rand = model.q1(torch.cat([states, random_actions], dim=1))
        q2_rand = model.q2(torch.cat([states, random_actions], dim=1))
        # Concatenate and clamp to avoid numerical issues
        q_rand = torch.cat([q1_rand, q2_rand], dim=1)
        q_rand_clamped = torch.clamp(q_rand, min=-100, max=100)
        logsumexp_val = torch.logsumexp(q_rand_clamped, dim=1).mean()
        cql_penalty = logsumexp_val - (current_q1.mean() + current_q2.mean()) / 2

        # Total critic loss
        critic_loss = td_loss + cql_weight * cql_penalty

        # -----------------------------
        # Actor Loss (Maximize Q-value)
        # -----------------------------

        pred_actions = model.actor(states)

        # Check for NaNs early
        if torch.isnan(pred_actions).any():
            print("❌ NaNs in actor outputs! Fix this before training.")
        pred_actions = torch.nan_to_num(pred_actions)

        # Evaluate predicted Q-values under current policy
        q1_pred = model.q1(torch.cat([states, pred_actions], dim=1))
        q2_pred = model.q2(torch.cat([states, pred_actions], dim=1))

        # Use safe min (not in-place)
        q_pred = torch.min(q1_pred, q2_pred)

        # Actor loss (maximize Q)
        actor_loss = -q_pred.mean()

        # Check for NaNs again
        if torch.isnan(actor_loss).any():
            print("❌ NaN in actor loss!")
        actor_loss = torch.tensor(0.0, requires_grad=True).to(device)

        # Backward pass (safe)
        optimizer_actor.zero_grad()
        actor_loss = actor_loss.clone()
        actor_loss.backward()
        optimizer_actor.step()


        # -----------------------------
        # Update Critic
        # -----------------------------
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()

        # -----------------------------
        # Update Actor
        # -----------------------------
        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()


        # -----------------------------
        # Print diagnostic stats every 'print_interval' iterations
        # -----------------------------
        if i % print_interval == 0:
            print(f"Epoch: {epoch}, Iteration: {i}")
            print(f"TD Loss: {td_loss.item():.4f}, CQL Penalty: {cql_penalty.item():.4f}")
            print(f"Critic Loss: {critic_loss.item():.4f}, Actor Loss: {actor_loss.item():.4f}")
            print(f"Current Q1: min={current_q1.min().item():.2f}, max={current_q1.max().item():.2f}, mean={current_q1.mean().item():.2f}")
            print(f"Current Q2: min={current_q2.min().item():.2f}, max={current_q2.max().item():.2f}, mean={current_q2.mean().item():.2f}")
            print("-"*50)


Epoch: 0, Iteration: 0
TD Loss: 2.0368, CQL Penalty: 0.6973
Critic Loss: 5.5235, Actor Loss: 0.0000
Current Q1: min=-0.43, max=0.03, mean=-0.11
Current Q2: min=-0.10, max=0.41, mean=0.05
--------------------------------------------------
Epoch: 1, Iteration: 0
TD Loss: 2.0320, CQL Penalty: 0.7304
Critic Loss: 5.6841, Actor Loss: 0.0000
Current Q1: min=-3.03, max=-0.46, mean=-1.06
Current Q2: min=-2.05, max=-0.33, mean=-0.79
--------------------------------------------------
Epoch: 2, Iteration: 0
TD Loss: 2.3584, CQL Penalty: 0.7881
Critic Loss: 6.2987, Actor Loss: 0.0000
Current Q1: min=-5.53, max=-0.93, mean=-2.12
Current Q2: min=-5.77, max=-0.75, mean=-1.77
--------------------------------------------------
Epoch: 3, Iteration: 0
TD Loss: 3.8145, CQL Penalty: 0.8716
Critic Loss: 8.1724, Actor Loss: 0.0000
Current Q1: min=-9.81, max=-1.87, mean=-3.54
Current Q2: min=-9.46, max=-1.53, mean=-3.38
--------------------------------------------------
Epoch: 4, Iteration: 0
TD Loss: 3.9231,

KeyboardInterrupt: 