# Cancer Cell Line Feature Extraction with a Transformer Encoder

In this Jupyter notebook, a Transformer Encoder is used to obtain cancer cell line feature representations.

<br>

### File Requirements

The following file is required in the `data/STEP00` folder:

1. [Cell_line_RMA_proc_basalExp.txt](https://www.cancerrxgene.org/gdsc1000/GDSC1000_WebResources/Home.html)

<br>

### Output
The trained model "CCL_TRANSFORMER.pth".

<br>

### Evaluation
Visualization of the learning curve and output of performance metrics.


In [None]:
import pandas as pd
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy.stats import pearsonr

In [None]:
np.random.seed(42)  # Fixed seed for reproducibility
torch.manual_seed(42)

In [None]:
ccl_file = pd.read_csv("data/STEP00/Cell_line_RMA_proc_basalExp.txt", sep="\t")

# Select data from row 1 (index 1) onwards and column 2 (index 2) onwards
ccl_df = ccl_file.iloc[:, 2:].T

# Drop duplicates and convert to a list
rna_values = ccl_df.drop_duplicates()
rna_values = rna_values.values.tolist()

# Transformer Encoder
PyTorch built-in modules are feasible but the scratch implementation yielded better results, most likely due to the MultiHeadAttention module.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Ensure that the embedding dimension is divisible by the number of heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        
        # Linear projection to compute the query, key, and value for attention
        self.qkv_proj = nn.Linear(embed_dim, embed_dim * 3)  # Projects input to 3 * embed_dim (query, key, value)
        
        # Linear projection to map the attention output back to the embedding dimension
        self.o_proj = nn.Linear(embed_dim, embed_dim)
        
        # Dropout layer to apply regularization during training
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Get the batch size, sequence length, and embedding dimension from input
        batch_size, seq_length, embed_dim = x.size()
        
        # Project the input to query, key, and value vectors using the qkv projection
        qkv = self.qkv_proj(x)
        
        # Reshape and permute the projected qkv to separate heads
        # The shape of qkv becomes (batch_size, seq_length, num_heads, 3 * head_dim)
        qkv = qkv.view(batch_size, seq_length, self.num_heads, 3 * self.head_dim)
        
        # Permute to get the shape (num_heads, batch_size, 3 * head_dim, seq_length)
        qkv = qkv.permute(2, 0, 3, 1)
        
        # Split the qkv tensor into individual query, key, and value tensors
        q, k, v = qkv.chunk(3, dim=2)
        
        # Scaled dot-product attention
        attn_scores = torch.matmul(q.transpose(-1, -2), k) / self.head_dim**0.5
        
        # Apply softmax to the attention scores to get attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)
        
        # Compute the attention output by multiplying the attention weights with values
        attn_output = torch.matmul(attn_weights, v.transpose(-1, -2)).transpose(-1, -2)
        
        # Reshape the attention output to match the input shape (batch_size, seq_length, embed_dim)
        attn_output = attn_output.contiguous().view(batch_size, seq_length, embed_dim)
        
        # Project the attention output back to the embedding dimension
        attn_output = self.o_proj(attn_output)
        
        # Apply dropout regularization to the output
        attn_output = self.dropout(attn_output)
        
        return attn_output


class TransformerEncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, ff_hidden_dim, dropout=0.1):
        super(TransformerEncoderBlock, self).__init__()
        
        # Multi-head attention layer
        self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
        
        # Layer normalization 
        self.norm1 = nn.LayerNorm(embed_dim)
        
        # Feed-forward network with hidden layer size ff_hidden_dim
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, ff_hidden_dim),
            nn.ReLU(),
            nn.Linear(ff_hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )
        
        # Layer normalization
        self.norm2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Apply multi-head attention to input
        attn_output = self.attention(x)
        
        # Residual connection: Add input and attention output, then normalize
        x = self.norm1(x + attn_output)
        
        # Apply feed-forward network
        ff_output = self.ff(x)
        
        # Residual connection: Add input and feed-forward output, then normalize
        x = self.norm2(x + ff_output)
        
        return x


class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        
        # Linear layer to convert input dimension to embedding dimension
        self.embedding = nn.Linear(input_dim, embed_dim)
        
        # Stack of transformer encoder blocks
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, ff_hidden_dim, dropout)
        for _ in range(num_layers)])
        
        # Output linear layer to project embedding back to input dimension
        self.fc_out = nn.Linear(embed_dim, input_dim)

    def forward(self, x):
        # Add sequence length dimension
        x = x.unsqueeze(1)
        
        # Apply embedding layer
        x = self.embedding(x)
        
        # Pass through all transformer encoder blocks
        for layer in self.layers:
            x = layer(x)
        
        # Apply output linear layer to get final result
        x = self.fc_out(x)
        
        # Remove sequence length dimension
        x = x.squeeze(1)
        
        return x

