In [None]:
import numpy as np
import torch
import sys
import os

# 📌 Import your trained model
from model.sac_cql import SACCQL

In [None]:
# --------------------------
# Data Handling
# --------------------------
class DiabetesDataset(Dataset):
    """Processed diabetes management dataset."""
    
    def __init__(self, csv_file):
        # Load data and fill missing values
        df = pd.read_csv(csv_file)
        df = df.ffill().bfill()
        
        # Ensure key features have no missing values
        if df[["glu", "glu_d", "glu_t", "iob", "hour"]].isna().any().any():
            raise ValueError("Dataset contains NaN values after preprocessing")
        
        # Verify that action values are within [-1, 1]
        assert df["action"].between(-1, 1).all(), "Actions must be between -1 and 1"
        
        # Prepare state features and action values
        self.states = df[["glu", "glu_d", "glu_t", "iob", "hour"]].values.astype(np.float32)
        self.actions = df["action"].values.astype(np.float32).reshape(-1, 1)
        
        # Compute rewards from the glu_raw values
        self.rewards = self._compute_rewards(df["glu_raw"].values)
        
        # Create transitions: next_states via roll and done flags
        self.next_states = np.roll(self.states, -1, axis=0)
        self.dones = df["done"].values.astype(np.float32)
        
        # Remove last transition (invalid next state)
        self._sanitize_transitions()
    
    def _compute_rewards(self, glucose_next):
        """
        Compute rewards using a rescaled Risk Index (RI)-based function.
        Based on Kovatchev et al. (2005), extended with a severe hypoglycemia penalty.
        """
        glucose = np.clip(glucose_next.astype(np.float32), 10, 400)  # Clamp extreme values

        # Step 1: Risk transformation function
        log_glucose = np.log(glucose)
        f = 1.509 * (np.power(log_glucose, 1.084) - 5.381)
        r = 10 * np.square(f)

        # Step 2: LBGI and HBGI
        lbgi = np.where(f < 0, r, 0)
        hbgi = np.where(f > 0, r, 0)

        # Step 3: Total Risk Index (RI)
        ri = lbgi + hbgi

        # Step 4: Rescale RI and convert to reward
        normalized_ri = -ri / 10.0  # Stronger signal than /100
        rewards = np.clip(normalized_ri, -5.0, 0.0)

        # Step 5: Severe hypoglycemia penalty
        severe_hypo_penalty = np.where(glucose <= 39, -15.0, 0.0)
        rewards += severe_hypo_penalty

        # Step 6: Optional time penalty
        rewards -= 0.01  # Encourage faster correction

        return np.clip(rewards, -15.0, 0.0).astype(np.float32)


    
    def _sanitize_transitions(self):
        """Remove the last transition which lacks a valid next state."""
        valid_mask = np.ones(len(self.states), dtype=bool)
        valid_mask[-1] = False
        self.states = self.states[valid_mask]
        self.actions = self.actions[valid_mask]
        self.rewards = self.rewards[valid_mask]
        self.next_states = self.next_states[valid_mask]
        self.dones = self.dones[valid_mask]
    
    def __len__(self):
        return len(self.states)
    
    def __getitem__(self, idx):
        return {
            'state': torch.FloatTensor(self.states[idx]),
            'action': torch.FloatTensor(self.actions[idx]),
            'reward': torch.FloatTensor([self.rewards[idx]]),
            'next_state': torch.FloatTensor(self.next_states[idx]),
            'done': torch.FloatTensor([self.dones[idx]])
        }


