# Drug Response Prediction Model

The goal is to develop a model that leverages the zero-shot capabilities of CLIP to identify the best-responding match based on the LN IC50 for a given new drug or cancer cell line.

This model trains a CCL encoder and a SMILES encoder to map cancer cell lines and drug compounds in a common embedding space, where pairs with lower LN IC50 values are expected be closer to each other, indicating a better drug response. 

A continuously weighted contrastive loss based on LN IC50 values is used in the training process.

<br>

### File Requirements
1. The `CCL_SMILES_IC50.csv` file from STEP00_DataProcessor.ipynb.
2. The `ccl_lookup.pkl` and `smiles_lookup.pkl` files from STEP00_DataProcessor.ipynb.
3. The `ChEMBL_SMILES_2kk.csv` file from [ChEMBL](https://www.ebi.ac.uk/chembl/explore/compounds/)
4. The `SPE_ChEMBL_1500freq.txt` file from STEP01_SPE.ipynb.
5. The `SMILES_ENCODER_50k.pth` file from STEP01_SMILES_Encoder.ipynb.
6. The `CCL_TRANSFORMER.pth` file from STEP02_Transformer.ipynb.

I recommend that training is not carried out with the entire data set, as this requires a considerable amount of time and does not change the result. Therefore, for debugging purposes, I use only the first 50,000 pairs for training.

In [None]:
!pip install timm

In [None]:
import pandas as pd
import torch
import torch.nn as nn
import itertools
import numpy as np
import torch.nn.functional as F
import torchtext; torchtext.disable_torchtext_deprecation_warning()
import math
import pickle
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
from tqdm import tqdm
from torchmetrics import MeanMetric

## Data Loading


In [None]:
df = pd.read_csv("data/STEP00/CCL_SMILES_IC50.csv")

cosmic_id = df["COSMIC_ID"]
smiles_id = df["DRUG_ID"]
ic50_index = df["index"]

with open("data/STEP00/ccl_lookup.pkl", "rb") as f:
    cosmic_dict = pickle.load(f)
    
with open("data/STEP00/smiles_lookup.pkl", "rb") as f:
    smiles_dict = pickle.load(f)

df = df[:50000]

lookup_table = df.set_index(["COSMIC_ID", "DRUG_ID"])["LN_IC50"].to_dict()

## Config

In [None]:
class Config:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    batch_size = 64
    weight_decay = 1e-3
    patience = 1
    factor = 0.8
    epochs = 5
    temperature = 1.0
    lr = 1e-3
    smiles_encoder_lr = 1e-3
    ccl_encoder_lr = 1e-3
    projection_dim = 256
    dropout = 0.1
    
    IC50_min = df["LN_IC50"].min()
    IC50_max = df["LN_IC50"].max()

## Tokenization and Vocabulary

In [None]:
# Vocab training data
vocab_df = pd.read_csv("data/STEP01/ChEMBL_SMILES_2kk.csv")

# Load custom tokens from SMILES PAIR Encoding
with open("data/STEP01/SPE_ChEMBL_1500freq.txt", "r") as f:
    custom_tokens = [line.strip().split()[0] for line in f]

# Select 50000 SMILES strings for vocab training, needs to match the training data 
# Alternative: Use a larger maximum vocabulary size and add more paddings
smiles_training = vocab_df.iloc[:50000, 0]


def custom_tokenizer(smiles_string):
    """
    Tokenizes a SMILES string by splitting it into tokens based on custom tokens
    :param smiles_string: The SMILES string to tokenize.
    :return: list: A list of tokens.
    """
    for token in custom_tokens:
        if token in smiles_string:
            smiles_string = smiles_string.replace(token, f' {token} ')
    return smiles_string.split()


def yield_tokens(data_iter):
    """
    Generator function to yield tokens from a data iterator
    :param data_iter: An iterable of SMILES strings.
    :return: A list of tokens for each SMILES string.
    """
    for text in data_iter:
        yield custom_tokenizer(text)


# Build vocabulary with special tokens
# It is also possible to save and load the vocabulary
vocab = build_vocab_from_iterator(yield_tokens(smiles_training), specials=["<pad>", "<unk>", "<sos>", "<eos>"])
vocab.set_default_index(vocab["<unk>"])


def encode_smiles_to_indices(smiles):
    """
    Encodes a SMILES string into a sequence of integer indices.
    :param smiles: The SMILES string to encode.
    :return: A list of integer indices.
    """
    return [vocab["<sos>"]] + [vocab[token] for token in custom_tokenizer(smiles)] + [vocab["<eos>"]]


# Function to pad sequences to the same length
max_len = max(len(encode_smiles_to_indices(smile)) for smile in smiles_training)


def pad_sequence(seq, max_len):
    """
    Pads a sequence to a specified length with the <pad> token.
    :param seq: The sequence to pad.
    :param max_len: The maximum length.
    :return: The padded sequence.
    """
    return seq + [vocab["<pad>"]] * (max_len - len(seq))


def smiles_to_padded_tensor(smiles):
    """
    Encodes a SMILES string into a padded PyTorch tensor of integer indices.
    :param smiles: The SMILES string to encode.
    :return: The encoded SMILES string as a PyTorch tensor.
    """
    sequence = encode_smiles_to_indices(smiles)
    sequence = pad_sequence(sequence, max_len)
    return torch.tensor(sequence).unsqueeze(0)

## CCL Transformer Encoder

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Ensure that the embedding dimension is divisible by the number of heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        
        # Linear projection to compute the query, key, and value for attention
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)  # Projects input to 3 * embed_dim (query, key, value)
        
        # Linear projection to map the attention output back to the embedding dimension
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        
        # Dropout layer to apply regularization during training
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Get the batch size, sequence length, and embedding dimension from input
        batch_size, seq_length, embed_dim = x.size()
        
        # Project the input to query, key, and value vectors using the qkv projection
        qkv = self.qkv_proj(x)
        
        # Reshape and permute the projected qkv to separate heads
        # The shape of qkv becomes (batch_size, seq_length, num_heads, 3 * head_dim)
        qkv = qkv.view(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        
        # Permute to get the shape (num_heads, batch_size, 3 * head_dim, seq_length)
        qkv = qkv.permute(2, 0, 3, 1)
        
        # Split the qkv tensor into individual query, key, and value tensors
        q, k, v = qkv.chunk(3, dim=2)
        
        # Scaled dot-product attention
        attn_scores = torch.matmul(q.transpose(-1, -2), k) / self.head_dim**0.5
        
        # Apply softmax to the attention scores to get attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Compute the attention output by multiplying the attention weights with values
        attn_output = torch.matmul(attn_weights, v.transpose(-1, -2)).transpose(-1, -2)
        
        # Reshape the attention output to match the input shape (batch_size, seq_length, embed_dim)
        attn_output = attn_output.contiguous().view(batch_size, seq_length, embed_dim)
        
        # Project the attention output back to the embedding dimension
        attn_output = self.o_proj(attn_output)
        
        # Apply dropout regularization to the output
        attn_output = self.dropout(attn_output)
        
        return attn_output


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        
        # Multi-head attention layer
        self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        
        # Layer normalization 
        self.norm1 = nn.LayerNorm(embed_dim)
        
        # Feed-forward network with hidden layer size ff_hidden_dim
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
        # Layer normalization
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Apply multi-head attention to input
        attn_output = self.attention(x)
        
        # Residual connection: Add input and attention output, then normalize
        x = self.norm1(x + attn_output)
        
        # Apply feed-forward network
        ff_output = self.ff(x)
        
        # Residual connection: Add input and feed-forward output, then normalize
        x = self.norm2(x + ff_output)
        
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        # Linear layer to convert input dimension to embedding dimension
        self.embedding = nn.Linear(input_dim, embed_dim)
        
        # Stack of transformer encoder blocks
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_hidden_dim, dropout)
        for _ in range(num_layers)])
        
        # Output linear layer to project embedding back to input dimension
        self.fc_out = nn.Linear(embed_dim, input_dim)

    def forward(self, x):
        # Add sequence length dimension
        x = x.unsqueeze(1)
        
        # Apply embedding layer
        x = self.embedding(x)
        
        # Pass through all transformer encoder blocks
        for layer in self.layers:
            x = layer(x)
        
        # Apply output linear layer to get final result
        x = self.fc_out(x)
        
        # Remove sequence length dimension
        x = x.squeeze(1)
        
        return x
    
    # Function to extract embeddings
    def extract_embeddings(self, x):
        # Add sequence length dimension
        x = x.unsqueeze(1)
        
        # Apply embedding layer
        x = self.embedding(x) 
        
        # Pass through each transformer encoder block
        for layer in self.layers:
            x = layer(x)
            
        # Remove sequence length dimension
        x = x.squeeze(1) 
        return x

