In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.model_selection import train_test_split

In [16]:
df_ryiadh = pd.read_csv("../possible_datasets/KSMC_Hospital/data_preprocessed.csv")
df_ryiadh.drop(columns=["Admission_Date"], inplace=True)        #dates are for every time series the same so using them in modeling does not make sense
df_ryiadh

Unnamed: 0,Age,Gender,Nationality,Hospital_Name,Admission_Counts
0,0-17,Female,Non-Saudi,King Abdulaziz Medical City,6
1,0-17,Female,Non-Saudi,King Abdulaziz Medical City,7
2,0-17,Female,Non-Saudi,King Abdulaziz Medical City,2
3,0-17,Female,Non-Saudi,King Abdulaziz Medical City,5
4,0-17,Female,Non-Saudi,King Abdulaziz Medical City,3
...,...,...,...,...,...
4795,65+,Male,Saudi,King Saud Medical City,23
4796,65+,Male,Saudi,King Saud Medical City,31
4797,65+,Male,Saudi,King Saud Medical City,31
4798,65+,Male,Saudi,King Saud Medical City,24


In [17]:
df_ryiadh = df_ryiadh.pivot_table(index=["Age", "Gender", "Nationality", "Hospital_Name"], aggfunc=list).reset_index()
df_ryiadh["target"] = [a[63:] for a in df_ryiadh["Admission_Counts"]]
df_ryiadh["Admission_Counts"] = [a[:63] for a in df_ryiadh["Admission_Counts"]]

In [18]:
# 80% train, 20% temp
x_dyn_train, x_dyn_temp, y_train, y_temp, x_static_train, x_static_temp = train_test_split(
    df_ryiadh[["Admission_Counts"]], df_ryiadh["target"], df_ryiadh[["Age", "Gender", "Nationality", "Hospital_Name"]], test_size=0.20, random_state=42, 
)

# 50/50 split of the 20% -> 10% val, 10% test
x_dyn_val, x_dyn_test, y_val, y_test, x_static_val, x_static_test = train_test_split(
    x_dyn_temp, y_temp, x_static_temp, test_size=0.50, random_state=42
)

In [19]:
ohe = OneHotEncoder(sparse_output=False, dtype=np.float32, drop='first')
ohe = ohe.fit(df_ryiadh[["Hospital_Name"]])

# encode categorical variables (Label Encoding and One-Hot Encoding)
def preprocess_static(df):
    df["Age"] = df["Age"].apply(lambda x: 0. if x == '0-17' else 1. if x == '18-45' else 2. if x == '46-65' else 3)
    df["Gender"] = df["Gender"].apply(lambda x: 0. if x == 'Female' else 1.)
    df["Nationality"] = df["Nationality"].apply(lambda x: 0. if x == 'Non-Saudi' else 1.)
    
    df_ohe = pd.DataFrame(
        ohe.fit_transform(df[["Hospital_Name"]]), columns=ohe.get_feature_names_out(["Hospital_Name"])
    )
    df = pd.concat([df.reset_index(drop=True), df_ohe.reset_index(drop=True)], axis=1).reset_index(drop=True)
    df = df.drop(columns=["Hospital_Name"])
    return torch.tensor(df.values).float()
#todo admission month circular encoding

x_static_train = preprocess_static(x_static_train)
x_static_val = preprocess_static(x_static_val)
x_static_test = preprocess_static(x_static_test)

In [20]:
#standardize over all time steps and time series, todo ask if this is right
scaler = StandardScaler()
all_counts = np.concatenate(df_ryiadh["Admission_Counts"].values)  
all_counts = np.asarray(all_counts, dtype=float)

overall_max = all_counts.max()
overall_std = all_counts.std(ddof=0)  # ddof=1 für Stichproben-Std

In [21]:
def standardize_counts(df):
    standardize_counts = [(a - overall_max) / overall_std for a in df["Admission_Counts"]]
    return torch.tensor(standardize_counts).float().unsqueeze(-1)

x_dyn_train = standardize_counts(x_dyn_train)
x_dyn_val = standardize_counts(x_dyn_val)
x_dyn_test = standardize_counts(x_dyn_test)

def convert_y_to_tensor(df):
    return torch.tensor(np.array(df.values.tolist(), dtype=np.float32)).float().unsqueeze(-1)

y_train = convert_y_to_tensor(y_train)
y_val = convert_y_to_tensor(y_val)
y_test = convert_y_to_tensor(y_test)

In [22]:
class TemporalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1):
        super().__init__()
        padding = (kernel_size - 1) * dilation

        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size,
                               padding=padding, dilation=dilation)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
                               padding=padding, dilation=dilation)
        self.relu2 = nn.ReLU()

        self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=1) \
            if in_channels != out_channels else None
        
    def forward(self, x):
        out = self.relu1(self.conv1(x))
        out = self.relu2(self.conv2(out))
        
        res = x if self.downsample is None else self.downsample(x)
        # Output kürzen wegen Padding
        out = out[..., :res.shape[-1]]
        return out + res


