In [1]:
import torch
import torch.nn as nn
import json
from torch_geometric.loader import DataLoader
from pathlib import Path
import numpy as np
from tqdm.notebook import tqdm
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK

# Import your existing modules
from src.models import CoGraphNet
from src.data.document_dataset import DocumentGraphDataset
from src.train import FocalLoss

# -------------------------
# Utility Functions
# -------------------------
def get_all_categories(train_dir: str, test_dir: str):
    """Get all unique categories across all datasets."""
    categories = set()
    for data_dir in [train_dir, test_dir]:
        for file in Path(data_dir).glob('*.json'):
            with open(file, 'r', encoding='utf-8') as f:
                doc = json.load(f)
                if 'text' in doc and 'category' in doc and doc['text'].strip() and doc['category'].strip():
                    categories.add(doc['category'])
    return categories

def create_dataloaders(root: str, train_dir: str, test_dir: str, batch_size: int):
    """Create DataLoader instances with a train/validation split."""
    all_categories = get_all_categories(train_dir, test_dir)
    category_to_idx = {cat: idx for idx, cat in enumerate(sorted(all_categories))}
    num_classes = len(category_to_idx)
    
    full_train_dataset = DocumentGraphDataset(
        f"{root}/train", 
        train_dir, 
        category_to_idx=category_to_idx
    )
    
    # Use 80% for training and 20% for validation.
    dataset_size = len(full_train_dataset)
    val_size = int(dataset_size * 0.2)
    train_size = dataset_size - val_size
    train_dataset, val_dataset = torch.utils.data.random_split(full_train_dataset, [train_size, val_size])
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    return train_loader, val_loader, num_classes, category_to_idx

def objective(params):
    """
    Hyperopt objective function.
    Trains the model for a few epochs using the provided hyperparameters and returns the validation loss.
    """
    print("Hyperparameters:", params)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Convert batch_size to integer
    batch_size = int(params['batch_size'])
    root = "processed_graphs_ohsumed"
    train_dir = "processed_data_ohsumed/train"
    test_dir = "processed_data_ohsumed/test"
    
    # Create dataloaders
    train_loader, val_loader, num_classes, category_to_idx = create_dataloaders(root, train_dir, test_dir, batch_size)
    
    # Create dropout configuration from hyperparameters
    dropout_config = {
        'word': params['dropout_word_enabled'],
        'sent': params['dropout_sent_enabled'],
        'fusion': params['dropout_fusion_enabled'],
        'co_graph': params['dropout_co_graph_enabled'],
        'final': params['dropout_final_enabled']
    }
    
    dropout_rate = {
        'word': params['dropout_rate_word'],
        'sent': params['dropout_rate_sent'],
        'fusion': params['dropout_rate_fusion'],
        'co_graph': params['dropout_rate_co_graph'],
        'final': params['dropout_rate_final']
    }
    
    # Build the model using the hyperparameters
    model = CoGraphNet(
        word_in_channels=768,
        sent_in_channels=768,
        hidden_channels=int(params['hidden_dim']),
        num_word_layers=int(params['num_word_layers']),
        num_sent_layers=int(params['num_sentence_layers']),
        num_classes=num_classes,
        dropout_config=dropout_config,
        dropout_rate=dropout_rate
    ).to(device)
    
    # Setup optimizer and criterion
    optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'], weight_decay=params['weight_decay'])
    
    # Compute separate class weights for training and validation sets
    train_labels = torch.tensor([data.y.item() for data in train_loader.dataset])
    train_class_counts = torch.bincount(train_labels, minlength=num_classes)
    train_total_samples = len(train_labels)
    train_class_weights = train_total_samples / (num_classes * train_class_counts.float())
    train_class_weights = train_class_weights.to(device)

    val_labels = torch.tensor([data.y.item() for data in val_loader.dataset])
    val_class_counts = torch.bincount(val_labels, minlength=num_classes)
    val_total_samples = len(val_labels)
    val_class_weights = val_total_samples / (num_classes * val_class_counts.float())
    val_class_weights = val_class_weights.to(device)
    
    # Create separate loss functions for training and validation
    train_criterion = FocalLoss(gamma=params['gamma'], weight=train_class_weights)
    val_criterion = FocalLoss(gamma=params['gamma'], weight=val_class_weights)
    
    num_epochs = 5  # Use a small number of epochs for hyperparameter tuning
    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}", leave=False, ncols=100):
            # Move data to device
            word_x = batch['word'].x.to(device)
            word_edge_index = batch['word', 'co_occurs', 'word'].edge_index.to(device)
            word_edge_weight = batch['word', 'co_occurs', 'word'].edge_attr.to(device)
            word_batch = batch['word'].batch.to(device)
            
            sent_x = batch['sentence'].x.to(device)
            sent_edge_index = batch['sentence', 'related_to', 'sentence'].edge_index.to(device)
            sent_edge_weight = batch['sentence', 'related_to', 'sentence'].edge_attr.to(device)
            sent_batch = batch['sentence'].batch.to(device)
            
            optimizer.zero_grad()
            outputs = model(
                word_x, word_edge_index, word_batch, word_edge_weight,
                sent_x, sent_edge_index, sent_batch, sent_edge_weight
            )
            batch = batch.to(device)
            curr_batch_size = batch.y.size(0)
            loss = train_criterion(outputs[:curr_batch_size], batch.y)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch+1} Training Loss: {epoch_loss/len(train_loader):.4f}")
    
    # Evaluate on validation set
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in val_loader:
            word_x = batch['word'].x.to(device)
            word_edge_index = batch['word', 'co_occurs', 'word'].edge_index.to(device)
            word_edge_weight = batch['word', 'co_occurs', 'word'].edge_attr.to(device)
            word_batch = batch['word'].batch.to(device)
            
            sent_x = batch['sentence'].x.to(device)
            sent_edge_index = batch['sentence', 'related_to', 'sentence'].edge_index.to(device)
            sent_edge_weight = batch['sentence', 'related_to', 'sentence'].edge_attr.to(device)
            sent_batch = batch['sentence'].batch.to(device)
            
            outputs = model(
                word_x, word_edge_index, word_batch, word_edge_weight,
                sent_x, sent_edge_index, sent_batch, sent_edge_weight
            )
            batch = batch.to(device)
            curr_batch_size = batch.y.size(0)
            loss = val_criterion(outputs[:curr_batch_size], batch.y)
            total_loss += loss.item() * curr_batch_size
            preds = outputs[:curr_batch_size].argmax(dim=1)
            correct += (preds == batch.y).sum().item()
            total += curr_batch_size
    
    val_loss = total_loss / total
    val_acc = correct / total
    print("Validation Loss:", val_loss, "Validation Accuracy:", val_acc)
    
    # Hyperopt minimizes the objective, so we return the validation loss.
    return {'loss': val_loss, 'status': STATUS_OK}

