In [14]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM

class SingleCellPerturbationMultiModalModel(nn.Module):
    def __init__(self, output_dim=5, num_heads=8):
        super().__init__()
        
        # Load ChemBERTa embeddings
        model = AutoModelForCausalLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1", is_decoder=True)
        chemberta_embeddings = model.roberta.embeddings
        self.smiles_embedding = chemberta_embeddings
        for param in self.smiles_embedding.parameters():
            param.requires_grad = False
        embedding_dim = self.smiles_embedding.word_embeddings.embedding_dim
        
        # 1D CNN components
        self.conv_block = nn.Sequential(
            nn.Conv1d(
                in_channels=embedding_dim + 6,  # Match combined feature dimension
                out_channels=64,
                kernel_size=9,
                padding=4  # Maintain sequence length
            ),
            nn.ReLU(),
            nn.BatchNorm1d(64),
            nn.Dropout(0.8)
        )
        
        # Learnable [CLS] token
        self.cls_token = nn.Parameter(torch.randn(1, 1, 64))  # Match CNN output channels
        
        # Transformer encoder (single layer)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=64,  # Match CNN output channels
                nhead=num_heads,
                dropout=0.8,
                batch_first=True
            ),
            num_layers=1
        )
        
        # Regression head
        self.head = nn.Sequential(
            nn.LayerNorm(64),
            nn.Linear(64, 32),
            nn.Sigmoid(),
            nn.Dropout(0.8),
            nn.Linear(32, output_dim)
        )

    def forward(self, input_ids, attention_mask, cell_type):
        # Get SMILES embeddings
        smile_emb = self.smiles_embedding(input_ids)
        
        # Prepare cell type features
        cell_type_emb = cell_type.unsqueeze(1).repeat(1, smile_emb.size(1), 1)
        
        # Combine features and apply CNN
        combined = torch.cat((smile_emb, cell_type_emb), dim=-1)
        combined = combined.permute(0, 2, 1)  # [batch, features, seq_len] for Conv1d
        conv_out = self.conv_block(combined).permute(0, 2, 1)  # Back to [batch, seq_len, features]
        
        # Add [CLS] token
        cls_tokens = self.cls_token.expand(conv_out.size(0), -1, -1)
        inputs = torch.cat([cls_tokens, conv_out], dim=1)
        
        # Update attention mask for [CLS]
        mask = torch.cat([torch.ones_like(attention_mask[:, :1]), attention_mask], dim=1)
        
        # Transformer processing
        out = self.encoder(
            inputs,  # Transformer expects [seq_len, batch, features]
            src_key_padding_mask=~mask.bool()
        )
        
        # Use [CLS] token for regression
        return self.head(out[:, 0])

In [15]:
from sklearn.preprocessing import MinMaxScaler
import pandas as pd

def preprocess(df, svd):
    """Preprocess gene expression data with normalization and SVD"""
    # Create a copy of the gene expression columns to avoid SettingWithCopyWarning
    gene_data = df.loc[:, 'A1BG':].copy()
    
    # Normalize each column between 0 and 1
    scaler = MinMaxScaler(feature_range=(0, 1))
    gene_data_normalized = scaler.fit_transform(gene_data)
    
    # Apply SVD transformation
    gene_expression_reduced = svd.fit_transform(gene_data_normalized)  # Shape: (num_samples, num_components)
    
    # Create new DataFrame with the required columns
    processed_df = pd.DataFrame({
        'cell_type': df['cell_type'].values,
        'SMILES': df['SMILES'].values,
        'gene_expressions': [list(row) for row in gene_expression_reduced]
    })
    
    return processed_df

In [16]:
from torch.utils.data import Dataset
from sklearn.preprocessing import OneHotEncoder
import torch
from transformers import AutoTokenizer
from rdkit import Chem
import random
import numpy as np

