# MIMIC-IV Table2Graph Training Pipeline
## Semantic Relationship Detection with GNN

This notebook trains the Table2Graph pipeline on MIMIC-IV healthcare data to learn semantic relationships between table columns.

## Setup & Installation

In [None]:
# Install dependencies
!pip install -q sentence-transformers torch-geometric scikit-learn

print("✓ Dependencies installed")

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
CHECKPOINT_DIR = '/content/drive/MyDrive/table2graph_checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"✓ Checkpoints will be saved to: {CHECKPOINT_DIR}")

## Upload Files

**Upload these files to Colab:**
1. `table2graph_sem.py` - Main pipeline
2. `gcn_conv.py` - TableGCN implementation
3. `hosp/` folder - MIMIC-IV CSV files

In [None]:
# Verify uploads
import os

required_files = ['table2graph_sem.py', 'gcn_conv.py']
for f in required_files:
    if os.path.exists(f'/content/{f}'):
        print(f"✓ {f} found")
    else:
        print(f"✗ {f} MISSING - please upload!")

if os.path.exists('/content/hosp'):
    csv_count = len([f for f in os.listdir('/content/hosp') if f.endswith('.csv')])
    print(f"✓ hosp/ folder found with {csv_count} CSV files")
else:
    print("✗ hosp/ folder MISSING - please upload!")

## Import Pipeline

In [None]:
import sys
sys.path.append('/content')

from table2graph_sem import (
    ColumnContentExtractor,
    LightweightFeatureTokenizer,
    RelationshipGenerator,
    SemanticLabelGenerator,
    GraphBuilder,
    GNNEdgePredictor,
    Table2GraphPipeline
)

print("✓ Pipeline imported successfully")

## Load MIMIC-IV Data

In [None]:
import pandas as pd
import numpy as np

def load_mimic_tables(hosp_dir, max_rows=500):
    """Load MIMIC-IV tables with row limit for memory management"""
    tables = {}
    
    csv_files = [f for f in os.listdir(hosp_dir) if f.endswith('.csv')]
    print(f"Found {len(csv_files)} CSV files\n")
    
    for csv_file in csv_files:
        table_name = csv_file.replace('.csv', '')
        filepath = os.path.join(hosp_dir, csv_file)
        
        try:
            df = pd.read_csv(filepath, nrows=max_rows, low_memory=False)
            
            # Skip invalid tables
            if len(df.columns) < 2 or len(df) == 0:
                continue
                
            tables[table_name] = df
            print(f"  ✓ {table_name}: {df.shape[0]} × {df.shape[1]}")
            
        except Exception as e:
            print(f"  ✗ {table_name}: {e}")
    
    return tables

# Load tables
HOSP_DIR = '/content/hosp'
tables = load_mimic_tables(HOSP_DIR, max_rows=500)

print(f"\n✓ Loaded {len(tables)} tables")

## Initialize Pipeline

In [None]:
# Initialize pipeline
pipeline = Table2GraphPipeline(embedding_strategy='hybrid')
pipeline.initialize_for_training(model_manager=None, node_dim=512)

print("✓ Pipeline initialized")
print(f"  - Semantic labels: {pipeline.train_builder.num_classes}")
print(f"  - Node dimension: 512")
print(f"  - GNN layers: 3")

## Training Configuration

In [None]:
CONFIG = {
    'num_epochs': 50,
    'batch_size': 4,
    'early_stopping_patience': 10,
    'checkpoint_every': 5,
}

