In [1]:
import numpy as np
import pandas as pd
import random
import os
import sys
sys.path.append('..')
# from modules import constants
from modules import former_constants as constants
from modules.env import LupusEnv
from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback
import torch as th
from torch.nn import functional as F

In [2]:
SEED = constants.SEED
random.seed(SEED)
np.random.seed(SEED)
os.environ['PYTHONHASHSEED']=str(SEED)
th.manual_seed(SEED)
th.use_deterministic_algorithms(True)
SEED

126

In [3]:
constants.BETA

9

In [4]:
class robustDQN(DQN):
    def __init__(self, *args, beta, al_r, al_p, p_proxy, **kwargs):
        self.beta = beta #(from fig 2)
        self.al_r = al_r
        self.al_p = al_p
        self.p_proxy = p_proxy
        self.norm_estimate = 0
        super(robustDQN, self).__init__(*args, **kwargs)
        
    
    def estimate_norm_r2(self, four_tuple_batch, q_net, q_net_target):
        state = four_tuple_batch[0]  # (batch, state)
#         print(f'state: {state}')
        q_net_val = q_net(state)  # (batch, actions)
        q_net_argmax = q_net_val.argmax(dim=1)  # (batch, )
        q_target_val = q_net_target(state)  # (batch, actions)
        qmax = q_target_val.gather(dim=1, index=q_net_argmax.unsqueeze(-1))  # (batch, )

        # calculate norm p for qmax
        if self.p_proxy == 'l2-norm':
#             print('using l2 norm')
            norm_estimate = (sum(qmax ** 2)) ** (1/2)  # dual norm (itself)
        elif self.p_proxy == 'l1-norm':
            norm_estimate = max(abs(qmax))  # dual norm (l_infinity)
        elif self.p_proxy == 'var-norm':  
#             print('Using var norm')
            norm_estimate = np.var(qmax.numpy()) ** (1/2)
        return norm_estimate, q_net_argmax

        
    
    def train(self, gradient_steps: int, batch_size: int = 100) -> None:
        self.policy.set_training_mode(True)
        self._update_learning_rate(self.policy.optimizer)

        losses = []
        for _ in range(gradient_steps):
            replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)  # type: ignore[union-attr]
            s = replay_data.observations
            a = replay_data.actions
            r = replay_data.rewards
            s_prime = replay_data.next_observations
            d = replay_data.dones
            
            with th.no_grad():
                next_q_values = self.q_net_target(s_prime)
                next_q_values, _ = next_q_values.max(dim=1)
                next_q_values = next_q_values.reshape(-1, 1) #max_q_prime
                
                #added
                current_norm_estimate, _ = self.estimate_norm_r2((s, a, r, s_prime), self.q_net, self.q_net_target)
                norm_estimate = self.beta * self.norm_estimate + (1 - self.beta) * current_norm_estimate  # moving avg
                self.norm_estimate = norm_estimate  # update last norm
                
                #changed this
                # try also without the 1-d. check github for reference
                target_q_values = r - self.al_r + (1 - d) * self.gamma * (next_q_values - self.al_p * norm_estimate)
                

            # Get current Q-values estimates
            current_q_values = self.q_net(s)

            # Retrieve the q-values for the actions from the replay buffer
            current_q_values = th.gather(current_q_values, dim=1, index=a.long())

            # Compute Huber loss (less sensitive to outliers)
            loss = F.smooth_l1_loss(current_q_values, target_q_values)
            losses.append(loss.item())

            # Optimize the policy
            self.policy.optimizer.zero_grad()
            loss.backward()
            # Clip gradient norm
            th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
            self.policy.optimizer.step()

        # Increase update counter
        self._n_updates += gradient_steps

        self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
        self.logger.record("train/loss", np.mean(losses))

#### Testing 

In [5]:
def create_env(X, y, random=True):
    '''
    Creates and environment using the given data
    '''
    env = LupusEnv(X, y, random)
    print(f'The environment seed is {env.seed()}') #to delete
    return env

In [6]:
constants.CHECKPOINT_FREQ

1000000

In [7]:
def stable_baselines3_robust_dqn(X_train, y_train, steps, save, log_path, log_prefix, filename, beta, al_r, al_p, 
                                 p_proxy):
    training_env = create_env(X_train, y_train)
    model = robustDQN('MlpPolicy', training_env, verbose=1, seed=constants.SEED, beta=beta, al_r=al_r, al_p=al_p,
                     p_proxy=p_proxy)
    checkpoint_callback = CheckpointCallback(save_freq=constants.CHECKPOINT_FREQ, save_path=log_path, 
                                            name_prefix=log_prefix)
    model.learn(total_timesteps=steps, log_interval=100000, callback=checkpoint_callback)
    if save:
#         model.save(f'{log_path}/{filename}.pkl')
        model.save(f'{log_path}/{filename}_full_model')
    training_env.close()
    return model

In [8]:
def run_robust_dqn_model(steps, beta, al_r, al_p, p_proxy):
    dir_name = f'seed_{SEED}_{steps}'
    parent_dir = f'../models/logs/robust_dqn3/noisiness/0.3/biopsy_9/var_norm'
    path = os.path.join(parent_dir, dir_name)
#     os.mkdir(path)
    model = stable_baselines3_robust_dqn(X_train, y_train, steps, True, log_path=path, log_prefix='robust_dqn3', 
                                  filename=f'robust_dqn3_{steps}', beta=beta, al_r=al_r, al_p=al_p, p_proxy=p_proxy)
    return model

