In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import torch
from sklearn.decomposition import PCA
from torch.utils.data import Dataset, DataLoader

class DiabetesDataset(Dataset):
    """Processed diabetes management dataset"""
    
    def __init__(self, csv_file):
        df = pd.read_csv(csv_file)
        
        # Handle missing values by forward-filling and backward-filling
        df = df.ffill().bfill()
        
        # Verify no remaining NaNs
        if df[["glu", "glu_d", "glu_t", "hr", "hr_d", "hr_t", "iob", "hour"]].isna().any().any():
            raise ValueError("Dataset contains NaN values after preprocessing")
        
        # State features (8 dimensions)
        self.states = df[["glu", "glu_d", "glu_t", "hr", "hr_d", "hr_t", "iob", "hour"]].values.astype(np.float32)
        
        # Actions (2 dimensions)
        self.actions = df[["basal", "bolus"]].values.astype(np.float32)
        
        # Rewards computed from next glucose values
        self.rewards = self._compute_rewards(df["glu_raw"].values)
        
        # Transition handling
        self.next_states = np.roll(self.states, -1, axis=0)
        self.dones = df["done"].values.astype(np.float32)
        
        # Remove last invalid transition
        self._sanitize_transitions()

    def _compute_rewards(self, glucose_next):
        """Improved reward scaling"""
        glucose_next = np.clip(glucose_next, 40, 400)
        with np.errstate(invalid='ignore'):
            log_term = np.log(glucose_next/180.0)
            risk_index = 10 * (1.509 * (log_term**1.084 - 1.861)**2)
        
        # Better reward scaling using sigmoid instead of tanh
        rewards = -1 / (1 + np.exp(-risk_index/50))  # Scaled to (-1, 0)
        rewards[glucose_next < 54] = -5.0  # Stronger hypo penalty
        return rewards.astype(np.float32)

    def _sanitize_transitions(self):
        """Remove invalid transitions and align array lengths"""
        valid_mask = np.ones(len(self.states), dtype=bool)
        valid_mask[-1] = False  # Remove last transition
        self.states = self.states[valid_mask]
        self.actions = self.actions[valid_mask]
        self.rewards = self.rewards[valid_mask]
        self.next_states = self.next_states[valid_mask]
        self.dones = self.dones[valid_mask]

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return {
            'state': torch.FloatTensor(self.states[idx]),
            'action': torch.FloatTensor(self.actions[idx]),
            'reward': torch.FloatTensor([self.rewards[idx]]),
            'next_state': torch.FloatTensor(self.next_states[idx]),
            'done': torch.FloatTensor([self.dones[idx]])
        }

In [3]:
dataset = DiabetesDataset("datasets/processed/563-train.csv")

In [4]:
dataset.rewards

array([-5., -5., -5., ..., -5., -5., -5.], shape=(13247,), dtype=float32)

In [None]:
csv_file = "datasets/processed/563-train.csv"
    
df = pd.read_csv(csv_file)
        
        # Handle missing values by forward-filling and backward-filling
df = df.ffill().bfill()

In [6]:
df

Unnamed: 0,time,glu_raw,glu,glu_d,glu_t,hr,hr_d,hr_t,iob,hour,basal,bolus,done
0,2021-09-13 00:00:00,1.468114,1.468114,-0.000418,-0.152496,-0.665125,0.002825,0.955246,-1.301074,0.000000,0.1,-1.0,0
1,2021-09-13 00:05:00,1.468114,1.468114,-0.000418,-0.152496,-0.665125,0.002825,0.955246,-1.286116,0.000000,0.1,-1.0,0
2,2021-09-13 00:10:00,1.468114,1.468114,-0.000418,-0.152496,-0.665125,0.002825,0.955246,-1.271195,0.000000,0.1,-1.0,0
3,2021-09-13 00:15:00,1.468114,1.468114,-0.000418,-0.152496,-0.665125,0.002825,0.955246,-1.256336,0.000000,0.1,-1.0,0
4,2021-09-13 00:20:00,1.468114,1.468114,-0.000418,-0.152496,-0.665125,0.002825,0.955246,-1.241561,0.000000,0.1,-1.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
13243,2021-10-28 23:35:00,2.132251,2.132251,0.440427,0.655150,-0.236627,-0.672682,0.130752,0.250901,0.958333,-0.2,-1.0,0
13244,2021-10-28 23:40:00,2.172502,2.172502,0.440427,0.407259,-0.665125,-0.576181,-0.840681,0.210328,0.958333,-0.2,-1.0,0
13245,2021-10-28 23:45:00,2.092000,2.092000,-0.882108,-0.096521,-0.879373,-0.286678,-1.428439,0.171224,0.958333,-0.2,-1.0,0
13246,2021-10-28 23:50:00,2.011499,2.011499,-0.882108,-0.424377,-0.950790,-0.093676,-1.199867,0.133788,0.958333,-0.2,-1.0,0
