In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset

class DiabetesDataset(Dataset):
    def __init__(self, csv_file):
        # =======================
        # 1) Load CSV
        # =======================
        self.df = pd.read_csv(csv_file)
        
        # Note: If you want time-based indexing or filtering, consider
        # making 'time' a datetime. But we won't do that here.

        # =========================
        # 2) Extract State Columns
        # =========================
        # State is 7D: [glu, glu_d, glu_t, hr, hr_d, hr_t, iob]
        self.glucose       = self.df["glu"].values
        self.glucose_deriv = self.df["glu_d"].values
        self.glucose_trend = self.df["glu_t"].values
        self.hr            = self.df["hr"].values
        self.hr_deriv      = self.df["hr_d"].values
        self.hr_trend      = self.df["hr_t"].values
        self.iob           = self.df["iob"].values

        # Combine into a single 2D array: shape (N, 7)
        self.states = np.column_stack([
            self.glucose,
            self.glucose_deriv,
            self.glucose_trend,
            self.hr,
            self.hr_deriv,
            self.hr_trend,
            self.iob
        ])

        # =========================
        # 3) Extract Action Columns
        # =========================
        # Action is 2D: [basal, bol]
        self.basal = self.df["basal"].values
        self.bol   = self.df["bol"].values
        # shape (N, 2)
        self.actions = np.column_stack([self.basal, self.bol])

        # =====================
        # 4) Rewards
        # =====================
        # Example reward: negative absolute distance from target BG = 100 mg/dL
        target_glucose = 100.0  # Adjust as needed
        self.rewards = -np.abs(self.glucose - target_glucose)
        
        # =====================
        # 5) Next-State
        # =====================
        # Just shift by 1 row for next_state
        self.next_states = np.roll(self.states, shift=-1, axis=0)
        
        # =====================
        # 6) Done Flags
        # =====================
        # Mark the very last sample as "done" = 1
        # (In real data, you might have episode breaks)
        self.dones = np.zeros(len(self.states))
        self.dones[-1] = 1

    def __len__(self):
        # We'll say the dataset ends at the second-to-last point
        # because the last point doesn't have a "next_state" in this
        # naive approach
        return len(self.states) - 1

    def __getitem__(self, idx):
        # Return a dict with the usual RL (s,a,r,s',done)
        return {
            "state":      torch.tensor(self.states[idx],      dtype=torch.float32),
            "action":     torch.tensor(self.actions[idx],     dtype=torch.float32),
            "reward":     torch.tensor(self.rewards[idx],     dtype=torch.float32),
            "next_state": torch.tensor(self.next_states[idx], dtype=torch.float32),
            "done":       torch.tensor(self.dones[idx],       dtype=torch.float32),
        }

# ===================
# USAGE EXAMPLE
# ===================
if __name__ == "__main__":
    dataset = DiabetesDataset(csv_file="normalized_data.csv")
    
    # Let's just peek at the first sample
    first_sample = dataset[0]
    print("State shape:", first_sample["state"].shape)
    print("Action shape:", first_sample["action"].shape)
    print("Reward:", first_sample["reward"])
    print("Done:", first_sample["done"])