In [9]:
train_df = pd.read_csv('../new_data/train_set_noisiness_0.3.csv')
# train_df = train_df.fillna(-1)
train_df.head()

Unnamed: 0,ana,fever,leukopenia,thrombocytopenia,auto_immune_hemolysis,delirium,psychosis,seizure,non_scarring_alopecia,oral_ulcers,...,proteinuria,biopsy_proven_lupus_nephritis,anti_cardioliphin_antibodies,anti_β2gp1_antibodies,lupus_anti_coagulant,low_c3,low_c4,anti_dsdna_antibody,anti_smith_antibody,label
0,1,1,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,1,0,0,1
1,1,0,1,0,0,0,0,0,1,0,...,0,0,0,1,0,0,0,1,0,1
2,0,0,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
3,0,0,0,0,0,0,0,0,0,0,...,1,0,0,0,0,0,0,0,0,0
4,1,0,0,0,0,1,0,0,1,0,...,1,3,0,0,0,0,0,0,0,0


In [10]:
train_df.label.value_counts()

1    25210
0    25190
Name: label, dtype: int64

In [11]:
X_train = train_df.iloc[:, 0:-1]
y_train = train_df.iloc[:, -1]
X_train, y_train = np.array(X_train), np.array(y_train)
X_train.shape, y_train.shape

((50400, 24), (50400,))

In [12]:
steps = 100000000
beta = 1 #
al_r = 0.01
al_p = 0.01
p_proxy = 'var-norm'

In [13]:
model = run_robust_dqn_model(steps, beta, al_r, al_p, p_proxy)

The environment seed is [126]
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.62     |
|    ep_rew_mean      | -0.67    |
|    exploration_rate | 0.954    |
|    success_rate     | 0.19     |
| time/               |          |
|    episodes         | 100000   |
|    fps              | 2860     |
|    time_elapsed     | 169      |
|    total_timesteps  | 484209   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.072    |
|    n_updates        | 108552   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 4.71     |
|    ep_rew_mean      | -0.712   |
|    exploration_rate | 0.907    |
|    success_rate     | 0.17     |
| time/               |          |
|    episodes         | 200000   |
|    fps              | 2679     |
|

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 14.3     |
|    ep_rew_mean      | -0.379   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.41     |
| time/               |          |
|    episodes         | 1600000  |
|    fps              | 2288     |
|    time_elapsed     | 5754     |
|    total_timesteps  | 13171161 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.214    |
|    n_updates        | 3280290  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 14.1     |
|    ep_rew_mean      | -0.424   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.4      |
| time/               |          |
|    episodes         | 1700000  |
|    fps              | 2244     |
|    time_elapsed     | 6492     |
|    total_timesteps  | 14573337 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 13.9     |
|    ep_rew_mean      | -0.00571 |
|    exploration_rate | 0.05     |
|    success_rate     | 0.58     |
| time/               |          |
|    episodes         | 3100000  |
|    fps              | 2120     |
|    time_elapsed     | 15979    |
|    total_timesteps  | 33888504 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0985   |
|    n_updates        | 8459625  |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 14.1     |
|    ep_rew_mean      | -0.198   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.48     |
| time/               |          |
|    episodes         | 3200000  |
|    fps              | 2116     |
|    time_elapsed     | 16687    |
|    total_timesteps  | 35317433 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 11.9     |
|    ep_rew_mean      | 0.255    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.68     |
| time/               |          |
|    episodes         | 4600000  |
|    fps              | 2084     |
|    time_elapsed     | 25253    |
|    total_timesteps  | 52650291 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0147   |
|    n_updates        | 13150072 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 10.6     |
|    ep_rew_mean      | 0.225    |
|    exploration_rate | 0.05     |
|    success_rate     | 0.66     |
| time/               |          |
|    episodes         | 4700000  |
|    fps              | 2083     |
|    time_elapsed     | 25814    |
|    total_timesteps  | 53783918 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 11.5     |
|    ep_rew_mean      | 0.2      |
|    exploration_rate | 0.05     |
|    success_rate     | 0.65     |
| time/               |          |
|    episodes         | 6100000  |
|    fps              | 2068     |
|    time_elapsed     | 33695    |
|    total_timesteps  | 69713476 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00222  |
|    n_updates        | 17415868 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 12.2     |
|    ep_rew_mean      | -0.0121  |
|    exploration_rate | 0.05     |
|    success_rate     | 0.55     |
| time/               |          |
|    episodes         | 6200000  |
|    fps              | 2068     |
|    time_elapsed     | 34273    |
|    total_timesteps  | 70881368 |
| train/              |          |
|    learning_rate  

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 11.7     |
|    ep_rew_mean      | 0.0926   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.6      |
| time/               |          |
|    episodes         | 7600000  |
|    fps              | 2056     |
|    time_elapsed     | 42391    |
|    total_timesteps  | 87187134 |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0105   |
|    n_updates        | 21784283 |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 12.6     |
|    ep_rew_mean      | 0.0306   |
|    exploration_rate | 0.05     |
|    success_rate     | 0.57     |
| time/               |          |
|    episodes         | 7700000  |
|    fps              | 2052     |
|    time_elapsed     | 43048    |
|    total_timesteps  | 88339360 |
| train/              |          |
|    learning_rate  