# Data Preparation Tutorial

This notebook demonstrates how to prepare and preprocess data for the Unified DTA System.

## What you'll learn:
1. Load and validate datasets (KIBA, Davis, BindingDB)
2. Preprocess SMILES strings and protein sequences
3. Handle data quality issues
4. Create custom datasets
5. Data augmentation techniques

In [None]:
# Install required packages (run if needed)
# !pip install unified-dta rdkit-pypi torch-geometric transformers pandas matplotlib seaborn

import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
import warnings
warnings.filterwarnings('ignore')

# Add project root to path (for development)
sys.path.insert(0, os.path.join(os.getcwd(), '../..'))

from unified_dta.data import DTADataset, DataProcessor
from unified_dta.utils import validate_smiles, validate_protein_sequence

print("Environment setup complete!")

## 1. Loading Standard Datasets

The system supports three major DTA datasets: KIBA, Davis, and BindingDB.

In [None]:
# Load sample datasets
datasets = {}
dataset_names = ['kiba', 'davis', 'bindingdb']

for name in dataset_names:
    try:
        train_path = f'../../data/{name}_train.csv'
        test_path = f'../../data/{name}_test.csv'
        
        if os.path.exists(train_path):
            train_df = pd.read_csv(train_path)
            test_df = pd.read_csv(test_path) if os.path.exists(test_path) else None
            
            datasets[name] = {
                'train': train_df,
                'test': test_df
            }
            print(f"✓ Loaded {name.upper()} dataset: {len(train_df)} train samples")
            if test_df is not None:
                print(f"  Test samples: {len(test_df)}")
        else:
            print(f"✗ {name.upper()} dataset not found at {train_path}")
    except Exception as e:
        print(f"✗ Error loading {name.upper()}: {e}")

# Use KIBA as example if available, otherwise create sample data
if 'kiba' in datasets and datasets['kiba']['train'] is not None:
    sample_df = datasets['kiba']['train'].head(100)
    print(f"\nUsing KIBA dataset sample: {len(sample_df)} entries")
else:
    # Create sample data for demonstration
    sample_data = {
        'compound_iso_smiles': [
            'CCO',  # Ethanol
            'CC(=O)O',  # Acetic acid
            'CC(C)O',  # Isopropanol
            'C1=CC=CC=C1',  # Benzene
            'CCN(CC)CC',  # Triethylamine
        ],
        'target_sequence': [
            'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
            'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
            'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
            'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
            'MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG',
        ],
        'affinity': [5.2, 6.1, 4.8, 7.3, 5.9]
    }
    sample_df = pd.DataFrame(sample_data)
    print(f"\nCreated sample dataset: {len(sample_df)} entries")

print("\nDataset columns:", list(sample_df.columns))
sample_df.head()

## 2. Data Quality Analysis

Let's analyze the quality of our data and identify potential issues.

In [None]:
# Analyze data quality
print("=== Data Quality Analysis ===")
print(f"Total samples: {len(sample_df)}")
print(f"Missing values:")
print(sample_df.isnull().sum())
print()

# SMILES validation
valid_smiles = []
invalid_smiles = []

for idx, smiles in enumerate(sample_df['compound_iso_smiles']):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            valid_smiles.append(idx)
        else:
            invalid_smiles.append(idx)
    except:
        invalid_smiles.append(idx)

print(f"SMILES validation:")
print(f"  Valid: {len(valid_smiles)} ({len(valid_smiles)/len(sample_df)*100:.1f}%)")
print(f"  Invalid: {len(invalid_smiles)} ({len(invalid_smiles)/len(sample_df)*100:.1f}%)")
print()

# Protein sequence analysis
protein_lengths = [len(seq) for seq in sample_df['target_sequence']]
print(f"Protein sequence lengths:")
print(f"  Min: {min(protein_lengths)} residues")
print(f"  Max: {max(protein_lengths)} residues")
print(f"  Mean: {np.mean(protein_lengths):.1f} residues")
print(f"  Std: {np.std(protein_lengths):.1f} residues")
print()

# Affinity distribution
print(f"Affinity values:")
print(f"  Min: {sample_df['affinity'].min():.2f}")
print(f"  Max: {sample_df['affinity'].max():.2f}")
print(f"  Mean: {sample_df['affinity'].mean():.2f}")
print(f"  Std: {sample_df['affinity'].std():.2f}")

