In [12]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
import torch.nn as nn
import torch.optim as optim
import os
import csv
import pandas as pd
from datetime import datetime


In [13]:

class DiabetesDataset(Dataset):
    def __init__(self, csv_file):
        # Load CSV data
        self.df = pd.read_csv(csv_file)
        
        # Extract state features: 8 dimensions
        # [glu, glu_d, glu_t, hr, hr_d, hr_t, iob, hour_norm]
        self.states = self.df[[
            "glu", "glu_d", "glu_t",
            "hr", "hr_d", "hr_t",
            "iob", "hour"
        ]].values.astype(np.float32)
        
        # Extract action features: 2 dimensions [basal, bol]
        self.actions = self.df[["basal", "bolus"]].values.astype(np.float32)
        
        # Extract done flags (1 at episode boundaries, 0 otherwise)
        self.dones = self.df["done"].values.astype(np.float32)
        
        # Define rewards (example: negative absolute deviation from target)
        target_glucose = 0.0  # For normalized glucose, target might be 0 after scaling
        self.rewards = -np.abs(self.states[:, 0] - target_glucose)
        
        # Compute next_states using a vectorized roll
        self.next_states = np.roll(self.states, shift=-1, axis=0)
        
        # For transitions where the current step is an episode end,
        # set the next state to be the current state so that transitions do not cross episodes.
        self.next_states[self.dones == 1] = self.states[self.dones == 1]
        
        # Remove the final row since it doesn't have a valid next state
        self.states = self.states[:-1]
        self.actions = self.actions[:-1]
        self.rewards = self.rewards[:-1]
        self.next_states = self.next_states[:-1]
        self.dones = self.dones[:-1]

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        # Return a dictionary for the transition (s, a, r, s', done)
        return {
            "state":      torch.tensor(self.states[idx],      dtype=torch.float32),
            "action":     torch.tensor(self.actions[idx],     dtype=torch.float32),
            "reward":     torch.tensor(self.rewards[idx],     dtype=torch.float32),
            "next_state": torch.tensor(self.next_states[idx], dtype=torch.float32),
            "done":       torch.tensor(self.dones[idx],       dtype=torch.float32)
        }


In [6]:

# ===================
# USAGE EXAMPLE
# ===================
if __name__ == "__main__":
    dataset = DiabetesDataset(csv_file="datasets/processed/563-test.csv")
    
    # Peek at the first sample
    sample = dataset[0]
    print("State:", sample["state"])
    print("Action:", sample["action"])
    print("Reward:", sample["reward"])
    print("Next State:", sample["next_state"])
    print("Done:", sample["done"])


State: tensor([ 1.5537,     nan,     nan, -0.0855,     nan,     nan, -1.2715,  0.0000])
Action: tensor([ 3.1836, -0.1611])
Reward: tensor(-1.5537)
Next State: tensor([ 1.5320, -0.1646,     nan, -1.0180, -1.3954,     nan, -1.2561,  0.0000])
Done: tensor(0.)


In [14]:
# State: [current_glucose, glucose_trend, heart_rate, heart_rate_trend, insulin_on_board]
state_dim = 8
action_dim = 2  # Continuous insulin dose (steps of 0.05 units)

# Hyperparameters
alpha = 0.2  # Entropy coefficient
cql_weight = 5.0  # CQL penalty strength
batch_size = 256
device = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
class SACCQL(nn.Module):
    def __init__(self):
        super().__init__()
        # Actor (policy) network
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, action_dim),
            nn.Tanh()  # Output in [-1, 1] (rescale to insulin range)
        )
        
        # Critic networks (twin Q-functions)
        self.q1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.q2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        
    def act(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(device)
            action = self.actor(state)
        return action.cpu().numpy()



In [16]:
dataset  = DiabetesDataset(csv_file="datasets/processed/563-test.csv")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [17]:
# Initialize networks and optimizers
model = SACCQL().to(device)
optimizer_actor = optim.Adam(model.actor.parameters(), lr=3e-4)
optimizer_critic = optim.Adam(list(model.q1.parameters()) + list(model.q2.parameters()), lr=3e-4)



In [None]:

# Training loop
for epoch in range(1000):
    for states, actions, rewards, next_states, dones in dataloader:
        states = torch.FloatTensor(states).to(device)
        actions = torch.FloatTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device).unsqueeze(1)
        next_states = torch.FloatTensor(next_states).to(device)
        dones = torch.FloatTensor(dones).to(device).unsqueeze(1)
        
        # Critic loss (CQL + TD error)
        with torch.no_grad():
            next_actions = model.actor(next_states)
            q1_next = model.q1(torch.cat([next_states, next_actions], 1))
            q2_next = model.q2(torch.cat([next_states, next_actions], 1))
            q_next = torch.min(q1_next, q2_next)
            target_q = rewards + (1 - dones) * 0.99 * q_next
        
        # Current Q-values
        current_q1 = model.q1(torch.cat([states, actions], 1))
        current_q2 = model.q2(torch.cat([states, actions], 1))
        
        # TD loss
        td_loss = nn.MSELoss()(current_q1, target_q) + nn.MSELoss()(current_q2, target_q)
        
        # CQL penalty: logsumexp(Q(s, a')) - Q(s, a)
        random_actions = torch.rand_like(actions) * 2 - 1  # Random actions in [-1, 1]
        q1_rand = model.q1(torch.cat([states, random_actions], 1))
        q2_rand = model.q2(torch.cat([states, random_actions], 1))
        cql_penalty = (
            torch.logsumexp(torch.cat([q1_rand, q2_rand], 1), dim=1).mean() -
            (current_q1.mean() + current_q2.mean()) / 2
        )
        
        # Total critic loss
        critic_loss = td_loss + cql_weight * cql_penalty
        
        # Actor loss (maximize Q-value + entropy)
        pred_actions = model.actor(states)
        q1_pred = model.q1(torch.cat([states, pred_actions], 1))
        q2_pred = model.q2(torch.cat([states, pred_actions], 1))
        actor_loss = -torch.min(q1_pred, q2_pred).mean()
        
        # Update critic
        optimizer_critic.zero_grad()
        critic_loss.backward()
        optimizer_critic.step()
        
        # Update actor
        optimizer_actor.zero_grad()
        actor_loss.backward()
        optimizer_actor.step()