In [2]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool
import numpy as np
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import roc_auc_score
import pandas as pd

class RadGraphEmbedder:
    def __init__(self, clinical_bert_model="emilyalsentzer/Bio_ClinicalBERT"):
        """
        Initialize the RadGraph embedder with a clinical BERT model
        """
        self.tokenizer = AutoTokenizer.from_pretrained(clinical_bert_model)
        self.bert_model = AutoModel.from_pretrained(clinical_bert_model)
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.bert_model.to(self.device)
        
    def process_radgraph(self, nodes, edges):
        """
        Convert RadGraph nodes and edges into a PyG graph
        
        Args:
            nodes: Dict of node_id -> {text, type, span}
            edges: List of (source_id, target_id, relation_type)
            
        Returns:
            torch_geometric.data.Data object
        """
        # Generate BERT embeddings for each node
        node_embeddings = []
        node_types = []
        
        for node_id in sorted(nodes.keys()):
            node = nodes[node_id]
            # Get BERT embedding for node text
            inputs = self.tokenizer(node['text'], 
                                  return_tensors='pt',
                                  padding=True,
                                  truncation=True,
                                  max_length=128)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                outputs = self.bert_model(**inputs)
                # Use CLS token embedding as node representation
                node_emb = outputs.last_hidden_state[:, 0, :].cpu()
                node_embeddings.append(node_emb)
            
            # Convert node type to one-hot
            type_mapping = {
                'ANATOMY': 0,
                'OBSERVATION': 1,
                'MEASUREMENT': 2,
                'PROCEDURE': 3,
                'MODIFIER': 4
            }
            node_type = torch.zeros(len(type_mapping))
            node_type[type_mapping[node['type']]] = 1
            node_types.append(node_type)
            
        node_embeddings = torch.cat(node_embeddings, dim=0)
        node_types = torch.stack(node_types)
        
        # Combine BERT embeddings with node type
        node_features = torch.cat([node_embeddings, node_types], dim=1)
        
        # Create edge index and edge attributes
        edge_index = []
        edge_attr = []
        relation_mapping = {
            'SUGGESTIVE_OF': 0,
            'LOCATED_AT': 1, 
            'MODIFY': 2,
            'RELATED_TO': 3
        }
        
        for source, target, rel_type in edges:
            edge_index.append([source, target])
            edge_type = torch.zeros(len(relation_mapping))
            edge_type[relation_mapping[rel_type]] = 1
            edge_attr.append(edge_type)
            
        edge_index = torch.tensor(edge_index).t().contiguous()
        edge_attr = torch.stack(edge_attr)
        
        return Data(x=node_features, 
                   edge_index=edge_index,
                   edge_attr=edge_attr)

class InterpretableGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
        """
        Interpretable GCN for processing RadGraph embeddings
        
        Args:
            in_channels: Input feature dimension
            hidden_channels: Hidden layer dimension
            out_channels: Output dimension (number of clinical outcomes)
            num_layers: Number of GCN layers
        """
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        self.attention = torch.nn.ModuleList()
        
        # Input layer
        self.convs.append(GCNConv(in_channels, hidden_channels))
        self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))
        self.attention.append(torch.nn.Linear(hidden_channels, 1))
        
        # Hidden layers
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
            self.batch_norms.append(torch.nn.BatchNorm1d(hidden_channels))
            self.attention.append(torch.nn.Linear(hidden_channels, 1))
            
        # Output layer
        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.batch_norms.append(torch.nn.BatchNorm1d(out_channels))
        self.attention.append(torch.nn.Linear(out_channels, 1))
        
    def forward(self, x, edge_index, edge_attr, batch):
        """
        Forward pass with attention weights for interpretability
        """
        attention_weights = []
        
        for conv, batch_norm, attn in zip(self.convs, self.batch_norms, self.attention):
            # Graph convolution
            x = conv(x, edge_index)
            x = batch_norm(x)
            x = torch.relu(x)
            
            # Calculate attention weights
            attn_weight = torch.sigmoid(attn(x))
            attention_weights.append(attn_weight)
            
            # Apply attention
            x = x * attn_weight
            
        # Global pooling
        out = global_mean_pool(x, batch)
        
        return out, attention_weights
    
    def interpret_predictions(self, graph_data, predictions, attention_weights):
        """
        Generate interpretations for predictions
        
        Returns dictionary mapping outcomes to most influential nodes/edges
        """
        interpretations = {}
        
        # Combine attention weights across layers
        combined_attention = torch.cat(attention_weights, dim=1).mean(dim=1)
        
        # Get top-k most important nodes
        k = min(5, len(combined_attention))
        top_nodes = combined_attention.topk(k)
        
        for outcome_idx, pred in enumerate(predictions):
            important_nodes = []
            for node_idx in top_nodes.indices:
                node_importance = {
                    'node_idx': node_idx.item(),
                    'attention': combined_attention[node_idx].item(),
                    'node_features': graph_data.x[node_idx].tolist()
                }
                important_nodes.append(node_importance)
                
            interpretations[f'outcome_{outcome_idx}'] = {
                'prediction': pred.item(),
                'important_nodes': important_nodes
            }
            
        return interpretations

