In [1]:
import pandas as pd
import numpy as np
import wandb, gc, optuna, shutil
import torch, json
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader
from rl import make_transition, make_df, make_transition_test, Sampler, Model

In [2]:
import random

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["WANDB_SILENT"] = "true"

print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.device_count())
print(torch.version.cuda)
print(torch.cuda.get_device_name(0))
print(torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

torch.set_printoptions(precision=4, sci_mode=False)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.allow_tf32 = True

2.2.1
True
1
12.1
NVIDIA GeForce RTX 4080
학습을 진행하는 기기: cuda:0


In [4]:
train_id = pd.read_csv('processed/train_id.csv',index_col=0)
valid_id = pd.read_csv('processed/val_id.csv',index_col=0)
test_id = pd.read_csv('processed/test_id.csv',index_col=0)

BCQ

In [None]:
class VAE(nn.Module):
    def __init__(self, state_dim, action_dim, latent_dim=16, hidden_size=128):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(hidden_size, latent_dim)
        self.fc_logvar = nn.Linear(hidden_size, latent_dim)
        self.decoder = nn.Sequential(
            nn.Linear(state_dim + latent_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, action_dim)
        )

    def forward(self, s, a=None):
        if a is not None:
            if a.dim() == 2 and a.shape[1] == 1:
                a = a.squeeze(1)
            batch_size = s.size(0)
            action_dim = self.decoder[-1].out_features
            a_onehot = F.one_hot(a, num_classes=action_dim).float()
            enc_input = torch.cat([s, a_onehot], dim=1)
            h = self.encoder(enc_input)
            mu = self.fc_mu(h)
            logvar = self.fc_logvar(h)
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            z = mu + eps * std
            dec_input = torch.cat([s, z], dim=1)
            logits = self.decoder(dec_input)
            return logits, mu, logvar
        else:
            raise NotImplementedError("Use sample_actions method to generate actions without a label.")

    def sample_actions(self, s, num_samples):
        batch_size = s.size(0)
        latent_dim = self.fc_mu.out_features
        s_repeat = s.unsqueeze(1).repeat(1, num_samples, 1)
        s_repeat = s_repeat.view(batch_size * num_samples, -1)
        z = torch.randn(batch_size * num_samples, latent_dim, device=s.device)
        dec_in = torch.cat([s_repeat, z], dim=1)
        logits = self.decoder(dec_in)
        probs = F.softmax(logits, dim=1)
        a_samples = torch.multinomial(probs, 1)
        return a_samples.view(batch_size, num_samples)

def vae_loss(logits, a, mu, logvar):
    recon_loss = F.cross_entropy(logits, a, reduction='mean')
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + 0.5 * kld

class Perturbation(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_size=128):
        super(Perturbation, self).__init__()
        self.action_dim = action_dim
        self.net = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 1)
        )

    def forward(self, s, a):
        if a.dim() == 2 and a.shape[1] == 1:
            a = a.squeeze(1)
        a_onehot = F.one_hot(a, num_classes=self.action_dim).float()
        x = torch.cat([s, a_onehot], dim=1)
        return self.net(x)

def pretrain_vae(train_data, vae, vae_lr=1e-3, vae_epochs=10, batch_size=256):
    optimizer = optim.Adam(vae.parameters(), lr=vae_lr)
    loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    vae.train()
    for ep in range(vae_epochs):
        total_loss = 0
        for s, a, r, s2, t in loader:
            s, a = s.to(device), a.to(device)
            if a.dim() == 2 and a.shape[1] == 1:
                a = a.squeeze(1)
            a = a.long() - 1
            logits, mu, logvar = vae(s, a)
            loss = vae_loss(logits, a, mu, logvar)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"[VAE pretrain epoch {ep+1}/{vae_epochs}] loss={total_loss/len(loader):.4f}")
    return vae

