In [20]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import os
import csv
import pandas as pd
from datetime import datetime


In [8]:


# State: [current_glucose, glucose_trend, heart_rate, heart_rate_trend, insulin_on_board]
state_dim = 5
action_dim = 1  # Continuous insulin dose (steps of 0.05 units)

# Hyperparameters
alpha = 0.2  # Entropy coefficient
cql_weight = 5.0  # CQL penalty strength
batch_size = 256
device = "cuda" if torch.cuda.is_available() else "cpu"



In [9]:

class SACCQL(nn.Module):
    def __init__(self):
        super().__init__()
        # Actor (policy) network
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh()  # Output in [-1, 1] (rescale to insulin range)
        )
        
        # Critic networks (twin Q-functions)
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def act(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action = self.actor(state)
        return action.cpu().numpy()




In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset

class DiabetesDataset(Dataset):
    def __init__(self, sequence_length=1):
        # Assume you have loaded your time-series data into these arrays:
        self.glucose = np.random.randn(10000)            # (num_timesteps,)
        self.glucose_deriv = np.random.randn(10000)      # (num_timesteps,)
        self.heart_rate = np.random.randn(10000)         # (num_timesteps,)
        self.hr_deriv = np.random.randn(10000)           # (num_timesteps,)
        self.iob = np.random.randn(10000)                # (num_timesteps,)
        self.insulin_doses = np.random.randn(10000)      # (num_timesteps,)
        
        # Compute rewards (example: penalize deviations from target glucose)
        self.rewards = -np.abs(self.glucose - 100)  # Target = 100 mg/dL
        
        # States: Stack all 5 time-series features
        self.states = np.column_stack([
            self.glucose,
            self.glucose_deriv,
            self.heart_rate,
            self.hr_deriv,
            self.iob
        ])  # Shape: (num_timesteps, 5)
        
        # Next states: Shift states by 1 timestep
        self.next_states = np.roll(self.states, shift=-1, axis=0)
        
        # "Done" flags (0 = episode continues, 1 = episode ends)
        # Assume episodes never terminate (modify for real data)
        self.dones = np.zeros(len(self.states))
        self.dones[-1] = 1  # Mark the end of the dataset

    def __len__(self):
        return len(self.states) - 1  # Ignore last next_state (no future)

    def __getitem__(self, idx):
        return {
            "state": self.states[idx],          # Shape: (5,)
            "action": self.insulin_doses[idx],  # Shape: (1,)
            "reward": self.rewards[idx],        # Shape: (1,)
            "next_state": self.next_states[idx],# Shape: (5,)
            "done": self.dones[idx]             # Shape: (1,)
        }

In [None]:
# Load your historical dataset (replace with your data)
class DiabetesDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.states = np.random.randn(10000, state_dim)  # Replace with real data
        self.actions = np.random.randn(10000, action_dim) * 0.05  # Insulin doses
        self.rewards = np.random.randn(10000)  # Reward = f(glucose)
        self.next_states = np.random.randn(10000, state_dim)
        self.dones = np.zeros(10000)
        
    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, idx):
        return (
            self.states[idx], self.actions[idx], self.rewards[idx],
            self.next_states[idx], self.dones[idx]
        )

dataset = DiabetesDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize networks and optimizers
model = SACCQL().to(device)
optimizer_actor = optim.Adam(model.actor.parameters(), lr=3e-4)
optimizer_critic = optim.Adam(list(model.q1.parameters()) + list(model.q2.parameters()), lr=3e-4)