def train_model(model, train_loader, val_loader, num_epochs=100):
    """
    Train the GCN model
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.BCEWithLogitsLoss()
    
    best_val_auc = 0
    patience = 10
    counter = 0
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        total_loss = 0
        
        for batch in train_loader:
            optimizer.zero_grad()
            out, _ = model(batch.x, batch.edge_index, 
                          batch.edge_attr, batch.batch)
            loss = criterion(out, batch.y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            
        # Validation
        model.eval()
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for batch in val_loader:
                out, _ = model(batch.x, batch.edge_index,
                             batch.edge_attr, batch.batch)
                val_preds.append(torch.sigmoid(out))
                val_labels.append(batch.y)
                
        val_preds = torch.cat(val_preds, dim=0)
        val_labels = torch.cat(val_labels, dim=0)
        val_auc = roc_auc_score(val_labels, val_preds)
        
        # Early stopping
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break
                
        if epoch % 10 == 0:
            print(f"Epoch {epoch}: Loss = {total_loss:.4f}, Val AUC = {val_auc:.4f}")
            
    return model

def predict_outcomes(model, graph_data):
    """
    Generate predictions and interpretations for a single graph
    """
    model.eval()
    with torch.no_grad():
        predictions, attention_weights = model(graph_data.x,
                                            graph_data.edge_index,
                                            graph_data.edge_attr,
                                            graph_data.batch)
        probabilities = torch.sigmoid(predictions)
        
        interpretations = model.interpret_predictions(
            graph_data, probabilities, attention_weights)
        
    return probabilities, interpretations

In [4]:
import networkx as nx
from typing import Dict, List, Tuple, Set
from collections import defaultdict

class RadGraphAnalyzer:
    """
    Analyzes RadGraph extracts according to the official schema for pneumonia classification
    """
    
    ENTITY_TYPES = {
        'Anatomy::definitely present',
        'Observation::definitely present',
        'Observation::uncertain',
        'Observation::definitely absent'
    }
    
    RELATION_TYPES = {
        'suggestive_of',  # Observation -> Observation
        'located_at',     # Observation -> Anatomy
        'modify'          # Observation -> Observation or Anatomy -> Anatomy
    }
    
    def __init__(self):
        self.pneumonia_terms = {
            'consolidation', 'opacity', 'infiltrate', 'pneumonia',
            'airspace disease', 'ground glass', 'patchy', 'focal'
        }
        
    def parse_radgraph_format(self, extract: Dict) -> Tuple[List[Dict], List[Dict]]:
        """
        Converts RadGraph extract format into structured entities and relations
        
        Args:
            extract: RadGraph dictionary extract
        
        Returns:
            Tuple of (entities, relations)
        """
        text = extract['text']
        entities_dict = extract['entities']
        
        entities = []
        relations = []
        
        # Process entities
        for entity_id, entity_info in entities_dict.items():
            entity = {
                'id': entity_id,
                'text': ' '.join(text.split()[entity_info['start_ix']:entity_info['end_ix'] + 1]),
                'type': entity_info['label'],
                'start_ix': entity_info['start_ix'],
                'end_ix': entity_info['end_ix']
            }
            entities.append(entity)
            
            # Process relations
            for rel_type, target_id in entity_info.get('relations', []):
                relation = {
                    'source': entity_id,
                    'target': target_id,
                    'type': rel_type
                }
                relations.append(relation)
        
        return entities, relations

    def analyze_certainty(self, entities: List[Dict]) -> Dict[str, int]:
        """
        Analyzes certainty levels of observations
        """
        certainty_counts = defaultdict(int)
        for entity in entities:
            if '::' in entity['type']:
                certainty_counts[entity['type'].split('::')[1]] += 1
        return dict(certainty_counts)

    def find_pneumonia_patterns(self, entities: List[Dict], relations: List[Dict]) -> Dict:
        """
        Identifies patterns relevant to pneumonia classification
        """
        patterns = {
            'findings': defaultdict(list),
            'locations': defaultdict(list),
            'modifiers': defaultdict(list)
        }
        
        # Create entity lookup
        entity_map = {e['id']: e for e in entities}
        
        for relation in relations:
            source = entity_map[relation['source']]
            target = entity_map[relation['target']]
            
            if relation['type'] == 'located_at':
                if 'Observation' in source['type']:
                    patterns['findings'][source['text']].append(target['text'])
                    
            elif relation['type'] == 'modify':
                if 'Observation' in source['type'] and 'Observation' in target['type']:
                    patterns['modifiers'][target['text']].append(source['text'])
                    
            elif relation['type'] == 'suggestive_of':
                if any(term in source['text'].lower() for term in self.pneumonia_terms):
                    patterns['findings']['suggestive_patterns'].append(
                        (source['text'], target['text'])
                    )
        
        return {k: dict(v) for k, v in patterns.items()}

    def extract_features(self, extract: Dict) -> Dict:
        """
        Extracts features relevant to pneumonia classification
        
        Args:
            extract: RadGraph dictionary extract
            
        Returns:
            Dictionary of extracted features
        """
        entities, relations = self.parse_radgraph_format(extract)
        
        # Create graph representation
        G = nx.DiGraph()
        for entity in entities:
            G.add_node(entity['id'], **entity)
        for relation in relations:
            G.add_edge(relation['source'], relation['target'], type=relation['type'])
            
        # Analyze patterns
        patterns = self.find_pneumonia_patterns(entities, relations)
        certainty = self.analyze_certainty(entities)
        
        # Extract observations and their properties
        observations = defaultdict(lambda: {
            'text': '',
            'certainty': '',
            'locations': [],
            'modifiers': [],
            'suggestions': []
        })
        
        for entity in entities:
            if 'Observation' in entity['type']:
                obs_id = entity['id']
                observations[obs_id]['text'] = entity['text']
                observations[obs_id]['certainty'] = entity['type'].split('::')[1]
                
                # Find connected entities
                for _, neighbor, rel_data in G.edges(obs_id, data=True):
                    neighbor_data = G.nodes[neighbor]
                    if rel_data['type'] == 'located_at':
                        observations[obs_id]['locations'].append(neighbor_data['text'])
                    elif rel_data['type'] == 'modify':
                        observations[obs_id]['modifiers'].append(neighbor_data['text'])
                    elif rel_data['type'] == 'suggestive_of':
                        observations[obs_id]['suggestions'].append(neighbor_data['text'])
        
        features = {
            'observations': dict(observations),
            'patterns': patterns,
            'certainty_analysis': certainty,
            'graph_metrics': {
                'num_entities': len(entities),
                'num_relations': len(relations),
                'num_anatomical_sites': len([e for e in entities if 'Anatomy' in e['type']]),
                'num_observations': len([e for e in entities if 'Observation' in e['type']])
            }
        }
        
        return features

# Example usage
def main():
    # Example RadGraph extract
    sample_extract = {
        'text': 'no evidence of acute cardiopulmonary process moderate hiatal hernia',
        'entities': {
            '1': {'tokens': 'acute', 'label': 'Observation::definitely absent', 
                  'start_ix': 3, 'end_ix': 3, 'relations': [['modify', '3']]},
            '2': {'tokens': 'cardiopulmonary', 'label': 'Anatomy::definitely present', 
                  'start_ix': 4, 'end_ix': 4, 'relations': []},
            '3': {'tokens': 'process', 'label': 'Observation::definitely absent', 
                  'start_ix': 5, 'end_ix': 5, 'relations': [['located_at', '2']]},
            '4': {'tokens': 'moderate', 'label': 'Observation::definitely present', 
                  'start_ix': 6, 'end_ix': 6, 'relations': [['modify', '5']]},
            '5': {'tokens': 'hiatal hernia', 'label': 'Observation::definitely present', 
                  'start_ix': 7, 'end_ix': 8, 'relations': []}
        }
    }
    
    analyzer = RadGraphAnalyzer()
    features = analyzer.extract_features(sample_extract)
    
    print("\nExtracted Features:")
    for key, value in features.items():
        print(f"\n{key}:")
        print(value)

if __name__ == "__main__":
    main()


Extracted Features:

observations:
{'1': {'text': 'acute', 'certainty': 'definitely absent', 'locations': [], 'modifiers': ['process'], 'suggestions': []}, '3': {'text': 'process', 'certainty': 'definitely absent', 'locations': ['cardiopulmonary'], 'modifiers': [], 'suggestions': []}, '4': {'text': 'moderate', 'certainty': 'definitely present', 'locations': [], 'modifiers': ['hiatal hernia'], 'suggestions': []}, '5': {'text': 'hiatal hernia', 'certainty': 'definitely present', 'locations': [], 'modifiers': [], 'suggestions': []}}

patterns:
{'findings': {'process': ['cardiopulmonary']}, 'locations': {}, 'modifiers': {'process': ['acute'], 'hiatal hernia': ['moderate']}}

certainty_analysis:
{'definitely absent': 2, 'definitely present': 3}

graph_metrics:
{'num_entities': 5, 'num_relations': 3, 'num_anatomical_sites': 1, 'num_observations': 4}