class TCN(nn.Module):
    def __init__(self, input_features=3, output_steps=12, hidden=32):
        super().__init__()
        
        self.tblock1 = TemporalBlock(input_features, hidden, dilation=1)
        self.tblock2 = TemporalBlock(hidden, hidden, dilation=2)
        self.tblock3 = TemporalBlock(hidden, hidden, dilation=4)

        # letzer linearer Kopf für Forecast-Horizon
        self.head = nn.Conv1d(hidden, output_steps, kernel_size=1)

    def forward(self, x_dyn, x_static=None):
        """
        x_dyn:    [batch, seq_len, features]
        x_static: [batch, n_static] (wird NICHT in Modell verwendet)
        """
        x = x_dyn.transpose(1, 2)  # → [batch, features, seq_len]
        
        x = self.tblock1(x)
        x = self.tblock2(x)
        x = self.tblock3(x)

        out = self.head(x)  # [batch, output_steps, seq_len]

        # Letzten Zeitschritt extrahieren (klassisch bei TCN Forecasting)
        out = out[:, :, -1]  # [batch, output_steps]

        return out.unsqueeze(-1)  # → [batch, output_steps, 1]
    


class GroupFairnessMAEVariance(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, preds, target, static_features):
        # preds, target: [batch, seq_len, 1]
        # static_features: [batch, n_static]
        mae = torch.abs(preds - target).mean(dim=1)  # mitteln über seq_len -> [batch, 1]
        group_vars = []
        for i in range(static_features.shape[1]):
            group_vals = static_features[:, i]  # [batch]
            unique_groups = torch.unique(group_vals)
            if len(unique_groups) <= 1:
                continue  # überspringen, wenn nur eine Gruppe vorhanden ist
            group_maes = []
            for g in unique_groups:
                mask = (group_vals == g)
                if mask.sum() > 0:
                    group_maes.append(mae[mask].mean())
            group_maes = torch.stack(group_maes)
            group_vars.append(group_maes.var())  # Varianz der Gruppen-MAEs
        return torch.stack(group_vars).mean()  # mitteln über alle statischen Merkmale

class MAE(nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, preds, target, static_features):
        return torch.abs(preds - target).mean()
    
class MAE_GroupFairness(nn.Module):
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha
        self.mae = MAE()
        self.fairness = GroupFairnessMAEVariance()
    
    def forward(self, preds, target, static_features):
        return self.alpha * self.mae(preds, target, static_features) + (1 - self.alpha) * self.fairness(preds, target, static_features)


def train_model(
    x_dyn_train, y_train, x_static_train,
    x_dyn_val, y_val, x_static_val, loss_fn, log_metrics={},
    epochs=50, lr=1e-3, batch_size=16,
    patience=5, print_every=1
):
    torch.manual_seed(42)
    
    model = TCN(input_features=x_dyn_train.shape[-1], output_steps=y_train.shape[1])
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    best_val_loss = float("inf")
    epochs_no_improve = 0
    best_model_state = None
    
    n_train = x_dyn_train.shape[0]
    
    for epoch in range(epochs):
        # Shuffle training data
        permutation = torch.randperm(n_train)
        for i in range(0, n_train, batch_size):
            idx = permutation[i:i+batch_size]
            bx, by, bs = x_dyn_train[idx], y_train[idx], x_static_train[idx]
            
            optimizer.zero_grad()
            preds = model(bx, bs)
            loss = loss_fn(preds, by, bs)
            loss.backward()
            optimizer.step()
        
        # Validation
        with torch.no_grad():
            val_preds = model(x_dyn_val, x_static_val)
            val_loss = loss_fn(val_preds, y_val, x_static_val).item()
        
        # Logging
        if (epoch+1) % print_every == 0:
            msg = f"Epoch {epoch+1}/{epochs} | Val Loss: {val_loss:.4f}"
            for name, metric_fn in log_metrics.items():
                val = metric_fn(val_preds, y_val, x_static_val)
                msg += f" | {name}: {val.item():.4f}"
            print(msg)
        
        # --- Early Stopping Check ---
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()  # besten Zustand speichern
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
        
        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    # Load best model
    model.load_state_dict(best_model_state)
    return model

In [23]:
model = train_model(x_dyn_train, y_train, x_static_train, x_dyn_val, y_val, x_static_val, loss_fn=MAE_GroupFairness(),
                    log_metrics={"Fairness_MAE_Var": GroupFairnessMAEVariance(), "MAE": MAE()})

