# Drug-Drug Interaction - Exploratory Data Analysis

This notebook explores the DrugBank DDI dataset to understand:
- Dataset statistics
- Interaction type distribution
- Molecular properties of drugs
- Graph structure characteristics

In [None]:
import sys
sys.path.insert(0, '..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter

# RDKit for molecular analysis
from rdkit import Chem
from rdkit.Chem import Descriptors, Draw

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

%matplotlib inline

## 1. Load Dataset

In [None]:
# Load DrugBank DDI dataset
try:
    from tdc.multi_pred import DDI
    data = DDI(name='DrugBank')
    df = data.get_data()
    print(f"Dataset loaded successfully!")
except Exception as e:
    print(f"Error loading data: {e}")
    print("Please install PyTDC: pip install PyTDC")

In [None]:
# Basic statistics
print(f"Total drug pairs: {len(df):,}")
print(f"Unique Drug1: {df['Drug1_ID'].nunique():,}")
print(f"Unique Drug2: {df['Drug2_ID'].nunique():,}")
print(f"Total unique drugs: {len(set(df['Drug1_ID']) | set(df['Drug2_ID'])):,}")
print(f"Interaction types: {df['Y'].nunique()}")

In [None]:
# Preview data
df.head()

## 2. Interaction Type Distribution

In [None]:
# Count interactions
interaction_counts = df['Y'].value_counts()

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Top 20 interactions
axes[0].barh(range(20), interaction_counts.head(20).values)
axes[0].set_yticks(range(20))
axes[0].set_yticklabels([f"Type {i}" for i in interaction_counts.head(20).index])
axes[0].set_xlabel('Count')
axes[0].set_title('Top 20 Interaction Types')
axes[0].invert_yaxis()

# Log distribution
axes[1].hist(interaction_counts.values, bins=50, edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Number of Samples per Class')
axes[1].set_ylabel('Number of Classes')
axes[1].set_title('Class Size Distribution')
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig('../paper/figures/class_distribution.png', dpi=150)
plt.show()

In [None]:
# Class imbalance statistics
print(f"Max class size: {interaction_counts.max():,}")
print(f"Min class size: {interaction_counts.min():,}")
print(f"Median class size: {interaction_counts.median():,.0f}")
print(f"Imbalance ratio: {interaction_counts.max() / interaction_counts.min():.1f}x")

## 3. Molecular Properties Analysis

In [None]:
# Get unique SMILES
all_smiles = set(df['Drug1'].unique()) | set(df['Drug2'].unique())
print(f"Total unique molecules: {len(all_smiles)}")

# Calculate molecular properties
properties = []
valid_smiles = []

for smiles in list(all_smiles)[:1000]:  # Sample for speed
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        props = {
            'smiles': smiles,
            'molecular_weight': Descriptors.MolWt(mol),
            'logp': Descriptors.MolLogP(mol),
            'tpsa': Descriptors.TPSA(mol),
            'num_atoms': mol.GetNumAtoms(),
            'num_bonds': mol.GetNumBonds(),
            'num_rings': Descriptors.RingCount(mol),
            'num_hba': Descriptors.NumHAcceptors(mol),
            'num_hbd': Descriptors.NumHDonors(mol),
            'num_rotatable': Descriptors.NumRotatableBonds(mol),
        }
        properties.append(props)
        valid_smiles.append(smiles)

props_df = pd.DataFrame(properties)
print(f"Valid molecules: {len(props_df)}")

In [None]:
# Property distributions
fig, axes = plt.subplots(2, 3, figsize=(14, 8))

props_to_plot = ['molecular_weight', 'logp', 'tpsa', 'num_atoms', 'num_rings', 'num_rotatable']
titles = ['Molecular Weight', 'LogP', 'TPSA', 'Number of Atoms', 'Number of Rings', 'Rotatable Bonds']

for ax, prop, title in zip(axes.flat, props_to_plot, titles):
    ax.hist(props_df[prop], bins=30, edgecolor='black', alpha=0.7)
    ax.set_xlabel(title)
    ax.set_ylabel('Count')
    ax.axvline(props_df[prop].median(), color='red', linestyle='--', label='Median')
    ax.legend()

plt.suptitle('Molecular Property Distributions', fontsize=14)
plt.tight_layout()
plt.savefig('../paper/figures/molecular_properties.png', dpi=150)
plt.show()

In [None]:
# Property statistics
props_df.describe()

## 4. Graph Structure Analysis

In [None]:
from src.data.featurizers import smiles_to_graph

# Convert sample molecules to graphs
graph_stats = []

for smiles in valid_smiles[:500]:
    graph = smiles_to_graph(smiles)
    if graph is not None:
        graph_stats.append({
            'num_nodes': graph.num_nodes,
            'num_edges': graph.num_edges,
            'avg_degree': graph.num_edges / graph.num_nodes if graph.num_nodes > 0 else 0,
            'node_feature_dim': graph.x.shape[1],
        })

graph_df = pd.DataFrame(graph_stats)
print(f"Graphs analyzed: {len(graph_df)}")

In [None]:
# Graph statistics
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].hist(graph_df['num_nodes'], bins=30, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Number of Nodes (Atoms)')
axes[0].set_ylabel('Count')
axes[0].set_title('Graph Size Distribution')

axes[1].hist(graph_df['num_edges'], bins=30, edgecolor='black', alpha=0.7, color='orange')
axes[1].set_xlabel('Number of Edges (Bonds)')
axes[1].set_ylabel('Count')
axes[1].set_title('Edge Count Distribution')

axes[2].hist(graph_df['avg_degree'], bins=30, edgecolor='black', alpha=0.7, color='green')
axes[2].set_xlabel('Average Degree')
axes[2].set_ylabel('Count')
axes[2].set_title('Average Node Degree')

plt.tight_layout()
plt.savefig('../paper/figures/graph_statistics.png', dpi=150)
plt.show()

In [None]:
# Graph statistics summary
print("Graph Statistics:")
print(f"  Node feature dimension: {graph_df['node_feature_dim'].iloc[0]}")
print(f"  Avg nodes per graph: {graph_df['num_nodes'].mean():.1f}")
print(f"  Avg edges per graph: {graph_df['num_edges'].mean():.1f}")
print(f"  Avg degree: {graph_df['avg_degree'].mean():.2f}")

## 5. Sample Molecular Visualizations

In [None]:
# Visualize sample molecules
sample_smiles = valid_smiles[:9]
mols = [Chem.MolFromSmiles(s) for s in sample_smiles]

img = Draw.MolsToGridImage(mols, molsPerRow=3, subImgSize=(300, 300))
img.save('../paper/figures/sample_molecules.png')
img

## 6. Drug Frequency Analysis

In [None]:
# Drug frequency in interactions
drug_counts = Counter(df['Drug1_ID'].tolist() + df['Drug2_ID'].tolist())
drug_freq = pd.Series(drug_counts)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Top 20 most frequent drugs
top_drugs = drug_freq.nlargest(20)
axes[0].barh(range(20), top_drugs.values)
axes[0].set_yticks(range(20))
axes[0].set_yticklabels(top_drugs.index)
axes[0].set_xlabel('Number of Interactions')
axes[0].set_title('Top 20 Most Frequent Drugs')
axes[0].invert_yaxis()

# Frequency distribution
axes[1].hist(drug_freq.values, bins=50, edgecolor='black', alpha=0.7)
axes[1].set_xlabel('Number of Interactions per Drug')
axes[1].set_ylabel('Number of Drugs')
axes[1].set_title('Drug Interaction Frequency Distribution')
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig('../paper/figures/drug_frequency.png', dpi=150)
plt.show()

## 7. Summary Statistics

In [None]:
print("="*60)
print("DATASET SUMMARY")
print("="*60)
print(f"\nDataset: DrugBank DDI")
print(f"Total drug pairs: {len(df):,}")
print(f"Unique drugs: {len(set(df['Drug1_ID']) | set(df['Drug2_ID'])):,}")
print(f"Interaction types: {df['Y'].nunique()}")
print(f"\nClass imbalance ratio: {interaction_counts.max() / interaction_counts.min():.1f}x")
print(f"\nGraph characteristics:")
print(f"  - Average nodes: {graph_df['num_nodes'].mean():.1f}")
print(f"  - Average edges: {graph_df['num_edges'].mean():.1f}")
print(f"  - Feature dimension: {graph_df['node_feature_dim'].iloc[0]}")
print(f"\nMolecular properties (median):")
print(f"  - Molecular weight: {props_df['molecular_weight'].median():.1f}")
print(f"  - LogP: {props_df['logp'].median():.2f}")
print(f"  - TPSA: {props_df['tpsa'].median():.1f}")
print("="*60)