## SMILES Transformer Encoder

In [None]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embed_size, num_heads, num_layers, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.pos_encoder = nn.TransformerEncoderLayer(embed_size, num_heads, dim_feedforward=512, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(self.pos_encoder, num_layers)
        self.fc_out = nn.Linear(embed_size, vocab_size)
        self.embed_size = embed_size

    def forward(self, src, src_mask=None):
        src = self.embedding(src) * math.sqrt(self.embed_size)
        src = self.transformer_encoder(src, src_mask)
        output = self.fc_out(src)
        return output

## Load the pre-trained models

In [None]:
# Load the SMILES Transformer Encoder
SMILES_encoder = TransformerModel(len(vocab), 512, 8, 3, 0.1)
SMILES_encoder.load_state_dict(torch.load("models/SMILES_ENCODER_50k.pth"))
SMILES_encoder.eval()

# Load the CCL Transformer Encoder
CCL_encoder = TransformerEncoder(input_dim=17737, embed_dim=512, num_heads=8, ff_hidden_dim=1024, num_layers=4, dropout=0.5)
CCL_encoder.load_state_dict(torch.load("models/CCL_TRANSFORMER.pth"))
CCL_encoder.eval()

## Create the dataset

In [None]:
class DRPDataset(torch.utils.data.Dataset):
    def __init__(self, cosmic_id, smiles_id, ic50_index, vocab):
        self.cosmic_id = list(cosmic_id)
        self.smiles_id = list(smiles_id)
        self.ic50_index = list(ic50_index)
        self.vocab = vocab

    def tokenize_smiles(self, smiles):
        # Tokenize, encode, and pad the SMILES string
        sequence = encode_smiles_to_indices(smiles)
        sequence = pad_sequence(sequence, max_len)
        return torch.tensor(sequence)

    def __getitem__(self, idx):
        item = {"cosmic_id": self.cosmic_id[idx], "smiles_id": self.smiles_id[idx]}
    

        # Get the SMILES string from the ID by using the lookup dictionary
        smiles_string = smiles_dict[self.smiles_id[idx]]
        
        # Encode the SMILES string and pass it through the SMILES encoder
        smiles_tensor = self.tokenize_smiles(smiles_string)
        item["smiles_tokens"] = smiles_tensor
    
        # Get the RNA values from the ID by using the lookup dictionary
        rna_values = cosmic_dict[self.cosmic_id[idx]]
        
        # Convert the RNA values to a tensor
        if isinstance(rna_values, str):
            rna_values = [float(x) for x in rna_values.strip('[]').split(',')]
        item['rna_values'] = torch.tensor(rna_values, dtype=torch.float32) 
        
        # Get the IC50 index
        item["ic50_index"] = torch.tensor(ic50_index[idx])
    
        return item

    def __len__(self):
        return len(self.smiles_id)

## ProjectionHead, DRPModel, Weight Matrix and CWCL Loss

In [None]:
class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim=Config.projection_dim, dropout=Config.dropout):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        x = x + projected
        x = self.layer_norm(x)
        return x


