# Phase 2: Feature Engineering (The Encoding Core)

We convert the cleaned strings from Phase 1 into the numerical tensors needed for the GNN and CNN modules.

We'll create two main helper functions: one for the Drug (SMILES) using RDKit and PyTorch Geomet-ric, and one for the Target (Sequence) using simple one-hot encoding.


In [None]:
import torch
import numpy as np
from rdkit import Chem
from torch_geometric.data import Data
from typing import List, Dict, Any, Tuple

### 1. Drug Encoding (SMILES = Graph Data Object)
This function converts a SMILES string into a torch_geometric.data.Data object, which includes a node feature matrix (X), an edge index (E), and edge features (Eattr).
##### Helper Functions: Feature Definitions
We first define the exact numerical features we want to extract from each atom and bond, typically using one-hot encoding.

In [None]:
# --- 1.1. Feature Maps (Defining the Encoding Space) ---

# All possible values for Atom Features
ATOM_FEATS = {
    'atomic_num': list(range(1, 19)), # Atomic number up to Neon
    'degree': [0, 1, 2, 3, 4, 5],
    'formal_charge': [-2, -1, 0, 1, 2],
    'chirality': [Chem.ChiralType.CHI_UNSPECIFIED, Chem.ChiralType.CHI_TETRAHEDRAL_CW, Chem.ChiralType.CHI_TETRAHEDRAL_CCW, Chem.ChiralType.CHI_OTHER],
    'hybridization': [Chem.HybridizationType.SP, Chem.HybridizationType.SP2, Chem.HybridizationType.SP3, Chem.HybridizationType.SP3D, Chem.HybridizationType.SP3D2, Chem.HybridizationType.UNSPECIFIED],
    'num_hs': [0, 1, 2, 3, 4],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}

# All possible values for Bond Features
BOND_FEATS = {
    'bond_type': [Chem.BondType.SINGLE, Chem.BondType.DOUBLE, Chem.BondType.TRIPLE, Chem.BondType.AROMATIC],
    'stereo': [Chem.BondStereo.STEREONONE, Chem.BondStereo.STEREOANY, Chem.BondStereo.STEREOZ, Chem.BondStereo.STEREOE, Chem.BondStereo.STEREOCIS, Chem.BondStereo.STEREOTRANS],
    'is_conjugated': [False, True],
    'is_in_ring': [False, True],
}

def one_hot_encode(value: Any, choices: List[Any]) -> List[int]:
    """Simple one-hot encoding helper."""
    encoding = [0] * (len(choices) + 1) # +1 for 'Other' or 'Unknown' category
    try:
        index = choices.index(value)
    except ValueError:
        index = len(choices) # Map unknown/unlisted value to the 'Other' index
    encoding[index] = 1
    return encoding

def get_atom_features(atom: Chem.Atom) -> List[int]:
    """Generates a concatenated atom feature vector."""
    features = []
    features += one_hot_encode(atom.GetAtomicNum(), ATOM_FEATS['atomic_num'])
    features += one_hot_encode(atom.GetTotalDegree(), ATOM_FEATS['degree'])
    features += one_hot_encode(atom.GetFormalCharge(), ATOM_FEATS['formal_charge'])
    features += one_hot_encode(atom.GetChiralTag(), ATOM_FEATS['chirality'])
    features += one_hot_encode(atom.GetHybridization(), ATOM_FEATS['hybridization'])
    features += one_hot_encode(atom.GetTotalNumHs(), ATOM_FEATS['num_hs'])
    features += one_hot_encode(atom.GetIsAromatic(), ATOM_FEATS['is_aromatic'])
    features += one_hot_encode(atom.IsInRing(), ATOM_FEATS['is_in_ring'])
    return features

def get_bond_features(bond: Chem.Bond) -> List[int]:
    """Generates a concatenated bond feature vector."""
    features = []
    features += one_hot_encode(bond.GetBondType(), BOND_FEATS['bond_type'])
    features += one_hot_encode(bond.GetStereo(), BOND_FEATS['stereo'])
    features += one_hot_encode(bond.GetIsConjugated(), BOND_FEATS['is_conjugated'])
    features += one_hot_encode(bond.IsInRing(), BOND_FEATS['is_in_ring'])
    return features

# --- 1.2. Main SMILES to PyG Data Function ---

