In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
import os
import csv
import pandas as pd
from datetime import datetime
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm  # For progress bar


In [2]:
class DiabetesDataset(Dataset):
    def __init__(self, csv_file):
        # Load and clean CSV data
        self.df = pd.read_csv(csv_file)
        self.df = self.df.ffill().bfill()

        # Extract state features (8 dimensions)
        self.states = self.df[[
            "glu", "glu_d", "glu_t",
            "hr", "hr_d", "hr_t",
            "iob", "hour"
        ]].values.astype(np.float32)

        # Extract action features (2 dimensions)
        self.actions = self.df[["basal", "bolus"]].values.astype(np.float32)

        # Extract done flags
        self.dones = self.df["done"].values.astype(np.float32)

        # Compute rewards based on glu_raw at t+1
        glucose_next_tensor = torch.tensor(self.df["glu_raw"].values, dtype=torch.float32)
        self.rewards = compute_reward_torch(glucose_next_tensor) / 15.0  # Normalize if needed

        # Compute next_states using vectorized roll
        self.next_states = np.roll(self.states, shift=-1, axis=0)

        # Prevent transitions across episode boundaries
        self.next_states[self.dones == 1] = self.states[self.dones == 1]

        # Slice to make all arrays align: remove last step (no next state), and align reward with t

        self.states      = self.states[:-2]
        self.actions     = self.actions[:-2]
        self.rewards     = self.rewards[1:-1]
        self.next_states = self.next_states[:-2]
        self.dones       = self.dones[:-2]
        self.dones       = torch.tensor(self.dones, dtype=torch.float32)

        # Sanity check
        L = len(self.states)
        assert all(len(arr) == L for arr in [self.actions, self.rewards, self.next_states, self.dones]), \
            f"Inconsistent lengths in dataset components: {L}"

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return {
            "state":      torch.from_numpy(self.states[idx]).float(),
            "action":     torch.from_numpy(self.actions[idx]).float(),
            "reward":     self.rewards[idx].float(),
            "next_state": torch.from_numpy(self.next_states[idx]).float(),
            "done":       self.dones[idx]
        }


In [None]:
def compute_reward_torch(glucose_next):
    """
    Compute RI-based reward in PyTorch.
    """
    glucose_next = torch.clamp(glucose_next, min=1e-6)
    log_term = torch.log(glucose_next) ** 1.084
    f = 1.509 * (log_term - 5.381)
    ri = 10 * f ** 2

    reward = -torch.clamp(ri / 100.0, 0, 1)
    reward[glucose_next <= 39.0] = -15.0
    return reward



In [4]:
sample = dataset[0]
print("State:", sample["state"])
print("Action:", sample["action"])
print("Reward:", sample["reward"])
print("Next State:", sample["next_state"])
print("Done:", sample["done"])

State: tensor([-0.9303, -0.4338,  1.1159,  1.8670, -3.1943, -1.2770, -1.8526,  0.0000])
Action: tensor([-1.0147, -0.0992])
Reward: tensor(-0.0003)
Next State: tensor([-0.9303, -0.4338,  1.1159,  1.8670, -3.1943, -1.2770, -1.8258,  0.0000])
Done: tensor(0.)