print("Training Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## Training Loop

In [None]:
import time
import torch

table_list = list(tables.values())
table_names = list(tables.keys())

history = {'epoch': [], 'loss': [], 'accuracy': [], 'time': []}
best_accuracy = 0.0
patience_counter = 0

print(f"Training on {len(table_list)} tables\n")
print("=" * 60)

for epoch in range(CONFIG['num_epochs']):
    epoch_start = time.time()
    
    print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
    print("-" * 40)
    
    # Shuffle tables
    indices = np.random.permutation(len(table_list))
    shuffled_tables = [table_list[i] for i in indices]
    shuffled_names = [table_names[i] for i in indices]
    
    # Train in batches
    epoch_losses = []
    epoch_accuracies = []
    
    for batch_idx in range(0, len(shuffled_tables), CONFIG['batch_size']):
        batch_tables = shuffled_tables[batch_idx:batch_idx + CONFIG['batch_size']]
        batch_names = shuffled_names[batch_idx:batch_idx + CONFIG['batch_size']]
        
        try:
            avg_loss, avg_accuracy = pipeline.train_epoch(batch_tables)
            epoch_losses.append(avg_loss)
            epoch_accuracies.append(avg_accuracy)
            
            print(f"  Batch {batch_idx//CONFIG['batch_size']+1}: "
                  f"Loss={avg_loss:.4f}, Acc={avg_accuracy:.3f}")
        except Exception as e:
            print(f"  ✗ Batch failed: {e}")
            continue
    
    # Epoch summary
    if epoch_losses:
        epoch_loss = np.mean(epoch_losses)
        epoch_accuracy = np.mean(epoch_accuracies)
        epoch_time = time.time() - epoch_start
        
        history['epoch'].append(epoch + 1)
        history['loss'].append(float(epoch_loss))
        history['accuracy'].append(float(epoch_accuracy))
        history['time'].append(float(epoch_time))
        
        print(f"\nEpoch Summary: Loss={epoch_loss:.4f}, Acc={epoch_accuracy:.3f}, Time={epoch_time:.2f}s")
        
        # Checkpointing
        if (epoch + 1) % CONFIG['checkpoint_every'] == 0:
            checkpoint_path = f"{CHECKPOINT_DIR}/checkpoint_epoch_{epoch+1}.pt"
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': pipeline.predictor.state_dict(),
                'loss': epoch_loss,
                'accuracy': epoch_accuracy,
            }, checkpoint_path)
            print(f"✓ Checkpoint saved")
        
        # Early stopping
        if epoch_accuracy > best_accuracy:
            best_accuracy = epoch_accuracy
            patience_counter = 0
            torch.save(pipeline.predictor.state_dict(), f"{CHECKPOINT_DIR}/best_model.pt")
            print(f"✓ New best: {best_accuracy:.3f}")
        else:
            patience_counter += 1
        
        if patience_counter >= CONFIG['early_stopping_patience']:
            print(f"\n⚠ Early stopping (no improvement for {patience_counter} epochs)")
            break

print("\n" + "=" * 60)
print(f"Training Complete! Best Accuracy: {best_accuracy:.3f}")
print("=" * 60)

## Plot Training History

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
ax1.plot(history['epoch'], history['loss'], 'b-', label='Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(history['epoch'], history['accuracy'], 'g-', label='Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print("✓ Training curves saved")

## Test Predictions

In [None]:
# Initialize for testing
pipeline.initialize_for_testing()

# Test on admissions table
test_table = 'admissions'
if test_table in tables:
    test_df = tables[test_table]
    
    print(f"Testing on: {test_table}")
    print(f"Shape: {test_df.shape}")
    print(f"Columns: {list(test_df.columns)}\n")
    
    predictions = pipeline.predict_relationships(test_df)
    
    print(f"✓ Predicted {len(predictions)} relationships\n")
    print("Top 10 Predictions:")
    print("=" * 80)
    
    for i, pred in enumerate(predictions[:10]):
        print(f"{i+1}. {pred['col1']} ↔ {pred['col2']}")
        print(f"   Label: {pred['predicted_label']}")
        print(f"   Confidence: {pred['confidence']:.3f}")
        print(f"   Meaning: {pred['semantic_meaning'][:80]}...")
        print()

## Export Results

In [None]:
import json

# Save training history
with open(f'{CHECKPOINT_DIR}/history.json', 'w') as f:
    json.dump(history, f, indent=2)

# Save predictions
with open(f'{CHECKPOINT_DIR}/predictions_{test_table}.json', 'w') as f:
    json.dump(predictions, f, indent=2)

print("✓ Results exported to Google Drive")
print(f"  Location: {CHECKPOINT_DIR}")