def smiles_to_graph(smiles: str, label: float) -> Data:
    """
    Converts a SMILES string to a torch_geometric.data.Data object.
    
    Args:
        smiles (str): The canonical SMILES string.
        label (float): The pIC50 or binary label associated with the molecule.
        
    Returns:
        Data: A PyG Data object containing X, E_index, E_attr, and y.
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None # Handles invalid SMILES gracefully

    # --- 1. Node Features (X) ---
    atom_features_list = [get_atom_features(atom) for atom in mol.GetAtoms()]
    x = torch.tensor(atom_features_list, dtype=torch.float)

    # --- 2. Edge Index (E) and Edge Attributes (E_attr) ---
    edge_indices, edge_attrs = [], []
    
    # Iterate over all bonds to get the edges and their features
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        features = get_bond_features(bond)
        
        # Add two directed edges for each undirected bond
        edge_indices += [[i, j], [j, i]]
        edge_attrs += [features, features]
        
    if not edge_indices:
        # Handle molecules with no bonds (e.g., single atom)
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, len(get_bond_features(Chem.MolFromSmiles('C-C').GetBondWithIdx(0)))), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_indices).t().contiguous()
        edge_attr = torch.tensor(edge_attrs, dtype=torch.float)

    # --- 3. Label (y) ---
    y = torch.tensor([label], dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

# Example usage (for testing the encoding)
sample_smiles = 'CC(=O)Oc1ccccc1C(=O)O' # Aspirin
sample_label = 7.0 
aspirin_graph = smiles_to_graph(sample_smiles, sample_label)
# print(f"Aspirin Graph Node Features (X shape): {aspirin_graph.x.shape}") 
# print(f"Aspirin Graph Edge Index (E shape): {aspirin_graph.edge_index.shape}")


### 2. Target Encoding (Sequence = 1D Tensor)
This function converts the protein's amino acid sequence into a One-Hot Encoded (OHE) matrix, which is the canonical input for a 1D Convolutional Neural Network (CNN).

In [None]:
# --- 2.1. Feature Map (Defining the Amino Acid Vocabulary) ---

# 20 standard amino acids + 'X' for unknown/padding (21 total)
AMINO_ACIDS = 'ACDEFGHIKLMNPQRSTVWYX' 
AA_DICT = {aa: i for i, aa in enumerate(AMINO_ACIDS)}

# --- 2.2. Main Sequence Encoding Function ---

def sequence_to_ohe_matrix(sequence: str, max_len: int = 1200) -> torch.Tensor:
    """
    Converts a protein sequence to a zero-padded, one-hot encoded matrix.
    
    Args:
        sequence (str): The amino acid sequence.
        max_len (int): The fixed length to pad/truncate the sequence to.
        
    Returns:
        torch.Tensor: The encoded matrix of shape (1, max_len, 21).
    """
    # 1. Truncate or Pad the sequence
    seq_padded = sequence[:max_len].ljust(max_len, 'X')
    
    # 2. Initialize the OHE matrix
    ohe_matrix = np.zeros((max_len, len(AMINO_ACIDS)), dtype=np.float32)
    
    # 3. Perform One-Hot Encoding
    for i, aa in enumerate(seq_padded):
        idx = AA_DICT.get(aa, AA_DICT['X']) # Use 'X' index for unknown AAs
        ohe_matrix[i, idx] = 1.0
        
    # 4. Convert to PyTorch Tensor
    # Reshape to (1, max_len, 21) -> (Batch_size=1, Sequence_Length, Features)
    ohe_tensor = torch.tensor(ohe_matrix, dtype=torch.float).unsqueeze(0)
    
    return ohe_tensor

# Example usage (for testing the encoding)
sample_sequence = "MKTWETLLV"
# For real targets, max_len should be chosen based on the distribution of your dataset.
sample_target_tensor = sequence_to_ohe_matrix(sample_sequence, max_len=100) 
# print(f"Target Sequence Tensor shape: {sample_target_tensor.shape}")


### 3. Custom PyTorch Dataset (The Integrator)
Finally, we wrap these functions into a custom Dataset class to handle the batching of heterogeneous data (graphs and tensors) for training.

In [None]:
from torch.utils.data import Dataset

class DTIDataset(Dataset):
    """
    Custom PyTorch Dataset for Drug-Target Interaction (DTI) pairs.
    Handles encoding of both molecule (SMILES) and protein (Sequence).
    """
    def __init__(self, dataframe: pd.DataFrame, max_len: int = 1200):
        self.df = dataframe.reset_index(drop=True)
        self.max_len = max_len
        self.processed_data = self._process_data()

    def _process_data(self) -> List[Tuple[Data, torch.Tensor]]:
        """Encodes all SMILES and Sequences in the DataFrame."""
        print("Starting heterogeneous data encoding...")
        data_list = []
        for index, row in self.df.iterrows():
            smiles = row['standard_smiles']
            sequence = row['target_sequence']
            label = row['label'] # Assuming 'label' is the target variable (pIC50 or binary)
            
            # 1. Encode Drug (SMILES -> PyG Data object)
            drug_graph = smiles_to_graph(smiles, label)
            if drug_graph is None:
                continue

            # 2. Encode Target (Sequence -> OHE Tensor)
            target_tensor = sequence_to_ohe_matrix(sequence, max_len=self.max_len).squeeze(0)
            
            # Store the encoded pair (Graph object, Tensor, Label)
            # The label is already in drug_graph.y, but we pass the sequence tensor explicitly
            data_list.append((drug_graph, target_tensor))

        print(f"Finished encoding. {len(data_list)} valid pairs processed.")
        return data_list

    def __len__(self):
        return len(self.processed_data)

    def __getitem__(self, idx):
        # Returns the drug graph object and the target sequence tensor
        return self.processed_data[idx]

# --- Example of creating the dataset (assuming 'train_data' from Phase 1) ---
# train_dataset = DTIDataset(train_data, max_len=1000) 
# first_pair = train_dataset[0] # first_pair is a tuple (drug_graph, target_tensor)

# Note: The next step (Phase 3) will require defining a custom PyTorch DataLoader 
# and a 'collate_fn' to properly batch these heterogeneous (Graph + Tensor) objects.

# You can learn how to build the custom molecular graph data object for PyTorch Geometric
# by following a tutorial on [How to turn a SMILES string into a molecular graph for Pytorch Geometric].