class SingleCellPerturbationDataset(Dataset):
    def __init__(self, smiles_list, cell_types, targets, augment_prob=0.8):  # Higher augmentation
        self.smiles_list = smiles_list
        self.augment_prob = augment_prob
        
        # Convert targets to torch tensor first
        targets_tensor = torch.tensor(targets, dtype=torch.float32)
        
        # Cell type encoding (now handles numpy/pandas inputs)
        cell_types_reshaped = cell_types.reshape(-1, 1) if isinstance(cell_types, np.ndarray) else cell_types.values.reshape(-1, 1)
        self.cell_types = torch.tensor(
            OneHotEncoder(sparse_output=False).fit_transform(cell_types_reshaped),
            dtype=torch.float32
        )
        
        # Target standardization (using torch operations)
        self.target_mean = torch.mean(targets_tensor, dim=0)
        self.target_std = torch.std(targets_tensor, dim=0)
        self.targets = (targets_tensor - self.target_mean) / (self.target_std + 1e-8)
        
        # Tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
        
    def augment_smiles(self, smiles):
        """More robust augmentation with coordinate randomization"""
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return smiles
                
            # Generate different 2D coordinates
            Chem.rdDepictor.Compute2DCoords(mol)
            return Chem.MolToSmiles(mol, doRandom=True, isomericSmiles=True, canonical=False)
        except:
            return smiles
        
    def __len__(self):
        """Required by DataLoader - returns number of samples"""
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = str(self.smiles_list[idx])
        if random.random() < self.augment_prob:
            smiles = self.augment_smiles(smiles)
            
        encoded = self.tokenizer(smiles, truncation=True, return_tensors="pt", padding="max_length", max_length=100)
        return {
            "input_ids": encoded["input_ids"].squeeze(0),  # [seq_len]
            "attention_mask": encoded["attention_mask"].squeeze(0),  # [seq_len]
            "cell_type": self.cell_types[idx],  # [6]
            "target": self.targets[idx]  # [100]
        }

In [17]:
def mean_rowwise_rmse(preds: np.ndarray, targets: np.ndarray) -> float:
    rowwise_mse = np.mean((preds - targets) ** 2, axis=1)  # Compute MSE per row
    rowwise_rmse = np.sqrt(rowwise_mse)  # Convert to RMSE
    return np.mean(rowwise_rmse)  # Average across all rows

class MeanRowwiseRMSELoss(nn.Module):
    def __init__(self, epsilon=1e-8):
        super().__init__()
        self.epsilon = epsilon  # Small value to prevent numerical issues

    def forward(self, preds: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        rowwise_mse = torch.mean((preds - targets) ** 2, dim=1)  # Compute MSE per row
        rowwise_rmse = torch.sqrt(rowwise_mse + self.epsilon)  # Convert to RMSE
        return torch.mean(rowwise_rmse)  # Average over all rows

In [18]:
def eval_model(model, val_loader, device):
    model.to(device)
    model.eval()
    
    predictions = []
    true_labels = []

    # Disable gradient computation during testing (saves memory and computation)
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            cell_type = batch["cell_type"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            target = batch["target"].to(device)
            
            # Forward pass
            output = model(input_ids, attention_mask, cell_type)
            
            # Collect predictions and true labels
            predictions.append(output.cpu().numpy())  # Moving output to CPU for storing as numpy
            true_labels.append(target.cpu().numpy())

    # Convert predictions and true labels to numpy arrays
    predictions = np.concatenate(predictions, axis=0)
    true_labels = np.concatenate(true_labels, axis=0)

    # Optionally, compute loss or other metrics on the test set
    # For example, using Mean Rowwise Root Mean Squared Error on the predictions:
    return mean_rowwise_rmse(true_labels, predictions)

In [19]:
from sklearn.decomposition import TruncatedSVD
import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader

torch.manual_seed(42)

df_test = pd.read_parquet("dataset/de_test_split.parquet")

# Preprocess the data with normalization and SVD of the gene expressions
svd = TruncatedSVD(n_components=5, random_state=42)
df_test = preprocess(df_test, svd)

test_dataset = SingleCellPerturbationDataset(
    smiles_list=df_test['SMILES'].to_numpy(),
    cell_types=df_test['cell_type'].to_numpy(),
    targets=np.array(df_test['gene_expressions'].tolist(), dtype=np.float32),
    augment_prob=0.0  # No augmentation for test set
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 4
print(f"Using device: {device}")

test_loader = DataLoader(test_dataset, batch_size=batch_size)

 # Initialize model,and load weights
model = SingleCellPerturbationMultiModalModel(output_dim=5).to(device)
model.load_state_dict(torch.load('best_model.pt'))

test_loss = eval_model(model, test_loader, device)

print(f"Test Mean Rowwise RMSE Loss: {test_loss:.4f}")

Using device: cuda
Test Mean Rowwise RMSE Loss: 0.5682