class DRPModel(nn.Module):
    def __init__(self, temperature=1.0):
        super().__init__()
        self.CCL_encoder = CCL_encoder  
        self.SMILES_encoder = SMILES_encoder 

        self.RNA_projection = ProjectionHead(embedding_dim=512)  
        self.SMILES_projection = ProjectionHead(embedding_dim=101) # vocab size  

        self.temperature = temperature
        
    def forward(self, batch):
        # Step 1: Get feature representations from the encoders
        CCL_features = self.CCL_encoder.extract_embeddings(batch["rna_values"].float())
        SMILES_features = self.SMILES_encoder(batch["smiles_tokens"]) 
        
        # Average pooling over the sequence length to match dimensions (typical procedure)
        SMILES_features = SMILES_features.mean(dim=1)
        
        # Step 2: Get projection representations from the ProjectionHead
        CCL_embeddings = self.RNA_projection(CCL_features) # Shape: [batch_size, projection_dim]
        SMILES_embeddings = self.SMILES_projection(SMILES_features) # Shape: [batch_size, projection_dim]

        # Step 3: Create the weight matrix W
        W = generate_W(batch).to(Config.device)

        # Step 4: Compute the symmetric CWCL loss 
        loss_1 = cwcl_loss(CCL_embeddings, SMILES_embeddings, W.T, temperature=self.temperature).to(Config.device)
        loss_2 = cwcl_loss(SMILES_embeddings, CCL_embeddings, W, temperature=self.temperature).to(Config.device)
        loss = (loss_1 + loss_2) / 2

        return loss
    
    
