# Phase 6: Web Application Deployment (app.py)

Deploying your DTI Model as a web application makes it a complete, practical project. We'll use Streamlit for rapid frontend development and assume the inference pipeline is contained in the backend logic.

This code must be saved in a single file (e.g., app.py) for easy deployment on platforms like Streamlit Cloud.

##### Prerequisites
You'll need to install the following libraries, including the ones used for RDKit and PyTorch Geo-metric in your inference functions:

In [None]:
pip install streamlit torch torch_geometric rdkit pandas numpy

This code assumes that:
- The architecture classes (DrugGNN, TargetCNN, DTIModel, custom_collate, etc.) from Phases 2 and 3 are accessible (ideally, they are included at the top of this script).
- The best model weights are saved at dti_model_best.pt.

Due to the size, the Phase 3 Model Architecture classes (DrugGNN, TargetCNN, DTIModel) and the Phase 2 Encoding functions (smiles_to_graph, sequence_to_ohe_matrix, custom_collate) must be defined at the top of the file but are represented here by comments for brevity.

In [None]:
import streamlit as st
import torch
import torch.nn as nn
import numpy as np
from rdkit import Chem
from rdkit.Chem import Draw
from io import BytesIO
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, global_max_pool
from typing import List, Tuple, Any

# --- 1. CONFIGURATION AND MODEL CONSTANTS (Must match Phase 4 training) ---

# DUMMY CONSTANTS: Replace with your actual trained model parameters
DRUG_IN_FEAT = 71
TARGET_IN_FEAT = 21
EMBEDDING_DIM = 128
HIDDEN_DIM = 64
GNN_LAYERS = 3
CNN_KERNEL_SIZE = 8
MAX_LEN = 1200
CHECKPOINT_PATH = 'dti_model_best.pt'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- 2. MODEL ARCHITECTURE DEFINITIONS (Paste your full classes here) ---

# NOTE: For a runnable app.py, the complete code for 
# DrugGNN, TargetCNN, DTIModel, smiles_to_graph, sequence_to_ohe_matrix, 
# and custom_collate from Phases 2 & 3 must be defined here.

# --- Example of a placeholder class (REPLACE WITH REAL CODE) ---
class DrugGNN(nn.Module):
    # ... full implementation from Phase 3 ...
    def __init__(self, *args, **kwargs):
        super().__init__()
        # Placeholder for real model layers
        self.final_lin = nn.Linear(64, kwargs.get('embedding_dim', 128)) 
    def forward(self, data):
        # Placeholder logic
        x = torch.randn(data.num_graphs, self.final_lin.in_features).to(data.x.device)
        return self.final_lin(x)

class TargetCNN(nn.Module):
    # ... full implementation from Phase 3 ...
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.final_lin = nn.Linear(256, kwargs.get('embedding_dim', 128))
    def forward(self, x):
        x = x.permute(0, 2, 1) 
        x = torch.randn(x.size(0), self.final_lin.in_features).to(x.device)
        return self.final_lin(x)
        
class DTIModel(nn.Module):
    # ... full implementation from Phase 3 ...
    def __init__(self, *args, **kwargs):
        super().__init__()
        emb_dim = kwargs['embedding_dim']
        self.drug_encoder = DrugGNN(*args, **kwargs)
        self.target_encoder = TargetCNN(*args, **kwargs)
        self.fnn = nn.Linear(emb_dim * 2, 1)
    def forward(self, drug_data, target_tensor):
        v_d = self.drug_encoder(drug_data)
        v_p = self.target_encoder(target_tensor)
        v_pair = torch.cat([v_d, v_p], dim=1)
        score = torch.sigmoid(self.fnn(v_pair))
        return score

# --- 3. INFERENCE FUNCTION (The Core API Logic) ---

@st.cache_resource
def load_model():
    """Loads the trained DTI model only once."""
    model = DTIModel(
        drug_in_feat=DRUG_IN_FEAT, target_in_feat=TARGET_IN_FEAT, hidden_dim=HIDDEN_DIM,
        gnn_layers=GNN_LAYERS, cnn_kernel_size=CNN_KERNEL_SIZE, embedding_dim=EMBEDDING_DIM
    ).to(DEVICE)
    
    # Load weights if checkpoint exists
    if os.path.exists(CHECKPOINT_PATH):
        try:
            model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=DEVICE))
            st.success("Model weights loaded successfully!")
        except Exception as e:
            st.warning(f"Could not load model weights: {e}. Model running uninitialized.")
    else:
        st.warning(f"Checkpoint not found at {CHECKPOINT_PATH}. Model running uninitialized.")
        
    model.eval()
    return model