def load_best_vae_perturbation(study, obs_dim, nb_actions, path):
    best_trial = study.best_trial
    best_trial_number = best_trial.number + 1
    vae_path = f"{path}/vae_pretrained_{best_trial_number}.pth"
    perturb_path = f"{path}/perturb_pretrained_{best_trial_number}.pth"
    vae = VAE(state_dim=obs_dim, action_dim=nb_actions, latent_dim=best_trial.params['bcq_latent_dim'],
              hidden_size=best_trial.params['bcq_hidden_size']).to(device)
    perturbation = Perturbation(state_dim=obs_dim, action_dim=nb_actions,
                                hidden_size=best_trial.params['bcq_hidden_size']).to(device)
    vae.load_state_dict(torch.load(vae_path))
    perturbation.load_state_dict(torch.load(perturb_path))
    print(f"Loaded best VAE and Perturbation from trial {best_trial_number}")
    return vae, perturbation


In [None]:
def objective_vae(trial, train_data, val_data, path):
    latent_dim = trial.suggest_int('bcq_latent_dim', 16, 32, step=8)
    hidden_size = trial.suggest_categorical('bcq_hidden_size', [32, 64, 128])
    vae_lr = trial.suggest_float('vae_lr', 1e-4, 5e-3, log=True)
    vae_epochs = trial.suggest_int('bcq_pretrain_epochs', 5, 15, step=5)
    batch_size = trial.suggest_categorical('batch_size', [64, 128, 256])

    obs_dim = train_data.tensors[0].shape[1]
    nb_actions = int(max(train_data.tensors[1]))

    vae = VAE(state_dim=obs_dim, action_dim=nb_actions, latent_dim=latent_dim, hidden_size=hidden_size).to(device)
    perturbation = Perturbation(state_dim=obs_dim, action_dim=nb_actions, hidden_size=hidden_size).to(device)

    vae = pretrain_vae(train_data, vae, vae_lr=vae_lr, vae_epochs=vae_epochs, batch_size=batch_size)

    vae.eval()
    total_loss = 0
    with torch.no_grad():
        for s, a, _, _, _ in DataLoader(val_data, batch_size=batch_size, num_workers=4, pin_memory=True):
            s, a = s.to(device), a.to(device)

            if a.dim() == 2 and a.shape[1] == 1:
                a = a.squeeze(1)
            a = a.long() - 1

            logits, mu, logvar = vae(s, a)
            loss = vae_loss(logits, a, mu, logvar)
            total_loss += loss.item()
    
    avg_loss = total_loss / len(val_data)
    print(f"[Evaluation] VAE Loss: {avg_loss:.4f}")

    torch.save(vae.state_dict(), f"{path}/vae_pretrained_{trial.number+1}.pth")
    torch.save(perturbation.state_dict(), f"{path}/perturb_pretrained_{trial.number+1}.pth")
    
    return avg_loss

IQL

In [None]:
def expectile_loss(advantage, tau):
    positive_mask = (advantage > 0).float()
    weight = tau * positive_mask + (1.0 - tau) * (1.0 - positive_mask)
    return (weight * (advantage ** 2)).mean()

Train