# Training loop
for epoch in range(1000):
    for states, actions, rewards, next_states, dones in dataloader:
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device).unsqueeze(1)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.FloatTensor(dones).to(device).unsqueeze(1)
        
        # Critic loss (CQL + TD error)
        with torch.no_grad():
            next_actions = model.actor(next_states)
            q1_next = model.q1(torch.cat([next_states, next_actions], 1))
            q2_next = model.q2(torch.cat([next_states, next_actions], 1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next
        
        # Current Q-values
        current_q1 = model.q1(torch.cat([states, actions], 1))
        current_q2 = model.q2(torch.cat([states, actions], 1))
        
        # TD loss
        td_loss = nn.MSELoss()(current_q1, target_q) + nn.MSELoss()(current_q2, target_q)
        
        # CQL penalty: logsumexp(Q(s, a')) - Q(s, a)
        random_actions = torch.rand_like(actions) * 2 - 1  # Random actions in [-1, 1]
        q1_rand = model.q1(torch.cat([states, random_actions], 1))
        q2_rand = model.q2(torch.cat([states, random_actions], 1))
        cql_penalty = (
            torch.logsumexp(torch.cat([q1_rand, q2_rand], 1), dim=1).mean() -
            (current_q1.mean() + current_q2.mean()) / 2
        )
        
        # Total critic loss
        critic_loss = td_loss + cql_weight * cql_penalty
        
        # Actor loss (maximize Q-value + entropy)
        pred_actions = model.actor(states)
        q1_pred = model.q1(torch.cat([states, pred_actions], 1))
        q2_pred = model.q2(torch.cat([states, pred_actions], 1))
        actor_loss = -torch.min(q1_pred, q2_pred).mean()
        
        # Update critic
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()
        
        # Update actor
        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()

In [None]:
#WORKING VERSION
for epoch in range(1000):
    for i, batch in enumerate(dataloader):
        
        # Move batch data to device
        states = batch["state"].to(device)
        actions = batch["action"].to(device)
        rewards = batch["reward"].to(device).unsqueeze(1)
        next_states = batch["next_state"].to(device)
        dones = batch["done"].to(device).unsqueeze(1)

        # -----------------------------
        # 1. Compute target Q for Critic
        # -----------------------------
        with torch.no_grad():
            next_actions = model.actor(next_states)
            q1_next = model.q1_target(torch.cat([next_states, next_actions], dim=1))
            q2_next = model.q2_target(torch.cat([next_states, next_actions], dim=1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next
            # clamp is fine, but do it as a separate (non-inplace) assignment
            target_q = target_q.clamp(-200, 0)

        # -----------------------------
        # 2. Critic forward pass
        # -----------------------------
        sa = torch.cat([states, actions], dim=1)
        current_q1 = model.q1(sa)
        current_q2 = model.q2(sa)

        # 2a. TD Loss (MSE between current Q and target Q)
        td_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)

        # 2b. CQL penalty
        random_actions = (torch.rand_like(actions) * 2) - 1
        q1_rand = model.q1(torch.cat([states, random_actions], dim=1))
        q2_rand = model.q2(torch.cat([states, random_actions], dim=1))
        q_rand = torch.cat([q1_rand, q2_rand], dim=1)
        q_rand_clamped = q_rand.clamp(-100, 100)  # not in-place
        logsumexp_val = torch.logsumexp(q_rand_clamped, dim=1).mean()
        cql_penalty = logsumexp_val - 0.5*(current_q1.mean() + current_q2.mean())

        critic_loss = td_loss + cql_weight * cql_penalty

        # -----------------------------
        # 3. Update Critic
        # -----------------------------
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()

        # -----------------------------
        # 4. Actor forward pass
        # -----------------------------
        # For actor training, we want the Q-value for the action predicted by the current actor
        pred_actions = model.actor(states)  # no .clone() or .detach()
        if torch.isnan(pred_actions).any():
            print("NaNs in actor output, skipping batch.")
            continue

        sa_actor = torch.cat([states, pred_actions], dim=1)
        q1_pred = model.q1(sa_actor)
        q2_pred = model.q2(sa_actor)
        q_pred = torch.min(q1_pred, q2_pred)

        actor_loss = -q_pred.mean()

        # -----------------------------
        # 5. Update Actor
        # -----------------------------
        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()

        # -----------------------------
        # 6. Soft-update target networks
        # -----------------------------
        model.update_targets()

        # Print stats
        if i % print_interval == 0:
            print(f"Epoch: {epoch}, Iteration: {i}")
            print(f"TD Loss: {td_loss.item():.4f}, CQL Penalty: {cql_penalty.item():.4f}")
            print(f"Critic Loss: {critic_loss.item():.4f}, Actor Loss: {actor_loss.item():.4f}")
            print("-"*50)

In [None]:
torch.autograd.set_detect_anomaly(True)
print_interval = 100
writer = SummaryWriter()
csv_file = 'training_stats1.csv'

# Write CSV header
with open(csv_file, 'w', newline='') as f:
    csv_writer = csv.writer(f)
    csv_writer.writerow(['Epoch', 'Iteration', 'TD Loss', 'CQL Penalty', 
                        'Critic Loss', 'Actor Loss', 'Q1 Value', 'Q2 Value'])

# Training loop
for epoch in tqdm(range(1000), desc="Training Progress"):
    # Initialize metrics
    metrics = {
        'td': 0.0,
        'cql': 0.0,
        'critic': 0.0,
        'actor': 0.0,
        'q1': 0.0,
        'q2': 0.0,
        'count': 0
    }

    for i, batch in enumerate(dataloader):
        # --- Existing training code ---
        states = batch["state"].to(device)
        actions = batch["action"].to(device)
        rewards = batch["reward"].to(device).unsqueeze(1)
        next_states = batch["next_state"].to(device)
        dones = batch["done"].to(device).unsqueeze(1)

        # --- Critic calculations ---
        with torch.no_grad():
            next_actions = model.actor(next_states)
            q1_next = model.q1_target(torch.cat([next_states, next_actions], dim=1))
            q2_next = model.q2_target(torch.cat([next_states, next_actions], dim=1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next
            target_q = target_q.clamp(-200, 0)

        sa = torch.cat([states, actions], dim=1)
        current_q1 = model.q1(sa)
        current_q2 = model.q2(sa)

        # TD Loss
        td_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(current_q2, target_q)
        
        # Inside your training loop:
        states = batch["state"].to(device)
        dataset_actions = batch["action"].to(device)  # Real actions from dataset

        # Compute CQL penalty using DATASET actions (not random ones!)
        cql_penalty = compute_cql_penalty(states, dataset_actions, model)
        

        critic_loss = td_loss + cql_weight * cql_penalty

        # Critic update
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()

        # --- Actor calculations ---
        pred_actions = model.actor(states)
        if torch.isnan(pred_actions).any():
            print("NaNs detected, skipping actor update")
            q1_pred = torch.tensor(0.0)  # Default values
            q2_pred = torch.tensor(0.0)
            actor_loss = torch.tensor(0.0)
        else:
            sa_actor = torch.cat([states, pred_actions], dim=1)
            q1_pred = model.q1(sa_actor)
            q2_pred = model.q2(sa_actor)
            actor_loss = -torch.min(q1_pred, q2_pred).mean() + alpha * policy_entropy

            # Actor update
            optimizer_actor.zero_grad()
            actor_loss.backward()
            optimizer_actor.step()

        # Target network updates
        model.update_targets()

        # --- Metrics collection ---
        metrics['td'] += td_loss.item()
        metrics['cql'] += cql_penalty.item()
        metrics['critic'] += critic_loss.item()
        metrics['actor'] += actor_loss.item() if not torch.isnan(pred_actions).any() else 0
        metrics['q1'] += q1_pred.mean().item()
        metrics['q2'] += q2_pred.mean().item()
        metrics['count'] += 1

        # --- Logging ---
        if metrics['count'] > 0:  # Log after every batch
            # Calculate averages
            avg_td = metrics['td'] / metrics['count']
            avg_cql = metrics['cql'] / metrics['count']
            avg_critic = metrics['critic'] / metrics['count']
            avg_actor = metrics['actor'] / metrics['count']
            avg_q1 = metrics['q1'] / metrics['count']
            avg_q2 = metrics['q2'] / metrics['count']

            # TensorBoard logging
            global_step = epoch * len(dataloader) + i
            writer.add_scalar('Loss/TD', avg_td, global_step)
            writer.add_scalar('Loss/CQL', avg_cql, global_step)
            writer.add_scalar('Loss/Critic', avg_critic, global_step)
            writer.add_scalar('Loss/Actor', avg_actor, global_step)
            writer.add_scalar('Q_Values/Q1', avg_q1, global_step)
            writer.add_scalar('Q_Values/Q2', avg_q2, global_step)
            writer.add_scalar('Actions/Mean', pred_actions.mean(), global_step)
            writer.add_scalar('Actions/Std', pred_actions.std(), global_step)

            # CSV logging
            with open(csv_file, 'a', newline='') as f:
                csv_writer = csv.writer(f)
                csv_writer.writerow([epoch, i, avg_td, avg_cql, avg_critic,
                                    avg_actor, avg_q1, avg_q2])

            # Reset metrics
            metrics = {k: 0.0 for k in metrics}
            metrics['count'] = 0