def run_prediction(model, smiles: str, sequence: str) -> float:
    """Encodes inputs and runs the single prediction."""
    
    # --- 3.1. Encoding ---
    # NOTE: smiles_to_graph and sequence_to_ohe_matrix must be defined
    
    # Encode drug
    drug_graph = smiles_to_graph(smiles, label=0) # Label is irrelevant for inference
    if drug_graph is None:
        return None, "Invalid SMILES string."
    
    # Encode target
    target_tensor = sequence_to_ohe_matrix(sequence, max_len=MAX_LEN).squeeze(0)
    
    # --- 3.2. Batching (for single sample) ---
    # Collate function is not strictly necessary for BATCH_SIZE=1, but 
    # Batch.from_data_list is still needed for the GNN input.
    drug_batch = Batch.from_data_list([drug_graph]).to(DEVICE)
    target_batch = target_tensor.unsqueeze(0).to(DEVICE)
    
    # --- 3.3. Inference ---
    with torch.no_grad():
        prediction_score = model(drug_batch, target_batch).item()
        
    return prediction_score, None

# --- 4. STREAMLIT FRONTEND ---

def main():
    st.set_page_config(
        page_title="DTI Predictor",
        layout="wide",
        initial_sidebar_state="expanded"
    )

    st.title("üî¨ Advanced Drug-Target Interaction (DTI) Predictor")
    st.markdown("---")
    
    # Load model once at startup
    model = load_model()

    with st.sidebar:
        st.header("Input Examples")
        st.code("Drug (SMILES):\nCC(=O)Oc1ccccc1C(=O)O", language='text')
        st.code("Target (Sequence):\nMKTWETLLVALL", language='text')
        st.markdown("**Note:** Run the app locally and place `dti_model_best.pt` in the same directory.")
        st.markdown("---")


    st.header("Input Data")
    
    col1, col2 = st.columns(2)

    with col1:
        smiles_input = st.text_area(
            "Drug Molecule (SMILES String)",
            value="CC(=O)Oc1ccccc1C(=O)O", # Aspirin example
            height=100
        )
        
    with col2:
        sequence_input = st.text_area(
            "Protein Target (Amino Acid Sequence)",
            value="MKTWETLLVALLAALITL", # Example start of a membrane protein
            height=100
        )
        
    st.markdown("---")

    if st.button("RUN DTI PREDICTION", type="primary", use_container_width=True):
        if not smiles_input or not sequence_input:
            st.error("Please provide both a SMILES string and a Protein Sequence.")
            return

        with st.spinner('Running GNN and CNN Inference...'):
            prediction, error = run_prediction(model, smiles_input.strip(), sequence_input.strip())

        if error:
            st.error(f"Prediction Error: {error}")
            return
            
        st.header("Prediction Results")
        
        # Determine the outcome and color
        probability_pct = prediction * 100
        is_binder = probability_pct >= 50
        result_color = "green" if is_binder else "red"
        
        st.markdown(
            f"""
            <div style="text-align: center; padding: 20px; border: 2px solid {result_color}; border-radius: 10px;">
                <h2 style="color: {result_color}; margin: 0;">
                    {'‚úÖ PREDICTED BINDER' if is_binder else '‚ùå PREDICTED NON-BINDER'}
                </h2>
                <h1 style="font-size: 3em; margin-top: 10px;">
                    {probability_pct:.2f}%
                </h1>
                <p>Confidence Score (Binding Probability)</p>
            </div>
            """, 
            unsafe_allow_html=True
        )
        
        st.subheader("Input Visualization")
        
        try:
            mol = Chem.MolFromSmiles(smiles_input.strip())
            if mol:
                img = Draw.MolToImage(mol)
                buf = BytesIO()
                img.save(buf, format="PNG")
                st.image(buf.getvalue(), caption=f"Predicted 2D Structure of {smiles_input.strip()}", width=300)
            else:
                st.warning("Could not render molecule image (RDKit parsing failed).")
        except Exception:
            st.warning("Could not render molecule image.")


if __name__ == '__main__':
    # NOTE: You must include the full definitions of all model and encoding 
    # helper functions (DrugGNN, TargetCNN, etc.) before calling main().
    main()

### How to Run the App Locally

1.	Save the complete, runnable code (including all Phase 2 and 3 definitions) as app.py.

2.	Ensure your trained model file (dti_model_best.pt) is in the same directory.

3.	Open your terminal in that directory and run:


In [None]:
streamlit run app.py

This command will open the web application in your browser, completing your project deployment!