In [None]:
def train(
    batch_size, lr, lr_decay, lr_epoch, ns, epochs, update_freq,
    mlp_size, mlp_num_layers, activation_type,
    algorithm,
    version,
    vae,
    perturbation,
    path,
    train_data, val_data, val_transition,
    alpha=0,
    bcq_num_samples=0,
    bcq_threshold=0.0,
    re_score_lambda=0.0,
    tau=0.7,
    trial=None,
    device=device
):
    def clamping(data, version):
        if version == '_negative':
            data_clamped = torch.clamp(data, min=-1.0, max=0.0)
        elif version == '_positive':
            data_clamped = torch.clamp(data, min=0.0, max=1.0)
        else:
            data_clamped = torch.clamp(data, min=-1.0, max=1.0)
        return data_clamped

    d_f = 1.0
    num_workers = 4
    patience = 10
    best_loss = float('inf')
    best_auroc_p_gat = 0.0
    best_auroc_p_med = 0.0
    no_improve_count = 0

    valid_auroc_gat, valid_auroc_med, valid_auroc_min, valid_auroc_max = [], [], [], []
    valid_auroc_p_gat, valid_auroc_p_med, valid_auroc_p_min, valid_auroc_p_max = [], [], [], []
    train_losses, valid_losses = [], []

    obs_dim = train_data.tensors[0].shape[1]
    nb_actions = int(max(train_data.tensors[1]))

    network = Model(
        obs_dim=obs_dim,
        nb_actions=nb_actions,
        mlp_size=mlp_size,
        mlp_num_layers=mlp_num_layers,
        activation_type=activation_type
    ).to(device)

    target_network = Model(
        obs_dim=obs_dim,
        nb_actions=nb_actions,
        mlp_size=mlp_size,
        mlp_num_layers=mlp_num_layers,
        activation_type=activation_type
    ).to(device)
    target_network.load_state_dict(network.state_dict())

    optimizer = optim.Adam(network.parameters(), lr=lr)
    scheduler = ExponentialLR(optimizer, gamma=lr_decay)

    if algorithm == '_bcq':
        optimizer = optim.Adam(
            list(network.parameters()) + list(perturbation.parameters()),
            lr=lr
        )

    if algorithm == '_iql':
        value_network = Model(
            obs_dim=obs_dim,
            nb_actions=1,
            mlp_size=mlp_size,
            mlp_num_layers=mlp_num_layers,
            activation_type=activation_type
        ).to(device)

        target_value_network = Model(
            obs_dim=obs_dim,
            nb_actions=1,
            mlp_size=mlp_size,
            mlp_num_layers=mlp_num_layers,
            activation_type=activation_type
        ).to(device)
        target_value_network.load_state_dict(value_network.state_dict())

        value_optimizer = optim.Adam(value_network.parameters(), lr=lr)
        value_scheduler = ExponentialLR(value_optimizer, gamma=lr_decay)

    train_loader = DataLoader(
        train_data,
        batch_sampler=Sampler(train_data, batch_size, version, ns),
        num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(val_data, batch_size=4096, shuffle=False, num_workers=num_workers, pin_memory=True)

    for epoch in range(1, epochs + 1):
        network.train()
        if perturbation is not None:
            perturbation.train()
        if algorithm == '_iql':
            value_network.train()

        train_loss = 0.0

        for s, a, r, s2, t in train_loader:
            s, a, r, s2, t = s.to(device), a.to(device), r.to(device), s2.to(device), t.to(device)
            a = a.long() - 1
            r_clamped = clamping(r, version)

            if algorithm == '_iql':
                q_values = network(s)
                q_pred = q_values.gather(1, a).squeeze(1)
                v_pred = value_network(s).squeeze(1)
                adv = q_pred.detach() - v_pred
                value_loss = expectile_loss(adv, tau)

                value_optimizer.zero_grad()
                value_loss.backward()
                value_optimizer.step()

                with torch.no_grad():
                    v_next = target_value_network(s2).squeeze(1)
                v_clamped = clamping(v_next, version)

                bellman_target = r_clamped + d_f * v_clamped * (1. - t)
                q_pred = q_values.gather(1, a).squeeze(1)
                q_loss = F.smooth_l1_loss(q_pred, bellman_target)

                optimizer.zero_grad()
                q_loss.backward()
                optimizer.step()

                total_loss = value_loss.item() + q_loss.item()
                train_loss += total_loss

            else:
                q_values = network(s)
                q_pred = q_values.gather(1, a).squeeze(1)

                with torch.no_grad():
                    q2_tgt = target_network(s2)

                    if algorithm == '_dqn':
                        max_actions = torch.max(q2_tgt, dim=1)[1].unsqueeze(1)
                        q2_max = q2_tgt.gather(1, max_actions).squeeze()

                    elif algorithm == '_bcq':
                        candidates = vae.sample_actions(s2, num_samples=bcq_num_samples)
                        B = s2.size(0)
                        q2_rep = q2_tgt.unsqueeze(1).repeat(1, bcq_num_samples, 1).view(B * bcq_num_samples, nb_actions)
                        cand_flat = candidates.view(B * bcq_num_samples)
                        q2_for_cand = q2_rep.gather(1, cand_flat.unsqueeze(1)).squeeze(1).view(B, bcq_num_samples)

                        re_score = torch.zeros_like(q2_for_cand)
                        if perturbation is not None:
                            s_expand = s2.unsqueeze(1).repeat(1, bcq_num_samples, 1).view(B * bcq_num_samples, -1)
                            pert_scores = perturbation(s_expand, cand_flat).view(B, bcq_num_samples)
                            re_score = re_score_lambda * pert_scores

                        total_score = q2_for_cand + re_score
                        best_actions = []
                        for i in range(B):
                            valid_mask = (q2_for_cand[i] >= bcq_threshold)
                            valid_indices = torch.where(valid_mask)[0]
                            if len(valid_indices) == 0:
                                idx = torch.argmax(q2_for_cand[i])
                            else:
                                sub_score = total_score[i][valid_indices]
                                idx_local = torch.argmax(sub_score)
                                idx = valid_indices[idx_local]
                            best_act = candidates[i, idx]
                            best_actions.append(best_act.item())

                        best_actions = torch.tensor(best_actions, device=s2.device, dtype=torch.long)
                        q2_max = q2_tgt.gather(1, best_actions.unsqueeze(1)).squeeze()

                    else:
                        q2_net = network(s2)
                        max_actions = torch.max(q2_net, dim=1)[1].unsqueeze(1)
                        q2_max = q2_tgt.gather(1, max_actions).squeeze()

                q2_max_clamped = clamping(q2_max, version)
                bellman_target = r_clamped + d_f * q2_max_clamped * (1 - t)
                td_loss = F.smooth_l1_loss(q_pred, bellman_target)

                if algorithm in ['_dqn', '_ddqn', '_bcq']:
                    loss = td_loss
                elif algorithm == '_cql':
                    logsumexp_q = torch.logsumexp(q_values, dim=1)
                    cql_term = (logsumexp_q - q_pred).mean()
                    loss = td_loss + alpha * cql_term

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

        network.eval()
        if perturbation is not None:
            perturbation.eval()
        if algorithm == '_iql':
            value_network.eval()

        with torch.no_grad():
            valid_loss = 0.0
            for s, a, r, s2, t in val_loader:
                s, a, r, s2, t = s.to(device), a.to(device), r.to(device), s2.to(device), t.to(device)
                a = a.long() - 1
                r_clamped = clamping(r, version)
                q_values = network(s)
                q_pred = q_values.gather(1, a).squeeze()

                if algorithm == '_iql':
                    v_next = target_value_network(s2).squeeze(1)
                    v_clamped = clamping(v_next, version)
                    bellman_target = r_clamped + d_f * v_clamped * (1. - t)
                    loss = F.smooth_l1_loss(q_pred, bellman_target)
                    valid_loss += loss.item()
                else:
                    q2_tgt = target_network(s2)
                    if algorithm == '_dqn':
                        max_actions = torch.max(q2_tgt, dim=1)[1].unsqueeze(1)
                        q2_max = q2_tgt.gather(1, max_actions).squeeze()
                    elif algorithm == '_bcq':
                        candidates = vae.sample_actions(s2, bcq_num_samples)
                        B = s2.size(0)
                        q2_rep = q2_tgt.unsqueeze(1).repeat(1, bcq_num_samples, 1).view(B * bcq_num_samples, nb_actions)
                        cand_flat = candidates.view(B * bcq_num_samples)
                        q2_for_cand = q2_rep.gather(1, cand_flat.unsqueeze(1)).squeeze(1).view(B, bcq_num_samples)
                        re_score = torch.zeros_like(q2_for_cand)
                        if perturbation is not None:
                            s_expand = s2.unsqueeze(1).repeat(1, bcq_num_samples, 1).view(B * bcq_num_samples, -1)
                            pert_scores = perturbation(s_expand, cand_flat).view(B, bcq_num_samples)
                            re_score = re_score_lambda * pert_scores
                        total_score = q2_for_cand + re_score
                        best_actions = []
                        for i in range(B):
                            valid_mask = (q2_for_cand[i] >= bcq_threshold)
                            valid_indices = torch.where(valid_mask)[0]
                            if len(valid_indices) == 0:
                                idx = torch.argmax(q2_for_cand[i])
                            else:
                                sub_score = total_score[i][valid_indices]
                                idx_local = torch.argmax(sub_score)
                                idx = valid_indices[idx_local]
                            best_act = candidates[i, idx]
                            best_actions.append(best_act.item())
                        best_actions = torch.tensor(best_actions, device=s2.device, dtype=torch.long)
                        q2_max = q2_tgt.gather(1, best_actions.unsqueeze(1)).squeeze()
                    else:
                        q2_net = network(s2)
                        max_actions = torch.max(q2_net, dim=1)[1].unsqueeze(1)
                        q2_max = q2_tgt.gather(1, max_actions).squeeze()

                    q2_max_clamped = clamping(q2_max, version)
                    bellman_target = r_clamped + d_f * q2_max_clamped * (1 - t)
                    loss = F.smooth_l1_loss(q_pred, bellman_target)
                    valid_loss += loss.item()

            q_value_list, reward_list, patient_list, action_list = [], [], [], []
            for s, a, r, rp in val_transition:
                s, a = s.to(device), a.to(device)
                a = a.long() - 1
                q = network(s)
                r_clamped = clamping(r, version)
                rp_clamped = clamping(rp, version)
                q_clamped = clamping(q, version)
                q_value_list.append(q_clamped.detach().cpu().numpy())
                reward_list.append(r_clamped.detach().cpu().numpy())
                patient_list.append(rp_clamped.detach().cpu().numpy())
                action_list.append(a.detach().cpu().numpy())

            action_space = np.concatenate(action_list, axis=0)
            q_value = np.concatenate(q_value_list, axis=0)
            reward = np.concatenate(reward_list, axis=0)
            patient = np.concatenate(patient_list, axis=0)

            q_max = np.max(q_value, axis=1)
            q_min = np.min(q_value, axis=1)
            q_median = np.median(q_value, axis=1)
            q_gather = q_value[np.arange(q_value.shape[0]), action_space]

            auroc_p_max = roc_auc_score(patient, q_max)
            auroc_p_min = roc_auc_score(patient, q_min)
            auroc_p_med = roc_auc_score(patient, q_median)
            auroc_p_gat = roc_auc_score(patient, q_gather)

            valid_auroc_p_gat.append(auroc_p_gat)
            valid_auroc_p_med.append(auroc_p_med)
            valid_auroc_p_min.append(auroc_p_min)
            valid_auroc_p_max.append(auroc_p_max)

            train_losses.append(train_loss)
            valid_losses.append(valid_loss)

            if version != '_both':
                auroc_max = roc_auc_score(reward, q_max)
                auroc_min = roc_auc_score(reward, q_min)
                auroc_med = roc_auc_score(reward, q_median)
                auroc_gat = roc_auc_score(reward, q_gather)
                valid_auroc_gat.append(auroc_gat)
                valid_auroc_med.append(auroc_med)
                valid_auroc_min.append(auroc_min)
                valid_auroc_max.append(auroc_max)
                wandb.log({
                    'epoch': epoch,
                    'train_loss': train_loss,
                    'valid_loss': valid_loss,
                    'reward_max': auroc_max,
                    'reward_min': auroc_min,
                    'reward_median': auroc_med,
                    'reward_gather': auroc_gat,
                    'patient_max': auroc_p_max,
                    'patient_min': auroc_p_min,
                    'patient_median': auroc_p_med,
                    'patient_gather': auroc_p_gat
                })
            else:
                wandb.log({
                    'epoch': epoch,
                    'train_loss': train_loss,
                    'valid_loss': valid_loss,
                    'patient_max': auroc_p_max,
                    'patient_min': auroc_p_min,
                    'patient_median': auroc_p_med,
                    'patient_gather': auroc_p_gat
                })

        if (auroc_p_gat > best_auroc_p_gat) & (auroc_p_med > best_auroc_p_med):
            best_auroc_p_gat = auroc_p_gat
            best_auroc_p_med = auroc_p_med
            no_improve_count = 0
        else:
            no_improve_count += 1

        gc.collect()
        torch.cuda.empty_cache()

        torch.save(network.state_dict(), os.path.join(path, f'network_{epoch}.pth'))
        torch.save(target_network.state_dict(), os.path.join(path, f'target_network_{epoch}.pth'))

        if algorithm == '_iql':
            torch.save(value_network.state_dict(), os.path.join(path, f'value_network_{epoch}.pth'))
            torch.save(target_value_network.state_dict(), os.path.join(path, f'target_value_network_{epoch}.pth'))

        if epoch % lr_epoch == 0:
            scheduler.step()
            if algorithm == '_iql':
                value_scheduler.step()

        if epoch % update_freq == 0:
            target_network.load_state_dict(network.state_dict())
            if algorithm == '_iql':
                target_value_network.load_state_dict(value_network.state_dict())

        if no_improve_count >= patience:
            print(f"Early stopping at epoch {epoch} due to no improvement.")
            break

        trial.report(auroc_p_gat + auroc_p_med, step=epoch)
        if trial.should_prune():
            metrics = []
            if version != '_both':
                metrics.extend(['gat', 'med', 'min', 'max'])
            metrics.extend(['p_gat', 'p_med', 'p_min', 'p_max'])

            valid_auroc_values = {}
            if version != '_both':
                valid_auroc_values['gat'] = valid_auroc_gat
                valid_auroc_values['med'] = valid_auroc_med
                valid_auroc_values['min'] = valid_auroc_min
                valid_auroc_values['max'] = valid_auroc_max

            valid_auroc_values['p_gat'] = valid_auroc_p_gat
            valid_auroc_values['p_med'] = valid_auroc_p_med
            valid_auroc_values['p_min'] = valid_auroc_p_min
            valid_auroc_values['p_max'] = valid_auroc_p_max

            for metric in metrics:
                torch.save(valid_auroc_values[metric], f'{path}/valid_auroc_{metric}.pth')

            torch.save(train_losses, f'{path}/train_losses.pth')
            torch.save(valid_losses, f'{path}/valid_losses.pth')
            raise optuna.TrialPruned()

    metrics = []
    if version != '_both':
        metrics.extend(['gat', 'med', 'min', 'max'])
    metrics.extend(['p_gat', 'p_med', 'p_min', 'p_max'])

    valid_auroc_values = {}
    if version != '_both':
        valid_auroc_values['gat'] = valid_auroc_gat
        valid_auroc_values['med'] = valid_auroc_med
        valid_auroc_values['min'] = valid_auroc_min
        valid_auroc_values['max'] = valid_auroc_max

    valid_auroc_values['p_gat'] = valid_auroc_p_gat
    valid_auroc_values['p_med'] = valid_auroc_p_med
    valid_auroc_values['p_min'] = valid_auroc_p_min
    valid_auroc_values['p_max'] = valid_auroc_p_max

    for metric in metrics:
        torch.save(valid_auroc_values[metric], f'{path}/valid_auroc_{metric}.pth')

    torch.save(train_losses, f'{path}/train_losses.pth')
    torch.save(valid_losses, f'{path}/valid_losses.pth')

    return max([x + y for x, y in zip(valid_auroc_p_gat, valid_auroc_p_med)]) if len(valid_auroc_p_gat) > 0 else 0.0

Optuna Objective

In [None]:
def objective(trial, train_data, val_data, val_transition, target, version, algorithm, vae, perturbation):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    config = {
        'batch_size': trial.suggest_categorical('batch_size', [128, 256, 512]),
        'lr': trial.suggest_categorical('lr', [1e-4, 1e-3, 1e-2]),
        'lr_decay': trial.suggest_categorical('lr_decay', [0.9, 0.95, 0.99]),
        'lr_epoch': trial.suggest_categorical('lr_epoch', [1, 2, 3]),
        'ns': trial.suggest_categorical('negative_sampling', [2, 4, 8]),
        'epochs': trial.suggest_categorical('epochs', [200]),
        'update_freq': trial.suggest_categorical('update_freq', [1, 2, 3]),
        'mlp_size': trial.suggest_categorical('mlp_size', [32, 64, 128]),
        'mlp_num_layers': trial.suggest_int('mlp_num_layers', 1, 3),
        'activation_type': trial.suggest_categorical('activation_type', ['ReLU', 'LeakyReLU', 'Tanh']),
        'algorithm': trial.suggest_categorical('algorithm', [algorithm]),
        'version': trial.suggest_categorical('version', [version]),
    }

    if algorithm == '_bcq':
        config['bcq_num_samples'] = trial.suggest_int('bcq_num_samples', 7, 9)
        config['bcq_threshold'] = trial.suggest_float('bcq_threshold', 0.1, 0.3, step=0.1)
        config['re_score_lambda'] = trial.suggest_float('re_score_lambda', 0.1, 0.3, step=0.1)

    if algorithm == '_cql':
        config['alpha'] = trial.suggest_categorical('alpha', [1e-4, 1e-2, 1.0])

    if algorithm == '_iql':
        config['tau'] = trial.suggest_float('tau', 0.7, 0.9, step=0.1)

    result = None

    try:
        print(f"[INFO] Starting trial {trial.number+1}")
        
        if wandb.run is not None:
            print("[INFO] Existing wandb run detected. Finishing it.")
            wandb.finish()
        
        wandb.init(
            project=f'MedClap_{target}',
            name=f'H-{algorithm}-{version}-{trial.number+1}',
            config=config,
            settings=wandb.Settings(init_timeout=180, start_method="thread")
        )
        
        print("[INFO] wandb initialized.")

        model_save_path = os.path.join(
            'experiments', target.lower(), algorithm, version, f'trial_{trial.number+1}'
        )
        os.makedirs(model_save_path, exist_ok=True)
        print(f"[INFO] Model save path created: {model_save_path}")

        result = train(
            path=model_save_path,
            train_data=train_data,
            val_data=val_data,
            val_transition=val_transition,
            vae=vae,
            perturbation=perturbation,
            trial=trial,
            device=device,
            **config
        )
        print(f"[INFO] Trial {trial.number+1} finished successfully.")
    except optuna.TrialPruned:
        print(f"[INFO] Trial {trial.number+1} pruned.")
    except Exception as e:
        import traceback
        print(f"[ERROR] Error in trial {trial.number+1}: {str(e)}")
        traceback.print_exc()
    finally:
        if wandb.run is not None:
            wandb.finish()

    return result

In [11]:
from optuna.pruners import BasePruner, HyperbandPruner, ThresholdPruner

class CompositePruner(BasePruner):
    def __init__(self, pruner1, pruner2):
        self.pruner1 = pruner1
        self.pruner2 = pruner2

    def prune(self, study, trial):
        prune1 = self.pruner1.prune(study, trial)
        prune2 = self.pruner2.prune(study, trial)
        return prune1 or prune2

hyperband = HyperbandPruner(min_resource=10, max_resource='auto', reduction_factor=3)
threshold = ThresholdPruner(n_warmup_steps=10, lower=1.5)

pruner = CompositePruner(hyperband, threshold)

In [None]:
sampler = optuna.samplers.TPESampler(
    n_startup_trials=20, 
    n_ei_candidates=40,      
    multivariate=True,       
    warn_independent_sampling=True
)

Data

In [5]:
targets = ['Dead_icu', 'Dead_hosp', 'Dead_90'] + ['AKI_rrt', 'AKI_48', 'AKI_24', 'AKI_12'] + ['Septic_shock']
prefixes = ['dead_icu', 'dead_hosp', 'dead_90'] + ['aki_rrt', 'aki_48', 'aki_24', 'aki_12'] + ['septic_shock']
algorithms = ['_ddqn','_cql','_iql','_bcq']
versions = ['_negative','_positive','_both']
base_path, n_trials, pre_train_trials = 'experiments', 100, 20

vae_cache = {
    'Dead': None,
    'AKI': None,
    'Septic_shock': None
}
perturb_cache = {
    'Dead': None,
    'AKI': None,
    'Septic_shock': None
}

In [None]:
if __name__ == "__main__":

    for idx, target in enumerate(targets):

        if 'Septic' in target : reward = 'r:reward_septic_shock'
        elif 'Dead' in target : reward = 'r:reward_dead'
        else : reward = 'r:reward_aki'

        data = pd.read_csv(f'processed/df_{target}.csv')
        #data[reward] = data.groupby('traj')[reward].transform(lambda x: x[:-1].tolist() + ([1] if x.iloc[-1] == 0 else [x.iloc[-1]]))

        train_df, valid_df, test_df = make_df(data, reward, train_id, valid_id, test_id)
        train_data, val_data = (make_transition(df, reward, rolling_size=1) for df in (train_df, valid_df))
        val_transition = make_transition_test(valid_df, reward, rolling_size=1)
            
        for algorithm in algorithms:

            path = os.path.join('experiments', prefixes[idx], algorithm)
            shutil.rmtree(path, ignore_errors=True)
            os.makedirs(path)

            vae, perturbation = None, None

            if algorithm == '_bcq':
                if 'Dead' in target:
                    group = 'Dead'
                elif 'AKI' in target:
                    group = 'AKI'
                else:
                    group = 'Septic_shock'
                
                if vae_cache[group] is None:
                    print(f"Training VAE for {group} group...")

                    study = optuna.create_study(direction='minimize') 
                    path = os.path.join('experiments', prefixes[idx], algorithm)
                    study.optimize(lambda trial: objective_vae(trial, train_data, val_data, path), n_trials=pre_train_trials, n_jobs=1, gc_after_trial=True,)

                    obs_dim = train_data.tensors[0].shape[1]
                    nb_actions = int(max(train_data.tensors[1]))
                    vae, perturbation = load_best_vae_perturbation(study, obs_dim, nb_actions, path)

                    vae_cache[group] = vae
                    perturb_cache[group] = perturbation
                else:
                    print(f"Using cached VAE for {group} group.")
                    vae = vae_cache[group]
                    perturbation = perturb_cache[group]
                
            for version in versions:

                best_params_path = os.path.join('experiments', prefixes[idx], algorithm, version, 'best_params.json')
                if os.path.exists(best_params_path):
                    print(f"[Skip] Already found best_params.json for version: {version}")
                    continue
        
                for i in range(1, n_trials + 1):
                    path = os.path.join('experiments', prefixes[idx], algorithm, version, f'trial_{i}')
                    shutil.rmtree(path, ignore_errors=True)
                    os.makedirs(path)

                def save_trial_parameters(study, trial):
                    params_path = os.path.join('experiments', prefixes[idx], algorithm, version, f'trial_{trial.number+1}')
                    os.makedirs(params_path, exist_ok=True)
                    with open(os.path.join(params_path, 'params.json'), 'w') as f:
                        json.dump(trial.params, f)
                
                study = optuna.create_study(direction='maximize', sampler=sampler, pruner=pruner)

                study.optimize(
                    lambda trial: objective(trial, train_data, val_data, val_transition, 
                                            target, version, algorithm, vae, perturbation), n_trials=n_trials, n_jobs=1, catch=(optuna.TrialPruned,), gc_after_trial=True, 
                    callbacks=[save_trial_parameters]
                )
                
                from optuna.trial import TrialState
                completed_trials = [t for t in study.trials if t.state == TrialState.COMPLETE]

                if len(completed_trials) > 0:
                    with open(os.path.join('experiments', prefixes[idx], algorithm, version, 'best_params.json'), 'w') as f:
                        json.dump(study.best_params, f)
                else:
                    print(f"[Skip] No completed trials for version: {version}, skipping best_params.json save.")