# Generate the weight matrix W of shape [batch_size, batch_size] based on SMILES tokens
def generate_W(batch):

    # Extract data from batch
    smiles_ids = batch['smiles_id'].cpu().numpy()
    cosmic_ids = batch['cosmic_id'].cpu().numpy()

    # Initialize the weight matrix with NaN values
    batch_size = len(smiles_ids)
    weight_matrix = np.full((batch_size, batch_size), np.nan, dtype=np.float32)
    
    # Fill the weight matrix based on smiles_ids and cosmic_ids
    for i in range(batch_size):
        for j in range(batch_size):
            # Get the IDs for the pair (i, j)
            cosmic_id = cosmic_ids[j]  # j is the row
            smiles_id = smiles_ids[i]  # i is the column
            
            # Get the IC50 value from the lookup table
            weight_matrix[i, j] = lookup_table.get((cosmic_id, smiles_id), np.nan)
            

    # Min-Max normalization
    W_min = torch.tensor(Config.IC50_min)
    W_max = torch.tensor(Config.IC50_max)
    
    if W_max > W_min:
        W_reversed_normalized = (W_max - weight_matrix) / (W_max - W_min)
    else:
        W_reversed_normalized = weight_matrix 

    return torch.tensor(W_reversed_normalized)


# Computes the Continuously Weighted Contrastive Loss (CWCL), ignoring NaN values in the weights
def cwcl_loss(p, q, weights, temperature=1.0):
    
    # Create a mask to identify non-NaN weights
    mask = ~torch.isnan(weights)  # True for valid weights, False for NaN

    # Replace NaN in weights with 0 for safe computation (not used due to masking)
    safe_weights = torch.where(mask, weights, torch.tensor(0.0, device=weights.device))

    # Compute the pairwise dot products between p_i and q_j, divided by temperature
    logits = torch.matmul(p, q.T) / temperature 

    # Compute log-softmax of the logits (log of probabilities)
    log_softmax_logits = torch.log_softmax(logits, dim=1)  

    # Mask the logits using the valid entries in weights
    masked_log_probs = safe_weights * log_softmax_logits

    # Compute loss per sample (ignore NaN rows/columns)
    weight_sums = safe_weights.sum(dim=1)  # Sum weights for normalization
    loss_per_sample = -(masked_log_probs.sum(dim=1) / (weight_sums + 1e-8))  # Avoid divide-by-zero

    # Compute mean loss over all samples (ignoring rows where all weights are NaN)
    valid_rows = mask.any(dim=1)  # Check if a row has at least one valid weight
    loss = loss_per_sample[valid_rows].mean()  # Mean over rows with valid weights

    return loss

