## Imports

In [1]:
! pip install stable-baselines3[extra] keras rdkit sb3-contrib 


Collecting stable-baselines3[extra]
  Downloading stable_baselines3-1.8.0-py3-none-any.whl (174 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m174.5/174.5 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
Collecting rdkit
  Downloading rdkit-2023.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.5/29.5 MB[0m [31m28.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sb3-contrib
  Downloading sb3_contrib-1.8.0-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
Collecting importlib-metadata~=4.13
  Downloading importlib_metadata-4.13.0-py3-none-any.whl (23 kB)
Collecting gym==0.21
  Downloading gym-0.21.0.tar.gz (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... 

In [2]:
import time
import random
import numpy as np

import gym
from gym import spaces
from gym.envs.registration import register

import torch as th
import torch.nn as nn

from stable_baselines3 import DQN,TD3,PPO,SAC
from stable_baselines3.common.vec_env import DummyVecEnv
# from stable_baselines3.dqn.policies import MlpPolicy, CnnPolicy
# from stable_baselines3.sac.policies import MlpPolicy, CnnPolicy
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.policies import MaskableActorCriticCnnPolicy
from stable_baselines3.ppo.policies import MlpPolicy, CnnPolicy
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from sklearn.model_selection import train_test_split
# from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.policies import ActorCriticPolicy
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical


from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs

# Load DATA

## .h5 with loader

In [3]:
INPUT_SHAPE = (1,1024)


y_train.max()
y_train.min()

In [5]:
# # x_train= np.load("./data/x_train.npy")
# # y_train= np.load("./data/y_train.npy")
x= np.load("/kaggle/input/mddr-multitask/mddr_x.npy")
y= np.load("/kaggle/input/mddr-multitask/mddr_y.npy")

# .h5 without loader


# Model

In [7]:
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
# x,y = x_train,y_train
x= x.reshape((x.shape[0],)+INPUT_SHAPE)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, shuffle=True,random_state=42)
# x_train , y_train= shuffle(x, y, random_state=42 )

del x
del y

In [8]:
fingerprint_size = 1024
class_num = y_train.shape[1]


31

### Custom Gym Env

In [10]:
from sb3_contrib.common.wrappers import ActionMasker
class MutliClassEnv(gym.Env):
    def __init__(self, images_per_episode=1, dataset=(x_train, y_train), random=True,mode="train"):
        super().__init__()

        self.action_space = gym.spaces.MultiBinary(class_num)
        
        self.observation_space = gym.spaces.Box(low=0,
                                                high=1,
                                                shape=INPUT_SHAPE,
                                                dtype=np.float32
                                                )
        self.images_per_episode = images_per_episode
        self.step_count = 0

        self.x, self.y = dataset
        self.random = random
        self.dataset_idx = 0
        self.mode = mode
        self.expected_action = self.y[0]

    def step(self, action):
        done = False
        
        mae = abs(action-self.expected_action)
        self.expected_action = np.array(self.expected_action)
        reward1 = len(np.where((action == self.expected_action) &(self.expected_action==1))[0]) 
        reward0 = len(np.where((action == self.expected_action) &(self.expected_action==0))[0])
        reward = (-(class_num - reward0) +reward0)*reward1
     
        obs = self._next_obs()

        self.step_count += 1
        if self.step_count >= self.images_per_episode:
            done = True

        return obs, reward, done, {"mae":mae,"acc":reward1,"rew1":reward1,"rew0":reward0}

    def reset(self):
        self.step_count = 0
        # if self.mode != "train":
        #     self.dataset_idx = 0
        
        obs = self._next_obs()
        return obs

    def _next_obs(self):
        if self.random:
            next_obs_idx = random.randint(0, self.x.shape[0]-1)
            shape_len = self.x.shape
            self.expected_action = self.y[next_obs_idx]
            obs = self.x[next_obs_idx]

        else:

            if self.dataset_idx >= len(self.x) -1:
                if self.mode == "train":
                    self.dataset_idx = 0
                else :
                    raise StopIteration()
            else :
                self.dataset_idx += 1
            obs = self.x[self.dataset_idx]
            self.expected_action = self.y[self.dataset_idx]
            

        return obs
    
        

In [11]:
env = MutliClassEnv(images_per_episode=1)

### CNN Feature extractor class

In [12]:
import torch
import torch.nn as nn

# Define the multi head attention class
class MultiHeadAttention(nn.Module):
  # Initialize the multi head attention
    def __init__(self, d_model, n_heads):
        # Call the parent class constructor
        super(MultiHeadAttention, self).__init__()
        # Check if the model dimension is divisible by the number of heads
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        # Define the dimension of each head
        self.d_head = d_model // n_heads
        # Define the number of heads
        self.n_heads = n_heads
        # Define the linear layers for the query, key and value projections
        self.q_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        # Define the linear layer for the output projection
        self.out_linear = nn.Linear(d_model, d_model)
        # Define the scaling factor
        self.scale = torch.sqrt(torch.tensor(self.d_head, dtype=torch.float32))
  
      # Define the forward pass of the multi head attention
    def forward(self, q, k, v):
        # Get the batch size from the query
        batch_size = q.size(0)
        # Project the query, key and value using the linear layers
        q = self.q_linear(q)
        k = self.k_linear(k)
        v = self.v_linear(v)
        # Reshape the query, key and value to split the heads
        q = q.view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(batch_size, -1, self.n_heads, self.d_head).transpose(1, 2)
        # Compute the attention scores using scaled dot product
        scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale
        # Apply softmax to get the attention weights
        weights = torch.softmax(scores, dim=-1)
        # Compute the weighted sum of the values
        output = torch.matmul(weights, v)
        # Reshape the output to concatenate the heads
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_head)
        # Project the output using the linear layer
        output = self.out_linear(output)
        # Return the output and the weights
        return output, weights

# Define the feature extractor class
fp_size=1024
class CustomMHA(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.MultiBinary, features_dim: int = 1024):
        super().__init__(observation_space, features_dim)
  
        #
        # Define the model dimension
        d_model = features_dim
        # Define the number of heads
        n_heads = 8
        # Define the linear layer that maps the fingerprint to a hidden vector
        self.linear = nn.Linear(fp_size, d_model)
        # Define the activation function
        self.relu = nn.ReLU()
        # Define the multi head attention layer
        self.mha = MultiHeadAttention(d_model, n_heads)
        self.linear_out = nn.Linear(d_model,features_dim)

  # Define the forward pass of the feature extractor
    def forward(self, x:th.Tensor)->th.Tensor:
        # Convert the input to a tensor
        # x = torch.tensor(x, dtype=torch.float32)
        # Apply the linear layer and the activation function
        x = self.relu(self.linear(x))
        # Apply the multi head attention layer (use x as query, key and value)
        x, _ = self.mha(x, x ,x)
        # Return the output vector 
        x = x.squeeze(1)
#         print(x.shape)
        return x


In [13]:
class CustomCNN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.MultiBinary, features_dim: int = 512):
        super().__init__(observation_space, features_dim)
        
        n_input_channels = observation_space.shape[0]
        # n_input_channels = 1
        # print(n_input_channels)
        self.cnn = nn.Sequential(
            nn.Conv1d(n_input_channels, 32,
                      kernel_size=8, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv1d(32, 64,
                      kernel_size=4, stride=2, padding=0),
            nn.MaxPool1d(1,2),
            nn.ReLU(),
            nn.Conv1d(64, 32, kernel_size=4, stride=2, padding=0),
            nn.MaxPool1d(1,2),
            nn.ReLU(),
            nn.Flatten(),
        )
        with th.no_grad():
            n_flatten = self.cnn(
                th.as_tensor(observation_space.sample()[None]).float()
            ).shape[1]
        
        self.linear = nn.Sequential(
            nn.Linear(n_flatten, 2048),
            nn.ReLU(),
            nn.Linear(2048,features_dim),
            nn.ReLU()
#             nn.Sigmoid()
            )
        d_model = features_dim
        n_heads = 8
        self.mha = MultiHeadAttention(d_model, n_heads)



    def forward(self, observations: th.Tensor) -> th.Tensor:
#         x = self.linear(self.cnn(observations))
        x = self.linear(self.cnn(observations))
        x,_ = self.mha(x,x,x)
        
#         x = nn.Linear(2048,self.features_dim)(x)
        
        x = x.squeeze(1)
       
        return x

### Custom Network model

### Custom Actor Critic Policy

In [14]:
fingerprint_size=1024
policy_kwargs = dict(
    features_extractor_class=CustomCNN,
    features_extractor_kwargs=dict(features_dim=fingerprint_size),
    
)

# Declare Agent

In [15]:
def mask_fn(env: MutliClassEnv) -> np.ndarray:
    return env.valid_action_mask()

# env = ActionMasker(env,mask_fn)
agent = PPO (
            CnnPolicy,
            env,
            verbose=1,
            learning_rate=0.00003,
            batch_size= 32,
            gamma=1,
            n_steps=32,
            tensorboard_log="./board/",           
            policy_kwargs=policy_kwargs
            )

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [16]:
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import CheckpointCallback

# Save a checkpoint every 1000 steps
checkpoint_callback = CheckpointCallback(
  save_freq=100000,
  save_path="./logs/",
  name_prefix="rl_model",
  # save_replay_buffer=True,
  save_vecnormalize=True,
)

## Train Agent

In [17]:
import torch as th
from stable_baselines3 import DQN,PPO

env = MutliClassEnv(images_per_episode=1)

In [18]:
try :
    agent.learn(total_timesteps=int(5e6),
            log_interval=100,
            tb_log_name="ppo_mddr_31c_cnn+mha_5e6",
#             callback=checkpoint_callback
            )
except StopIteration:
    print("stopped")
    # agent.save("RL_chembl_700k_random_False")

Logging to ./board/ppo_mddr_31c_cnn+mha_5e6_1
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 1             |
|    ep_rew_mean          | 2.53          |
| time/                   |               |
|    fps                  | 163           |
|    iterations           | 100           |
|    time_elapsed         | 19            |
|    total_timesteps      | 3200          |
| train/                  |               |
|    approx_kl            | 0.00023646094 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -19.3         |
|    explained_variance   | 0.000225      |
|    learning_rate        | 3e-05         |
|    loss                 | 9.4           |
|    n_updates            | 990           |
|    policy_gradient_loss | -0.0049       |
|    value_loss           | 18.9          |
-------------------------------------------
------------------------------

In [19]:
agent.save("./RL_PPO_mddr_(-n+2r0)*r1_cnn+mha_5e6tstep")

## Test Agent

# Load Model

In [None]:
attempts, correct,corr1 = 0, 0,0

env = MutliClassEnv(images_per_episode=1,random=False,mode="test")
i =1
maes = []
try:
    while True:
        obs, done = env.reset(), False
        while not done:
            obs, rew, done, info = env.step(agent.predict(obs[None],deterministic=False)[0])
            maes.append(info["mae"])
            attempts += 1
            if (info["rew1"] +info["rew0"]) == class_num:
                correct +=1
#          
            i +=1
        # env.render("rgb_array")
except StopIteration:
    print()
    print('validation done...')
    print('Accuracy: {0}'.format((float(correct) / attempts)))
    print("mean absolute error : ",np.mean(maes))

In [None]:
attempts, correct,corr1 = 0, 0,0

env = MutliClassEnv(images_per_episode=1, dataset=(x_test, y_test),random=False,mode="test")
i =1
maes = []
try:
    while True:
        obs, done = env.reset(), False
        while not done:
            obs, rew, done, info = env.step(agent.predict(obs[None],deterministic=False)[0])
            maes.append(info["mae"])
            attempts += 1
            if (info["rew1"] +info["rew0"]) == class_num:
                correct +=1
            i +=1
except StopIteration:
    print()
    print('test validation done...')
    print('Accuracy: {0}'.format((float(correct) / attempts)))
    print("mean absolute error : ",np.mean(maes))

In [23]:
def accuracy_vec_train(agent):
    acc = [0]*31
    c = [0]*31
    for ix,xt in enumerate(x_train):
        yp = agent.predict(xt)[0]

        yt = y_train[ix]
        for i,a in enumerate(acc):
   
            if yp[i]==yt[i]:
                    acc[i] += 1

    return  np.array(acc)/np.array([x_train.shape[0]]*31)
acct1= accuracy_vec_train(agent)
print("train class accuracy:", acct1)
print("mean : ",acct1.mean(),
     )

train class accuracy: [0.99808515 0.99310653 0.97102189 0.99617029 0.97951107 0.93298015
 0.99712772 0.97587285 0.89659794 0.9822557  0.99604264 0.99208527
 0.98448969 0.99814898 0.98602157 0.98544712 0.99055339 0.99795749
 0.99795749 0.98212804 0.99942554 0.99387247 0.98544712 0.99182996
 0.99814898 0.9310653  0.95602221 0.99297887 0.9930427  0.96942618
 0.93336312]
mean :  0.9799414013840473


In [24]:
def accuracy_vec_test(agent):
    acc = [0]*31
    c = [0]*31
    for ix,xt in enumerate(x_test):
        yp = agent.predict(xt)[0]

        yt = y_test[ix]
        for i,a in enumerate(acc):
    
            if yp[i]==yt[i]:
                    acc[i] += 1

    return  np.array(acc)/np.array([x_test.shape[0]]*31)
def accuracy1_vec_test(agent):
    acc = [0]*31
    c = [0]*31
    for ix,xt in enumerate(x_test):
        yp = agent.predict(xt)[0]

        yt = y_test[ix]
        for i,a in enumerate(acc):
            if yt[i] ==0:
                c[i] +=1
                if yp[i]==yt[i]:
                    acc[i] += 1


    return np.array(acc)/np.array(c)

acc1 =accuracy_vec_test(agent)
print("test class accuracy:",acc1 )
print("mean : ",acc1)

test class accuracy: [0.99642584 0.99131989 0.95608884 0.99514935 0.96527955 0.93311208
 0.9902987  0.95659944 0.89839163 0.97498085 0.98953281 0.990554
 0.9805974  0.99617054 0.98468215 0.98187388 0.98519275 0.99591524
 0.99642584 0.97676793 0.99897881 0.98697983 0.96783252 0.97855502
 0.99719173 0.91881542 0.94791933 0.98927751 0.98468215 0.95736533
 0.92264488]
mean :  [0.99642584 0.99131989 0.95608884 0.99514935 0.96527955 0.93311208
 0.9902987  0.95659944 0.89839163 0.97498085 0.98953281 0.990554
 0.9805974  0.99617054 0.98468215 0.98187388 0.98519275 0.99591524
 0.99642584 0.97676793 0.99897881 0.98697983 0.96783252 0.97855502
 0.99719173 0.91881542 0.94791933 0.98927751 0.98468215 0.95736533
 0.92264488]
