# Phase 3: Model Architecture Design and Implementation (The Model Core)

Building the architecture involves defining three components: the Drug Encoder (GNN), the Target Encoder (CNN), and the Prediction Head (Fusion & FNN).

Since you have two heterogeneous inputs (a Graph from the drug and a Tensor from the protein), we must also define a custom collate_fn for the DataLoader.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool
from torch_geometric.data import Batch
from torch.utils.data import DataLoader
from typing import List, Tuple

### 1. Drug Encoder (Graph Neural Network - GNN)
We'll use a simple Graph Convolutional Network (GCN) followed by a global pooling operation to gen-erate a fixed-size drug feature vector (VD).

In [None]:
class DrugGNN(nn.Module):
    def __init__(self, in_features, hidden_dim, gnn_layers, embedding_dim):
        super(DrugGNN, self).__init__()
        # Initial linear layer to project input features to hidden_dim
        self.initial_lin = nn.Linear(in_features, hidden_dim)
        
        # Stack multiple GCN layers
        self.convs = nn.ModuleList([
            GCNConv(hidden_dim, hidden_dim) for _ in range(gnn_layers)
        ])
        
        # Final linear layer to project pooled features to the final embedding dimension
        self.final_lin = nn.Linear(hidden_dim, embedding_dim)
        
        self.dropout = nn.Dropout(0.2)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # 1. Initial Feature Projection
        x = self.initial_lin(x)
        x = F.relu(x)
        
        # 2. GNN Layers
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        
        # 3. Global Pooling (Graph Classification Layer)
        # global_max_pool aggregates node features (x) across all nodes in each graph (batch)
        x = global_max_pool(x, batch)
        
        # 4. Final Embedding
        v_d = self.final_lin(x)
        return v_d # V_D: Drug Feature Vector (Shape: Batch_size x embedding_dim)

### 2. Target Encoder (1D Convolutional Neural Network - CNN)
This module processes the padded, one-hot encoded amino acid sequence matrix to generate the protein feature vector (VP).

In [None]:
class TargetCNN(nn.Module):
    def __init__(self, in_features, hidden_channels, kernel_size, embedding_dim):
        super(TargetCNN, self).__init__()
        
        # The input tensor is (Batch_size, Sequence_Length, in_features=21)
        # CNN expects (Batch_size, Channels, Length), so we need to transpose:
        
        # 1. 1D Convolutional Layers
        self.conv1 = nn.Conv1d(in_channels=in_features, out_channels=hidden_channels, ker-nel_size=kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv1d(hidden_channels, hidden_channels * 2, kernel_size=kernel_size, pad-ding=kernel_size//2)
        self.conv3 = nn.Conv1d(hidden_channels * 2, hidden_channels * 4, kernel_size=kernel_size, padding=kernel_size//2)
        
        # 2. Final Pooling Layer
        # We use Global Max Pooling to get a fixed-size vector regardless of sequence length
        self.global_pool = nn.AdaptiveMaxPool1d(1)
        
        # 3. Final Linear Layer to project pooled features to the final embedding dimension
        self.final_lin = nn.Linear(hidden_channels * 4, embedding_dim)
        self.dropout = nn.Dropout(0.2)

    def forward(self, sequence_tensor):
        # Transpose the input for Conv1D: (B, L, F) -> (B, F, L)
        x = sequence_tensor.permute(0, 2, 1)
        
        # 1. Convolution Blocks
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.dropout(x)
        
        # 2. Global Pooling: (B, F_out, L) -> (B, F_out, 1)
        x = self.global_pool(x).squeeze(-1)
        
        # 3. Final Embedding
        v_p = self.final_lin(x)
        return v_p # V_P: Protein Feature Vector (Shape: Batch_size x embedding_dim)


### 3. Combined DTI Prediction Model (Fusion Head)
This class combines the two encoders and uses a simple Concatenation Fusion followed by a Feed-Forward Network (FNN) for the final prediction.

In [None]:
class DTIModel(nn.Module):
    def __init__(self, drug_in_feat, target_in_feat, hidden_dim, gnn_layers, cnn_kernel_size, embed-ding_dim, fc_layers=2):
        super(DTIModel, self).__init__()
        
        # Encoders
        self.drug_encoder = DrugGNN(drug_in_feat, hidden_dim, gnn_layers, embedding_dim)
        self.target_encoder = TargetCNN(target_in_feat, hidden_dim, cnn_kernel_size, embedding_dim)
        
        # Fusion Head (Predictor)
        # Input size is 2 * embedding_dim due to concatenation (V_D + V_P)
        self.fc_input_size = embedding_dim * 2
        
        self.fnn = nn.Sequential(
            nn.Linear(self.fc_input_size, self.fc_input_size // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(self.fc_input_size // 2, 1) # Final output is a single score
        )

    def forward(self, drug_data, target_tensor):
        # 1. Encode Drug and Target
        v_d = self.drug_encoder(drug_data)
        v_p = self.target_encoder(target_tensor)
        
        # 2. Feature Fusion: Concatenation
        v_pair = torch.cat([v_d, v_p], dim=1) # Shape: (Batch_size, 2 * embedding_dim)
        
        # 3. Prediction Head
        # Use sigmoid for classification (probability output)
        score = torch.sigmoid(self.fnn(v_pair))
        
        return score

### 4. Custom Collate Function (The Glue)
The DataLoader needs a custom function to batch the PyG Data objects (graphs) and the regular PyTorch Tensor (sequences) simultaneously.

In [None]:
def custom_collate(batch: List[Tuple[Data, torch.Tensor]]) -> Tuple[Batch, torch.Tensor, torch.Tensor]:
    """
    Custom collate function to create a batch for DTI data.
    
    Args:
        batch: A list of tuples: [(drug_graph_1, target_tensor_1), (drug_graph_2, target_tensor_2), ...]
        
    Returns:
        Tuple: (Batched Drug Data, Batched Target Tensor, Batched Labels)
    """
    # 1. Separate the components
    drug_graphs, target_tensors = zip(*batch)
    
    # 2. Batch Drug Graphs using PyG's built-in Batch class
    # This correctly stacks node/edge features and creates the necessary 'batch' vector
    drug_batch = Batch.from_data_list(drug_graphs)
    
    # 3. Batch Target Sequences using torch.stack
    # Since all target tensors were already padded to max_len in Phase 2, this is straightforward
    target_batch = torch.stack(target_tensors, dim=0) # Shape: (B, L, F)
    
    # 4. Extract Labels (y is stored in drug_graphs.y)
    labels = torch.cat([g.y for g in drug_graphs], dim=0) # Shape: (B, 1)

    return drug_batch, target_batch, labels

# Note: You would initialize your DataLoader like this:
# from torch.utils.data import DataLoader
# dti_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate)

This completes the architecture setup, providing all the necessary classes to initialize, train, and test your heterogeneous DTI model.