## 3. Molecular Property Analysis

Let's analyze the molecular properties of our compounds.

In [None]:
# Calculate molecular properties
properties = []

for smiles in sample_df['compound_iso_smiles']:
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            props = {
                'smiles': smiles,
                'molecular_weight': Descriptors.MolWt(mol),
                'logp': Descriptors.MolLogP(mol),
                'num_atoms': mol.GetNumAtoms(),
                'num_bonds': mol.GetNumBonds(),
                'num_rings': rdMolDescriptors.CalcNumRings(mol),
                'tpsa': Descriptors.TPSA(mol),
                'hbd': Descriptors.NumHDonors(mol),
                'hba': Descriptors.NumHAcceptors(mol)
            }
            properties.append(props)
    except Exception as e:
        print(f"Error processing {smiles}: {e}")

props_df = pd.DataFrame(properties)
print(f"Calculated properties for {len(props_df)} compounds")
props_df.head()

In [None]:
# Visualize molecular properties
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

properties_to_plot = ['molecular_weight', 'logp', 'num_atoms', 'tpsa', 'hbd', 'hba']
property_labels = ['Molecular Weight', 'LogP', 'Number of Atoms', 'TPSA', 'H-Bond Donors', 'H-Bond Acceptors']

for i, (prop, label) in enumerate(zip(properties_to_plot, property_labels)):
    if prop in props_df.columns:
        axes[i].hist(props_df[prop], bins=10, alpha=0.7, color='skyblue', edgecolor='black')
        axes[i].set_title(label)
        axes[i].set_xlabel(label)
        axes[i].set_ylabel('Frequency')
    else:
        axes[i].text(0.5, 0.5, f'{label}\nNo data', ha='center', va='center', transform=axes[i].transAxes)

plt.tight_layout()
plt.show()

## 4. Data Preprocessing

Now let's preprocess the data for model training.

In [None]:
# Initialize data processor
processor = DataProcessor(
    max_protein_length=200,  # Truncate proteins to 200 residues
    validate_smiles=True,
    remove_invalid=True
)

print("Processing data...")
processed_data = processor.process_dataframe(sample_df)

print(f"Original samples: {len(sample_df)}")
print(f"Processed samples: {len(processed_data)}")
print(f"Removed samples: {len(sample_df) - len(processed_data)}")

# Show processed data structure
if len(processed_data) > 0:
    sample_item = processed_data[0]
    print("\nProcessed data structure:")
    for key, value in sample_item.items():
        if hasattr(value, 'shape'):
            print(f"  {key}: {type(value).__name__} {value.shape}")
        else:
            print(f"  {key}: {type(value).__name__} - {str(value)[:50]}...")

## 5. Creating Custom Datasets

Learn how to create custom datasets for specific use cases.

In [None]:
# Create a custom dataset
from torch.utils.data import DataLoader

# Create dataset from processed data
dataset = DTADataset(processed_data)

print(f"Created dataset with {len(dataset)} samples")

# Create data loader
dataloader = DataLoader(
    dataset, 
    batch_size=2, 
    shuffle=True,
    collate_fn=dataset.collate_fn
)

print(f"Created dataloader with batch size 2")

# Test the dataloader
print("\nTesting dataloader...")
for i, batch in enumerate(dataloader):
    print(f"Batch {i+1}:")
    for key, value in batch.items():
        if hasattr(value, 'shape'):
            print(f"  {key}: {value.shape}")
        elif isinstance(value, list):
            print(f"  {key}: list of {len(value)} items")
        else:
            print(f"  {key}: {type(value).__name__}")
    
    if i >= 1:  # Only show first 2 batches
        break

## 6. Data Augmentation

Explore data augmentation techniques to improve model robustness.

In [None]:
# Data augmentation examples
from unified_dta.data import DataAugmenter

# Initialize augmenter
augmenter = DataAugmenter(
    smiles_augmentation=True,
    protein_augmentation=True,
    noise_level=0.1
)

# Example SMILES augmentation (canonical vs random)
original_smiles = "CCO"
print(f"Original SMILES: {original_smiles}")

# Generate augmented versions
augmented_smiles = augmenter.augment_smiles(original_smiles, n_variants=3)
print("Augmented SMILES:")
for i, smi in enumerate(augmented_smiles):
    print(f"  {i+1}. {smi}")