# Training

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create training and validation splits
data = torch.tensor(rna_values, dtype=torch.float32)
train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
train_dataset = TensorDataset(train_data)
val_dataset = TensorDataset(val_data)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Hyperparameters
input_dim = 17737   # Number of genes expression values
embed_dim = 512
num_heads = 8
ff_hidden_dim = 1024
num_layers = 4
dropout = 0.5
epochs = 300

# Early stopping parameters
patience = 20  # Number of epochs to wait for improvement
min_delta = 1e-4  # Minimum change to qualify as an improvement
best_val_loss = float("inf")
patience_counter = 0

train_losses = []
val_losses = []

# Model definition
model = TransformerEncoder(input_dim, embed_dim, num_heads, ff_hidden_dim, num_layers, dropout).to(device)
optimizer = optim.AdamW(model.parameters(), lr=0.0008752047398730367)
criterion = nn.MSELoss()


# Training loop
for epoch in range(epochs):
    model.train()
    epoch_train_loss = 0
    
    for batch in train_loader:
        x_batch = batch[0].to(device)
        
        # Forward pass
        outputs = model(x_batch)
        loss = criterion(outputs, x_batch)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        epoch_train_loss += loss.item()
        
    epoch_train_loss /= len(train_loader)
    train_losses.append(epoch_train_loss)

    # Validation loss
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            x_batch = batch[0].to(device)
            outputs = model(x_batch)
            loss = criterion(outputs, x_batch)
            epoch_val_loss += loss.item()
    epoch_val_loss /= len(val_loader)
    val_losses.append(epoch_val_loss)

    print(f"Epoch {epoch + 1}, Training Loss: {epoch_train_loss:.4f}, Validation Loss: {epoch_val_loss:.4f}")

    # Early stopping
    if epoch_val_loss < best_val_loss - min_delta:
        best_val_loss = epoch_val_loss
        patience_counter = 0
        torch.save(model.state_dict(), "models/CCL_TRANSFORMER.pth")  # Save the best model
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"Early stopping triggered at epoch {epoch + 1}")
        break

print("Training complete")

In [None]:
torch.save(model.state_dict(), "models/CCL_TRANSFORMER.pth")

## Evaluation: Learning curve and performance metrics

In [None]:
#plt.style.use("classic")
plt.style.use("seaborn-v0_8-ticks")
plt.rc("font", family="Times New Roman", size=12)

fig, ax = plt.subplots(figsize=(8, 6))
ax.tick_params(axis="both", which="both", direction="in", length=6, width=1)

ax.plot(train_losses, label="Training Loss", linewidth=1.5)
ax.plot(val_losses, label="Validation Loss", linewidth=1.5)


ax.set_xlabel("Training Epochs")
ax.set_ylabel("Mean Squared Error Loss")
ax.set_ylim(0.15, 0.8)
ax.legend()

plt.title("TransformerEncoder")
fig.savefig("transformer_loss.pdf", format="pdf", bbox_inches="tight")
plt.show()

In [None]:
def calculate_mse(model, data_loader, device):
    model.eval()
    mse_list = []
    
    with torch.no_grad():
        for batch in data_loader:
            x_batch = batch[0].to(device)
            reconstructed = model(x_batch)
            mse = mean_squared_error(x_batch.cpu().numpy(), reconstructed.cpu().numpy())
            mse_list.append(mse)
    return np.mean(mse_list)


def calculate_mae(model, data_loader, device):
    model.eval()
    mae_list = []
    
    with torch.no_grad():
        for batch in data_loader:
            x_batch = batch[0].to(device)
            reconstructed = model(x_batch)
            mae = mean_absolute_error(x_batch.cpu().numpy(), reconstructed.cpu().numpy())
            mae_list.append(mae)
    return np.mean(mae_list)


def calculate_pcc(model, data_loader, device):
    model.eval()
    pcc_list = []
    
    with torch.no_grad():
        for batch in data_loader:
            x_batch = batch[0].to(device)
            reconstructed = model(x_batch)
            x_original = x_batch.cpu().numpy()
            x_reconstructed = reconstructed.cpu().numpy()
            
            # Calculate Pearson correlation coefficient for each gene
            for gene_idx in range(x_original.shape[1]):
                r, _ = pearsonr(x_original[:, gene_idx], x_reconstructed[:, gene_idx])
                pcc_list.append(r)
    return np.mean(pcc_list)

In [None]:
print(f"Training MSE: {calculate_mse(model, train_loader, device):.4f}")
print(f"Validation MSE: {calculate_mse(model, val_loader, device):.4f}")
print(f"Training RMSE: {math.sqrt(calculate_mse(model, train_loader, device)):.4f}")
print(f"Validation RMSE: {math.sqrt(calculate_mse(model, val_loader, device)):.4f}")
print(f"Training MAE: {calculate_mae(model, train_loader, device):.4f}")
print(f"Validation MAE: {calculate_mae(model, val_loader, device):.4f}")
print(f"Training PCC: {calculate_pcc(model, train_loader, device):.4f}")
print(f"Validation PCC: {calculate_pcc(model, val_loader, device):.4f}")