if __name__ == "__main__":
    # Define the hyperopt search space including batch_size and dropout_rate
    space = {
        'learning_rate': hp.loguniform('learning_rate', np.log(1e-5), np.log(1e-3)),
        'gamma': hp.uniform('gamma', 0.5, 5.0),
        'weight_decay': hp.loguniform('weight_decay', np.log(1e-8), np.log(1e-3)),
        'hidden_dim': hp.quniform('hidden_dim', 64, 256, 1),
        'num_word_layers': hp.quniform('num_word_layers', 1, 5, 1),
        'num_sentence_layers': hp.quniform('num_sentence_layers', 1, 5, 1),
        'batch_size': hp.quniform('batch_size', 16, 512, 16),
        
        # Dropout enable/disable flags
        'dropout_word_enabled': hp.choice('dropout_word_enabled', [True, False]),
        'dropout_sent_enabled': hp.choice('dropout_sent_enabled', [True, False]),
        'dropout_fusion_enabled': hp.choice('dropout_fusion_enabled', [True, False]),
        'dropout_co_graph_enabled': hp.choice('dropout_co_graph_enabled', [True, False]),
        'dropout_final_enabled': hp.choice('dropout_final_enabled', [True, False]),
        
        # Dropout rates for each component
        'dropout_rate_word': hp.uniform('dropout_rate_word', 0.1, 0.5),
        'dropout_rate_sent': hp.uniform('dropout_rate_sent', 0.1, 0.5),
        'dropout_rate_fusion': hp.uniform('dropout_rate_fusion', 0.1, 0.5),
        'dropout_rate_co_graph': hp.uniform('dropout_rate_co_graph', 0.1, 0.5),
        'dropout_rate_final': hp.uniform('dropout_rate_final', 0.1, 0.5)
    }
    
    trials = Trials()
    best = fmin(fn=objective, space=space, algo=tpe.suggest, max_evals=40, trials=trials)
    print("Best hyperparameters:", best)


Hyperparameters:                                                                                                                                                                                                   
{'batch_size': 496.0, 'dropout_co_graph_enabled': False, 'dropout_final_enabled': False, 'dropout_fusion_enabled': False, 'dropout_rate_co_graph': 0.4643111397148273, 'dropout_rate_final': 0.34970308072328576, 'dropout_rate_fusion': 0.45256714273560494, 'dropout_rate_sent': 0.42843059544507023, 'dropout_rate_word': 0.1774289678106419, 'dropout_sent_enabled': False, 'dropout_word_enabled': False, 'gamma': 4.000674057390592, 'hidden_dim': 79.0, 'learning_rate': 2.0791434687645655e-05, 'num_sentence_layers': 4.0, 'num_word_layers': 5.0, 'weight_decay': 1.6769135932290969e-06}
Loading existing valid indices from metadata                                                                                                                                                                    

job exception: CoGraphNet.__init__() got an unexpected keyword argument 'num_sentence_layers'



  0%|                                                                                                                                                                       | 0/40 [00:06<?, ?trial/s, best loss=?]


TypeError: CoGraphNet.__init__() got an unexpected keyword argument 'num_sentence_layers'