In [None]:
class SACCQL(nn.Module):
    def __init__(self):
        super().__init__()
        # Actor (policy) network
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh()  # Output in [-1, 1] (rescale to insulin range)
        )
        
        # Critic networks (twin Q-functions)
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )

        self.q1_target = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        ) # Same as q1
        self.q2_target = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )  # Same as q2
        
        # Initialize targets to match main critics
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

        
    def act(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action = self.actor(state)
        return action.cpu().numpy()

    def update_targets(self, tau=0.005):
        # Soft update: target = tau * main + (1-tau) * target
        with torch.no_grad():
            for t, m in zip(self.q1_target.parameters(), self.q1.parameters()):
                t.data.copy_(tau * m.data + (1 - tau) * t.data)
            for t, m in zip(self.q2_target.parameters(), self.q2.parameters()):
                t.data.copy_(tau * m.data + (1 - tau) * t.data)






In [None]:


def debug_tensor(tensor, name="", check_grad=False, threshold=1e6):
    """
    Prints diagnostic information about a tensor.
    
    Args:
        tensor (torch.Tensor): The tensor to check.
        name (str): Optional name for logging.
        check_grad (bool): Also check gradients if available.
        threshold (float): Warn if values exceed this.
    """
    try:
        t_min = tensor.min().item()
        t_max = tensor.max().item()
        t_mean = tensor.mean().item()
        t_std = tensor.std().item()
    except Exception as e:
        print(f"⚠️ Could not extract stats for {name}: {e}")
        return

    print(f"🧪 [{name}] Shape: {tuple(tensor.shape)} | min: {t_min:.4f}, max: {t_max:.4f}, mean: {t_mean:.4f}, std: {t_std:.4f}")

    if torch.isnan(tensor).any():
        print(f"❌ NaNs detected in {name}")
    if torch.isinf(tensor).any():
        print(f"❌ Infs detected in {name}")
    if abs(t_min) > threshold or abs(t_max) > threshold:
        print(f"⚠️ Extreme values detected in {name}: values exceed ±{threshold}")

    if check_grad and tensor.requires_grad and tensor.grad is not None:
        grad = tensor.grad
        print(f"🔁 [{name}.grad] norm: {grad.norm().item():.4f}")
        if torch.isnan(grad).any():
            print(f"❌ NaNs in gradient of {name}")
        if torch.isinf(grad).any():
            print(f"❌ Infs in gradient of {name}")



In [None]:
def compute_cql_penalty(q1, q2, num_action_samples=100):
    """
    q1, q2: Q-values for dataset actions (batch_size, 1)
    """
    # 1. Sample additional actions from current policy
    states = ...  # From batch
    with torch.no_grad():
        policy_actions = model.actor(states)  # (batch_size, 2)
    
    # 2. Create action candidates (dataset actions + policy actions)
    all_actions = torch.cat([dataset_actions, policy_actions], dim=0)
    
    # 3. Compute Q-values for all actions
    q1_all = model.q1(torch.cat([states.expand(num_action_samples, *states.shape), all_actions], dim=-1))
    q2_all = model.q2(...)  # Same pattern
    
    # 4. Compute proper logsumexp
    logsumexp_val = torch.logsumexp(0.5*(q1_all + q2_all), dim=0)  # More stable
    
    # 5. Final penalty
    cql_penalty = logsumexp_val - 0.5*(q1.mean() + q2.mean())
    
    return cql_penalty

In [None]:

state_dim = 8
action_dim = 2  # Continuous insulin dose 

# Hyperparameters
alpha = 0.2  # Entropy coefficient
cql_weight = 1.0  # CQL penalty strength
batch_size = 256
device = "cuda" if torch.cuda.is_available() else "cpu"

dataset  = DiabetesDataset(csv_file="datasets/processed/559-train.csv")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


# Initialize networks and optimizers
model = SACCQL().to(device)
optimizer_actor = optim.Adam(model.actor.parameters(), lr=3e-4)
optimizer_critic = optim.Adam(list(model.q1.parameters()) + list(model.q2.parameters()), lr=3e-4)

torch.autograd.set_detect_anomaly(True)
print_interval = 100

In [None]:
#NOT WORKING VERSION
for epoch in range(1000):
    for i, batch in enumerate(dataloader):
        states = batch["state"].to(device)
        actions = batch["action"].to(device)
        rewards = batch["reward"].to(device).unsqueeze(1)
        next_states = batch["next_state"].to(device)
        dones = batch["done"].to(device).unsqueeze(1)

        # 🧪 Check raw inputs
        debug_tensor(states, "states")
        debug_tensor(actions, "actions")



        # -----------------------------
        # Critic Loss (CQL + TD error)
        # -----------------------------
        with torch.no_grad():
            next_actions = model.actor(next_states)
            # Use TARGET critics for Q_next
            debug_tensor(next_actions, "next_actions")
            q1_next = model.q1_target(torch.cat([next_states, next_actions], dim=1))
            q2_next = model.q2_target(torch.cat([next_states, next_actions], dim=1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next
            target_q = torch.clamp(target_q, min=-200, max=0)

        current_q1 = model.q1(torch.cat([states, actions], dim=1))
        current_q2 = model.q2(torch.cat([states, actions], dim=1))

        debug_tensor(current_q1, "current_q1")
        debug_tensor(current_q2, "current_q2")

        td_loss = nn.MSELoss()(current_q1, target_q) + nn.MSELoss()(current_q2, target_q)

        # -----------------------------
        # CQL Penalty
        # -----------------------------
        random_actions = torch.rand_like(actions) * 2 - 1
        q1_rand = model.q1(torch.cat([states, random_actions], dim=1))
        q2_rand = model.q2(torch.cat([states, random_actions], dim=1))
        q_rand = torch.cat([q1_rand, q2_rand], dim=1)
        q_rand_clamped = torch.clamp(q_rand, min=-100, max=100)
        logsumexp_val = torch.logsumexp(q_rand_clamped, dim=1).mean()
        cql_penalty = logsumexp_val - (current_q1.mean() + current_q2.mean()) / 2
        critic_loss = td_loss + cql_weight * cql_penalty

        debug_tensor(critic_loss, "critic_loss")

        # -----------------------------
        # Actor Loss
        # -----------------------------
        states_actor = states.clone().detach()# prevent in-place gradient conflict
        pred_actions = model.actor(states_actor)



        """
        if torch.isnan(pred_actions).any():
            print("❌ NaNs in actor outputs! Fix this before training.")
            pred_actions = torch.nan_to_num(pred_actions.clone())  # ✅ clone before modifying
        """
        
        # 🔒 Safe skip if NaNs
        if torch.isnan(pred_actions).any():
            print("❌ NaNs in actor output, skipping batch.")
            continue

        debug_tensor(pred_actions, "pred_actions")
        

        sa = torch.cat([states_actor, pred_actions], dim=1)
        q1_pred = model.q1(sa)
        q2_pred = model.q2(sa)
        q_pred = torch.min(q1_pred, q2_pred)

 
        actor_loss = -q_pred.mean()

        debug_tensor(actor_loss, "actor_loss")

        if torch.isnan(actor_loss).any():
            print("❌ Skipping batch due to NaN in actor loss.")
            continue

        # -----------------------------
        # Update Critic
        # -----------------------------
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()

        # -----------------------------
        # Update Actor
        # -----------------------------
        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()

        # After critic and actor updates:
        model.update_targets()  # Add this line

        # -----------------------------
        # Print Stats
        # -----------------------------
        if i % print_interval == 0:
            print(f"Epoch: {epoch}, Iteration: {i}")
            print(f"TD Loss: {td_loss.item():.4f}, CQL Penalty: {cql_penalty.item():.4f}")
            print(f"Critic Loss: {critic_loss.item():.4f}, Actor Loss: {actor_loss.item():.4f}")
            print(f"Current Q1: min={current_q1.min().item():.2f}, max={current_q1.max().item():.2f}, mean={current_q1.mean().item():.2f}")
            print(f"Current Q2: min={current_q2.min().item():.2f}, max={current_q2.max().item():.2f}, mean={current_q2.mean().item():.2f}")
            print("-" * 50)


In [None]:
#WORKING VERSION
for epoch in range(1000):
    for i, batch in enumerate(dataloader):
        
        # Move batch data to device
        states = batch["state"].to(device)
        actions = batch["action"].to(device)
        rewards = batch["reward"].to(device).unsqueeze(1)
        next_states = batch["next_state"].to(device)
        dones = batch["done"].to(device).unsqueeze(1)

        # -----------------------------
        # 1. Compute target Q for Critic
        # -----------------------------
        with torch.no_grad():
            next_actions = model.actor(next_states)
            q1_next = model.q1_target(torch.cat([next_states, next_actions], dim=1))
            q2_next = model.q2_target(torch.cat([next_states, next_actions], dim=1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next
            # clamp is fine, but do it as a separate (non-inplace) assignment
            target_q = target_q.clamp(-200, 0)

        # -----------------------------
        # 2. Critic forward pass
        # -----------------------------
        sa = torch.cat([states, actions], dim=1)
        current_q1 = model.q1(sa)
        current_q2 = model.q2(sa)

        # 2a. TD Loss (MSE between current Q and target Q)
        td_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        # 2b. CQL penalty
        random_actions = (torch.rand_like(actions) * 2) - 1
        q1_rand = model.q1(torch.cat([states, random_actions], dim=1))
        q2_rand = model.q2(torch.cat([states, random_actions], dim=1))
        q_rand = torch.cat([q1_rand, q2_rand], dim=1)
        q_rand_clamped = q_rand.clamp(-100, 100)  # not in-place
        logsumexp_val = torch.logsumexp(q_rand_clamped, dim=1).mean()
        cql_penalty = logsumexp_val - 0.5*(current_q1.mean() + current_q2.mean())

        critic_loss = td_loss + cql_weight * cql_penalty

        # -----------------------------
        # 3. Update Critic
        # -----------------------------
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()

        # -----------------------------
        # 4. Actor forward pass
        # -----------------------------
        # For actor training, we want the Q-value for the action predicted by the current actor
        pred_actions = model.actor(states)  # no .clone() or .detach()
        if torch.isnan(pred_actions).any():
            print("NaNs in actor output, skipping batch.")
            continue

        sa_actor = torch.cat([states, pred_actions], dim=1)
        q1_pred = model.q1(sa_actor)
        q2_pred = model.q2(sa_actor)
        q_pred = torch.min(q1_pred, q2_pred)

        actor_loss = -q_pred.mean()

        # -----------------------------
        # 5. Update Actor
        # -----------------------------
        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()

        # -----------------------------
        # 6. Soft-update target networks
        # -----------------------------
        model.update_targets()

        # Print stats
        if i % print_interval == 0:
            print(f"Epoch: {epoch}, Iteration: {i}")
            print(f"TD Loss: {td_loss.item():.4f}, CQL Penalty: {cql_penalty.item():.4f}")
            print(f"Critic Loss: {critic_loss.item():.4f}, Actor Loss: {actor_loss.item():.4f}")
            print("-"*50)


In [7]:
# Initialize logging tools
writer = SummaryWriter()
csv_file = 'training_stats.csv'

# Write CSV header
with open(csv_file, 'w', newline='') as f:
    csv_writer = csv.writer(f)
    csv_writer.writerow(['Epoch', 'Iteration', 'TD Loss', 'CQL Penalty', 
                        'Critic Loss', 'Actor Loss', 'Q1 Value', 'Q2 Value'])

# Training loop
for epoch in tqdm(range(1000), desc="Training Progress"):
    # Initialize metrics
    metrics = {
        'td': 0.0,
        'cql': 0.0,
        'critic': 0.0,
        'actor': 0.0,
        'q1': 0.0,
        'q2': 0.0,
        'count': 0
    }

    for i, batch in enumerate(dataloader):
        # --- Existing training code ---
        states = batch["state"].to(device)
        actions = batch["action"].to(device)
        rewards = batch["reward"].to(device).unsqueeze(1)
        next_states = batch["next_state"].to(device)
        dones = batch["done"].to(device).unsqueeze(1)

        # --- Critic calculations ---
        with torch.no_grad():
            next_actions = model.actor(next_states)
            q1_next = model.q1_target(torch.cat([next_states, next_actions], dim=1))
            q2_next = model.q2_target(torch.cat([next_states, next_actions], dim=1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next
            target_q = target_q.clamp(-200, 0)

        sa = torch.cat([states, actions], dim=1)
        current_q1 = model.q1(sa)
        current_q2 = model.q2(sa)

        # TD Loss
        td_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        
        # CQL Penalty
        random_actions = (torch.rand_like(actions) * 2) - 1
        q1_rand = model.q1(torch.cat([states, random_actions], dim=1))
        q2_rand = model.q2(torch.cat([states, random_actions], dim=1))
        q_rand = torch.cat([q1_rand, q2_rand], dim=1).clamp(-100, 100)
        logsumexp_val = torch.logsumexp(q_rand, dim=1).mean()
        cql_penalty = logsumexp_val - 0.5*(current_q1.mean() + current_q2.mean())

        critic_loss = td_loss + cql_weight * cql_penalty

        # Critic update
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()

        # --- Actor calculations ---
        pred_actions = model.actor(states)
        if torch.isnan(pred_actions).any():
            print("NaNs detected, skipping actor update")
            q1_pred = torch.tensor(0.0)  # Default values
            q2_pred = torch.tensor(0.0)
            actor_loss = torch.tensor(0.0)
        else:
            sa_actor = torch.cat([states, pred_actions], dim=1)
            q1_pred = model.q1(sa_actor)
            q2_pred = model.q2(sa_actor)
            actor_loss = -torch.min(q1_pred, q2_pred).mean()

            # Actor update
            optimizer_actor.zero_grad()
            actor_loss.backward()
            optimizer_actor.step()

        # Target network updates
        model.update_targets()

        # --- Metrics collection ---
        metrics['td'] += td_loss.item()
        metrics['cql'] += cql_penalty.item()
        metrics['critic'] += critic_loss.item()
        metrics['actor'] += actor_loss.item() if not torch.isnan(pred_actions).any() else 0
        metrics['q1'] += q1_pred.mean().item()
        metrics['q2'] += q2_pred.mean().item()
        metrics['count'] += 1

        # --- Logging ---
        if i % print_interval == 0 and metrics['count'] > 0:
            # Calculate averages
            avg_td = metrics['td'] / metrics['count']
            avg_cql = metrics['cql'] / metrics['count']
            avg_critic = metrics['critic'] / metrics['count']
            avg_actor = metrics['actor'] / metrics['count']
            avg_q1 = metrics['q1'] / metrics['count']
            avg_q2 = metrics['q2'] / metrics['count']

            # TensorBoard logging
            global_step = epoch * len(dataloader) + i
            writer.add_scalar('Loss/TD', avg_td, global_step)
            writer.add_scalar('Loss/CQL', avg_cql, global_step)
            writer.add_scalar('Loss/Critic', avg_critic, global_step)
            writer.add_scalar('Loss/Actor', avg_actor, global_step)
            writer.add_scalar('Q_Values/Q1', avg_q1, global_step)
            writer.add_scalar('Q_Values/Q2', avg_q2, global_step)

            # CSV logging
            with open(csv_file, 'a', newline='') as f:
                csv_writer = csv.writer(f)
                csv_writer.writerow([epoch, i, avg_td, avg_cql, avg_critic,
                                    avg_actor, avg_q1, avg_q2])

            # Reset metrics
            metrics = {k: 0.0 for k in metrics}
            metrics['count'] = 0

# Visualization code remains the same as before

Training Progress: 100%|██████████| 1000/1000 [41:50<00:00,  2.51s/it]


In [None]:
·tensorboard --logdir=runs

SyntaxError: cannot assign to expression here. Maybe you meant '==' instead of '='? (3224537314.py, line 1)