In [None]:
# --------------------------
# Actor: Gaussian Policy Network
# --------------------------
class GaussianActor(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_units=32):
        super().__init__()
        # Three dense layers with 32 units each
        self.fc1 = nn.Linear(state_dim, hidden_units)
        self.fc2 = nn.Linear(hidden_units, hidden_units)
        self.fc3 = nn.Linear(hidden_units, hidden_units)
        # Separate output heads for mean and log_std
        self.mean_head = nn.Linear(hidden_units, action_dim)
        self.log_std_head = nn.Linear(hidden_units, action_dim)
        self.LOG_STD_MIN = -20
        self.LOG_STD_MAX = 2
        
        # Initialize weights
        for layer in [self.fc1, self.fc2, self.fc3, self.mean_head, self.log_std_head]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
            
    def forward(self, state):
        x = F.relu(self.fc1(state))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        mean = self.mean_head(x)
        log_std = self.log_std_head(x)
        log_std = torch.clamp(log_std, self.LOG_STD_MIN, self.LOG_STD_MAX)
        std = torch.exp(log_std)
        return mean, std
    
    def sample(self, state):
        mean, std = self.forward(state)
        dist = Normal(mean, std)
        x_t = dist.rsample()  # Reparameterization trick
        action = torch.tanh(x_t)  # Squash to [-1, 1]
        # Compute log probability with tanh correction
        log_prob = dist.log_prob(x_t)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        log_prob -= torch.log(1 - action.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
        return action, log_prob

# --------------------------
# Critic: Q-Network
# --------------------------
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_units=32):
        super().__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_units)
        self.fc2 = nn.Linear(hidden_units, hidden_units)
        self.fc3 = nn.Linear(hidden_units, hidden_units)
        self.out = nn.Linear(hidden_units, 1)
        
        for layer in [self.fc1, self.fc2, self.fc3, self.out]:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
            
    def forward(self, state, action):
        x = torch.cat([state, action], dim=1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        q = self.out(x)
        return q

In [None]:
# --------------------------
# SAC Agent
# --------------------------
class SACAgent(nn.Module):
    """SAC agent with a Gaussian policy and twin Q-networks."""
    def __init__(self, state_dim=8, action_dim=1,
                 actor_lr=3e-4, critic_lr=3e-4, alpha_lr=3e-4,
                 target_entropy=-1, gamma=0.997, tau=0.005):
        super().__init__()
        self.gamma = gamma
        self.tau = tau
        self.target_entropy = target_entropy
        
        # Actor (policy) network
        self.actor = GaussianActor(state_dim, action_dim).to(device)
        
        # Twin Q-networks
        self.q1 = QNetwork(state_dim, action_dim).to(device)
        self.q2 = QNetwork(state_dim, action_dim).to(device)
        # Target networks
        self.q1_target = QNetwork(state_dim, action_dim).to(device)
        self.q2_target = QNetwork(state_dim, action_dim).to(device)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())
        
        # Entropy temperature (initialized to 0.1) as log_alpha for numerical stability
        self.log_alpha = torch.tensor([0.1], requires_grad=True, device=device)
        
        # Optimizers for actor, critic, and temperature
        self.actor_optim = optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optim = optim.Adam(list(self.q1.parameters()) + list(self.q2.parameters()), lr=critic_lr)
        self.alpha_optim = optim.Adam([self.log_alpha], lr=alpha_lr)
        
    def act(self, state):
        """Deterministic action for evaluation: use the mean and apply tanh."""
        mean, _ = self.actor.forward(state)
        return torch.tanh(mean)
    
    def update_targets(self):
        """Soft-update target networks."""
        with torch.no_grad():
            for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)
            for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
                target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

In [None]:

# ➕ Dynamically add Gloop repo to path
GLOOP_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../Gloop"))
if GLOOP_PATH not in sys.path:
    sys.path.append(GLOOP_PATH)



class Controller:
    name = "GloopController"

    def __init__(self, scenario_instance):
        print(">> GloopController initialized")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # 🔁 Load SAC-CQL model
        self.model = SACCQL().to(self.device)
        checkpoint_path = os.path.join(GLOOP_PATH, "checkpoints/saccql_trained.pt")

        if not os.path.exists(checkpoint_path):
            raise FileNotFoundError(f"❌ Checkpoint not found at {checkpoint_path}")

        checkpoint = torch.load(checkpoint_path, map_location=self.device)
        self.model.load_state_dict(checkpoint)
        self.model.eval()
        print("✅ SAC-CQL model loaded")

    def convert_to_dose(self, x, min_dose=0.0, max_dose=10.0):
        """Scale [-1, 1] model output to actual insulin range."""
        return float(np.clip((x + 1) / 2 * max_dose, min_dose, max_dose))

    def run(self, measurements, states, inputs, sample):
        print(f"[step={sample}] GloopController running...")

        if sample >= states.shape[0]:
            print(f"[step={sample}] sample out of bounds")
            return

        try:
            # ⚙️ Build 8D input vector (match model expectations)
            state = np.array([
                states[sample, 0],  # glucose
                states[sample, 1],  # glucose_derivative
                               # glucose_trend (placeholder)
                states[sample, 2],  # heart_rate
                states[sample, 3],  # hr_derivative
                0.0,                # heart_rate_trend (placeholder)
                states[sample, 4],  # insulin_on_board
                (sample % 1440) / 60.0  # hour of day
            ], dtype=np.float32)

            state_tensor = torch.tensor(state).unsqueeze(0).to(self.device)

            # 🧠 Inference
            with torch.no_grad():
                action = self.model.act(state_tensor)[0]
            dose = self.convert_to_dose(action)

        except Exception as e:
            print(f"[step={sample}] ❌ Model inference error: {e}")
            dose = 0.0

        try:
            # 💉 Inject into simulator
            if isinstance(inputs, dict):
                target = None
                if "u_insulin" in inputs and hasattr(inputs["u_insulin"], "sampled_signal"):
                    target = inputs["u_insulin"]
                elif "uInsulin" in inputs and hasattr(inputs["uInsulin"], "sampled_signal"):
                    target = inputs["uInsulin"]

                if target:
                    target.sampled_signal[sample, 0] = dose
                    print(f"[step={sample}] ✅ Dose injected: {dose:.2f} U/hr")
                else:
                    print(f"[step={sample}] ❌ No insulin input signal found in dict")

            elif isinstance(inputs, np.ndarray):
                inputs[sample, 0] = dose
                print(f"[step={sample}] ✅ Dose injected via array: {dose:.2f} U/hr")

            else:
                print(f"[step={sample}] ❌ Unknown input structure")

        except Exception as e:
            print(f"[step={sample}] ❌ Injection failed: {e}")