Epoch 1/50 | Val Loss: 26.7783 | Fairness_MAE_Var: 38.7627 | MAE: 14.7939
Epoch 2/50 | Val Loss: 26.5834 | Fairness_MAE_Var: 38.1320 | MAE: 15.0348
Epoch 3/50 | Val Loss: 26.3858 | Fairness_MAE_Var: 37.4228 | MAE: 15.3489
Epoch 4/50 | Val Loss: 26.1074 | Fairness_MAE_Var: 36.5290 | MAE: 15.6858
Epoch 5/50 | Val Loss: 25.6954 | Fairness_MAE_Var: 35.2814 | MAE: 16.1094
Epoch 6/50 | Val Loss: 25.2008 | Fairness_MAE_Var: 33.8064 | MAE: 16.5953
Epoch 7/50 | Val Loss: 24.4788 | Fairness_MAE_Var: 31.7968 | MAE: 17.1609
Epoch 8/50 | Val Loss: 23.4384 | Fairness_MAE_Var: 29.0050 | MAE: 17.8718
Epoch 9/50 | Val Loss: 21.9975 | Fairness_MAE_Var: 25.1011 | MAE: 18.8940
Epoch 10/50 | Val Loss: 19.7815 | Fairness_MAE_Var: 19.1072 | MAE: 20.4558
Epoch 11/50 | Val Loss: 16.9937 | Fairness_MAE_Var: 10.7505 | MAE: 23.2370
Epoch 12/50 | Val Loss: 14.3709 | Fairness_MAE_Var: 3.0891 | MAE: 25.6526
Epoch 13/50 | Val Loss: 12.7990 | Fairness_MAE_Var: 0.1195 | MAE: 25.4785
Epoch 14/50 | Val Loss: 9.1783 | Fai

In [24]:
group_fairness_metric = GroupFairnessMAEVariance()
mae_metric = MAE()
mae_group_fairness_metric = MAE_GroupFairness()

with torch.no_grad():
    preds_test = model(x_dyn_test, x_static_test)
    test_mae = mae_metric(preds_test, y_test, x_static_test).item()
    test_fairness = group_fairness_metric(preds_test, y_test, x_static_test).item()
    test_combined = mae_group_fairness_metric(preds_test, y_test, x_static_test).item()

print("Test MAE:", test_mae)
print("Test Fairness:", test_fairness)
print("Test Combined:", test_combined)

Test MAE: 3.4493825435638428
Test Fairness: 0.24558070302009583
Test Combined: 1.847481608390808


In [25]:
model = train_model(x_dyn_train, y_train, x_static_train, x_dyn_val, y_val, x_static_val, loss_fn=MAE(),
                    log_metrics={"Fairness_MAE_Var": GroupFairnessMAEVariance(), "MAE": MAE()})

Epoch 1/50 | Val Loss: 13.7788 | Fairness_MAE_Var: 40.7681 | MAE: 13.7788
Epoch 2/50 | Val Loss: 12.9184 | Fairness_MAE_Var: 41.7945 | MAE: 12.9184
Epoch 3/50 | Val Loss: 11.6873 | Fairness_MAE_Var: 43.0028 | MAE: 11.6873
Epoch 4/50 | Val Loss: 10.5048 | Fairness_MAE_Var: 37.7481 | MAE: 10.5048
Epoch 5/50 | Val Loss: 9.7977 | Fairness_MAE_Var: 25.1472 | MAE: 9.7977
Epoch 6/50 | Val Loss: 9.4295 | Fairness_MAE_Var: 16.7060 | MAE: 9.4295
Epoch 7/50 | Val Loss: 9.2546 | Fairness_MAE_Var: 12.2287 | MAE: 9.2546
Epoch 8/50 | Val Loss: 9.1791 | Fairness_MAE_Var: 10.8502 | MAE: 9.1791
Epoch 9/50 | Val Loss: 9.0674 | Fairness_MAE_Var: 13.9285 | MAE: 9.0674
Epoch 10/50 | Val Loss: 8.9734 | Fairness_MAE_Var: 17.3466 | MAE: 8.9734
Epoch 11/50 | Val Loss: 8.8904 | Fairness_MAE_Var: 25.3736 | MAE: 8.8904
Epoch 12/50 | Val Loss: 8.9168 | Fairness_MAE_Var: 29.7285 | MAE: 8.9168
Epoch 13/50 | Val Loss: 8.8834 | Fairness_MAE_Var: 30.1334 | MAE: 8.8834
Epoch 14/50 | Val Loss: 8.9543 | Fairness_MAE_Var: 3

In [26]:
group_fairness_metric = GroupFairnessMAEVariance()
mae_metric = MAE()
mae_group_fairness_metric = MAE_GroupFairness()

with torch.no_grad():
    preds_test = model(x_dyn_test, x_static_test)
    test_mae = mae_metric(preds_test, y_test, x_static_test).item()
    test_fairness = group_fairness_metric(preds_test, y_test, x_static_test).item()
    test_combined = mae_group_fairness_metric(preds_test, y_test, x_static_test).item()

print("Test MAE:", test_mae)
print("Test Fairness:", test_fairness)
print("Test Combined:", test_combined)

Test MAE: 2.9408762454986572
Test Fairness: 1.771399736404419
Test Combined: 2.356137990951538