print()

# Protein sequence augmentation (adding noise)
original_protein = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
print(f"Original protein: {original_protein[:30]}...")

augmented_proteins = augmenter.augment_protein(original_protein, n_variants=2)
print("Augmented proteins:")
for i, prot in enumerate(augmented_proteins):
    print(f"  {i+1}. {prot[:30]}...")

## 7. Data Splitting Strategies

Learn about different data splitting strategies for DTA prediction.

In [None]:
from sklearn.model_selection import train_test_split
from unified_dta.data import DTASplitter

# Initialize splitter
splitter = DTASplitter()

# Random split (standard)
train_data, val_data, test_data = splitter.random_split(
    processed_data, 
    train_ratio=0.7, 
    val_ratio=0.15, 
    test_ratio=0.15,
    random_state=42
)

print("Random Split:")
print(f"  Train: {len(train_data)} samples")
print(f"  Validation: {len(val_data)} samples")
print(f"  Test: {len(test_data)} samples")
print()

# Protein-based split (no protein overlap between sets)
try:
    train_prot, val_prot, test_prot = splitter.protein_split(
        processed_data,
        train_ratio=0.7,
        val_ratio=0.15,
        test_ratio=0.15,
        random_state=42
    )
    
    print("Protein-based Split:")
    print(f"  Train: {len(train_prot)} samples")
    print(f"  Validation: {len(val_prot)} samples")
    print(f"  Test: {len(test_prot)} samples")
except Exception as e:
    print(f"Protein-based split not possible with current data: {e}")

print()

# Drug-based split (no drug overlap between sets)
try:
    train_drug, val_drug, test_drug = splitter.drug_split(
        processed_data,
        train_ratio=0.7,
        val_ratio=0.15,
        test_ratio=0.15,
        random_state=42
    )
    
    print("Drug-based Split:")
    print(f"  Train: {len(train_drug)} samples")
    print(f"  Validation: {len(val_drug)} samples")
    print(f"  Test: {len(test_drug)} samples")
except Exception as e:
    print(f"Drug-based split not possible with current data: {e}")

## 8. Data Export and Saving

Save processed data for future use.

In [None]:
import pickle
import json

# Save processed data
output_dir = "processed_data"
os.makedirs(output_dir, exist_ok=True)

# Save as pickle (preserves all data types)
with open(f"{output_dir}/processed_data.pkl", "wb") as f:
    pickle.dump(processed_data, f)
print(f"✓ Saved processed data to {output_dir}/processed_data.pkl")

# Save splits
splits = {
    'train': train_data,
    'val': val_data,
    'test': test_data
}

for split_name, split_data in splits.items():
    with open(f"{output_dir}/{split_name}_data.pkl", "wb") as f:
        pickle.dump(split_data, f)
    print(f"✓ Saved {split_name} split: {len(split_data)} samples")

# Save metadata
metadata = {
    'total_samples': len(processed_data),
    'train_samples': len(train_data),
    'val_samples': len(val_data),
    'test_samples': len(test_data),
    'max_protein_length': 200,
    'processing_date': pd.Timestamp.now().isoformat()
}

with open(f"{output_dir}/metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)
print(f"✓ Saved metadata to {output_dir}/metadata.json")

print(f"\nAll processed data saved to '{output_dir}/' directory")

## Summary

In this tutorial, you learned how to:

1. **Load standard datasets** (KIBA, Davis, BindingDB)
2. **Analyze data quality** and identify issues
3. **Calculate molecular properties** for compounds
4. **Preprocess data** for model training
5. **Create custom datasets** and data loaders
6. **Apply data augmentation** techniques
7. **Use different splitting strategies** (random, protein-based, drug-based)
8. **Save and export** processed data

## Next Steps

- **Model Training**: See `03_model_training.ipynb` to learn how to train models with your processed data
- **Advanced Configuration**: Explore `04_advanced_configuration.ipynb` for custom model architectures
- **Performance Optimization**: Check `05_performance_optimization.ipynb` for memory and speed optimization

## 💡 Tips for Success

- Always validate your SMILES strings before training
- Consider protein sequence length limits for memory efficiency
- Use appropriate splitting strategies based on your research goals
- Save processed data to avoid recomputation
- Monitor data quality throughout your pipeline