## Splits the data into training and validation sets

In [None]:
# Splits the input DataFrame into train and validation sets.
def split_dataset(df, debug=False, test_size=0.2, random_state=42, verbose=False):
    max_rows = 10000 if debug else len(df)
    df = df.iloc[:max_rows]  # Subset for debugging
    train_df, valid_df = train_test_split(df, test_size=test_size, random_state=random_state)
    
    if verbose:
        print(f"Train size: {len(train_df)}, Valid size: {len(valid_df)}")
    
    return train_df.reset_index(drop=True), valid_df.reset_index(drop=True)

# Builds a DataLoader for the given dataset.
def build_loaders(df, vocab, batch_size, num_workers=0, mode="train", shuffle=None):
    if shuffle is None:
        shuffle = mode == "train"  # Default behavior based on mode
    
    dataset = DRPDataset(
        cosmic_id=df["COSMIC_ID"],
        smiles_id=df["DRUG_ID"],
        ic50_index=df["index"],
        vocab=vocab
    )
    
    return DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        shuffle=shuffle,
        pin_memory=True,  # Improves performance when using GPUs
    )

## Training

In [None]:
# Runs a single epoch for training or validation and returns the average loss over the epoch
def run_epoch(model, data_loader, optimizer=None, lr_scheduler=None, step=None, training=True):
    phase = "Train" if training else "Valid"
    loss_meter = MeanMetric().to(Config.device)
    tqdm_object = tqdm(data_loader, total=len(data_loader), desc=f"{phase} Epoch", leave=True)

    for batch in tqdm_object:
        batch = {k: v.to(Config.device) for k, v in batch.items()}
        loss = model(batch)

        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if step == "batch" and lr_scheduler is not None:
                lr_scheduler.step()

        loss_meter.update(loss.item())
        tqdm_object.set_postfix(loss=loss_meter.compute())

    if training and step == "epoch" and lr_scheduler is not None:
        lr_scheduler.step(loss_meter.compute())

    return loss_meter.compute()


def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
    return run_epoch(model, train_loader, optimizer, lr_scheduler, step, training=True)


def valid_epoch(model, valid_loader):
    return run_epoch(model, valid_loader, training=False)


def main(save_model_path="models/DRP.pth", verbose=False):
    # Prepare datasets and loaders
    train_df, valid_df = split_dataset(df, debug=False, verbose=verbose)
    train_loader = build_loaders(train_df, vocab, batch_size=64, num_workers=0, mode="train")
    valid_loader = build_loaders(valid_df, vocab, batch_size=64, num_workers=0, mode="valid")

    # Initialize model, optimizer, and scheduler
    model = DRPModel().to(Config.device)
    params = [
        {"params": model.CCL_encoder.parameters(), "lr": Config.ccl_encoder_lr},
        {"params": model.SMILES_encoder.parameters(), "lr": Config.smiles_encoder_lr},
        {"params": itertools.chain(
            model.RNA_projection.parameters(), model.SMILES_projection.parameters()
        ), "lr": Config.lr, "weight_decay": Config.weight_decay}
    ]
    optimizer = torch.optim.AdamW(params, weight_decay=0.)
    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=Config.patience, factor=Config.factor
    )

    # Training loop
    best_loss = float("inf")
    for epoch in range(1, Config.epochs + 1):
        if verbose:
            print(f"Epoch {epoch}/{Config.epochs}")
        
        # Train and validate
        model.train()
        train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step="epoch")
        
        model.eval()
        with torch.no_grad():
            valid_loss = valid_epoch(model, valid_loader)

        if verbose:
            print(f"Train Loss: {train_loss:.4f} | Valid Loss: {valid_loss:.4f}")

        # Save the best model
        if valid_loss < best_loss:
            best_loss = valid_loss
            torch.save(model.state_dict(), save_model_path)
            if verbose:
                print(f"Saved Best Model to {save_model_path}!")

    print("Training complete. Best validation loss:", best_loss)

In [None]:
main()