In [1]:
import json
import re
from pathlib import Path
from typing import Dict, List, Set, Tuple, Optional
import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict, Counter
import pandas as pd
from owlready2 import get_ontology
import numpy as np
from itertools import combinations
import random
from datetime import datetime
import warnings
import logging
from tqdm import tqdm
import gc
warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
try:
    import spacy
    SPACY_AVAILABLE = True
except ImportError:
    SPACY_AVAILABLE = False
    logger.warning("spaCy not available. NLP entity extraction will be limited.")
logging.getLogger('matplotlib.legend').setLevel(logging.ERROR)
class LegalKnowledgeGraphPipeline:    
    def __init__(self, owl_path: str, processed_dir: str = "dataset_processed", 
                 rules_dir: str = "official_documents", cache_dir: str = "entity_cache"):
        self.processed_dir = Path(processed_dir)
        self.rules_dir = Path(rules_dir)
        self.owl_path = owl_path        
        self.onto = None
        self.cases_data = {}
        self.entity_mappings = {}
        self.global_knowledge_graph = nx.DiGraph()
        self.legal_doctrines = self._initialize_legal_doctrines()
        self.court_hierarchy = {
            'supreme_court': 1, 'allahabad_high_court': 2,
            'bombay_high_court': 2, 'calcutta_high_court': 2,
            'delhi_high_court': 2, 'madras_high_court': 2}        
        self.stats = {
            'total_cases': 0, 'by_court': defaultdict(int),
            'by_year': defaultdict(int)}
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(exist_ok=True)
        self.entity_cache_enabled = True        
        if SPACY_AVAILABLE:
            try:
                self.nlp = spacy.load("en_core_web_sm")
            except:
                logger.warning("Spacy model not found.")
                self.nlp = None
        else:
            self.nlp = None
        self.load_ontology()
    def _initialize_legal_doctrines(self) -> Dict[str, List[str]]:
        return {
            'constitutional': ['fundamental rights', 'directive principles', 'judicial review'],
            'procedural': ['natural justice', 'due process', 'res judicata'],
            'criminal': ['presumption of innocence', 'burden of proof', 'reasonable doubt'],
            'civil': ['promissory estoppel', 'unjust enrichment', 'specific performance'],
            'administrative': ['legitimate expectation', 'proportionality', 'ultra vires']}
    def load_ontology(self):
        logger.info(f"Loading ontology from {self.owl_path}...")
        try:
            self.onto = get_ontology(self.owl_path).load()
            logger.info(f"Ontology loaded: {self.onto.base_iri}")
            self._index_ontology_entities()
        except Exception as e:
            logger.error(f"Error loading ontology: {e}")
            raise
    def _index_ontology_entities(self):
        self.onto_index = {'classes': {}, 'individuals': {}}
        for cls in self.onto.classes():
            labels = [str(l).lower() for l in cls.label] if cls.label else []
            self.onto_index['classes'][cls.name.lower()] = cls
            for label in labels:
                self.onto_index['classes'][label] = cls
        for ind in self.onto.individuals():
            labels = [str(l).lower() for l in ind.label] if ind.label else []
            self.onto_index['individuals'][ind.name.lower()] = ind
            for label in labels:
                self.onto_index['individuals'][label] = ind	    
    def load_official_documents(self):
        logger.info("Loading official legal documents from PDFs...")
        self.official_docs = {}
        doc_files = {
            'ipc': 'Indian Penal Code.pdf', 'crpc': 'Code of Criminal Procedure.pdf',
            'constitution': 'Constitution of India.pdf', 'evidence': 'Indian Evidence Act.pdf'}
        for doc_type, filename in doc_files.items():
            filepath = self.rules_dir / filename
            if filepath.exists():
                try:
                    logger.info(f"Loading {filename}...")
                    extracted_data = self._extract_provisions_from_pdf(filepath, doc_type)
                    if extracted_data:
                        self.official_docs[doc_type] = extracted_data
                except Exception as e:
                    logger.warning(f"Could not load {filename}: {e}")
            else:
                logger.warning(f"File not found: {filepath}")
        if self.official_docs:
            self._create_provision_index()
        else:
            logger.warning("No official documents loaded.")
    def _extract_provisions_from_pdf(self, pdf_path: Path, doc_type: str) -> Dict:
        try:
            import PyPDF2            
            extracted_data = {'provisions': [], 'full_text': ''}
            with open(pdf_path, 'rb') as file:
                pdf_reader = PyPDF2.PdfReader(file)
                full_text = ""
                logger.info(f"  Reading {len(pdf_reader.pages)} pages...")
                for page in pdf_reader.pages:
                    text = page.extract_text()
                    if text:
                        full_text += text + "\n"
                extracted_data['full_text'] = full_text
                if doc_type == 'constitution':
                    provisions = self._extract_articles_from_text(full_text)
                else:
                    provisions = self._extract_sections_from_text(full_text)
                extracted_data['provisions'] = provisions
                return extracted_data
        except ImportError:
            logger.error("PyPDF2 not installed.")
            return {}
        except Exception as e:
            logger.error(f"Error extracting from PDF: {e}")
            return {}
    def _extract_sections_from_text(self, text: str) -> List[Dict]:
        sections = []
        section_pattern = r'(?:Section\s+)?(\d+[A-Z]*)\.\s*([^\n]{10,200}?)'
        matches = re.finditer(section_pattern, text, re.MULTILINE)
        for match in matches:
            section_num = match.group(1).strip()
            section_text = match.group(2).strip()
            section_text = re.sub(r'\s+', ' ', section_text)
            sections.append({
                'number': section_num, 'text': section_text[:500],
                'title': section_text.split('.')[0][:100] if '.' in section_text else section_text[:100]})
        return sections[:500]
    def _extract_articles_from_text(self, text: str) -> List[Dict]:
        articles = []
        article_pattern = r'(?:Article\s+)?(\d+[A-Z]*)\.\s*([^\n]{10,200}?)'
        matches = re.finditer(article_pattern, text, re.MULTILINE)
        for match in matches:
            article_num = match.group(1).strip()
            article_text = match.group(2).strip()
            article_text = re.sub(r'\s+', ' ', article_text)
            articles.append({
                'number': article_num, 'text': article_text[:500],
                'title': article_text.split('.')[0][:100] if '.' in article_text else article_text[:100]})
        return articles[:500]
    def _create_provision_index(self):
        self.provision_index = defaultdict(dict)
        self.provision_to_act_map = {}
        doc_name_map = {
            'ipc': 'Indian Penal Code', 'crpc': 'Code of Criminal Procedure',
            'constitution': 'Constitution of India', 'evidence': 'Indian Evidence Act'}
        for doc_type, doc_data in self.official_docs.items():
            act_name = doc_name_map.get(doc_type, doc_type)
            if 'provisions' in doc_data:
                for provision in doc_data['provisions']:
                    prov_num = provision.get('number', '')
                    self.provision_index[doc_type][prov_num] = {
                        'text': provision.get('text', ''),
                        'title': provision.get('title', ''),
                        'act': act_name}
                    if doc_type == 'constitution':
                        self.provision_to_act_map[f"Article {prov_num}"] = act_name
                    else:
                        self.provision_to_act_map[f"Section {prov_num}"] = act_name
        total = sum(len(v) for v in self.provision_index.values())
        logger.info(f"Indexed {total} provisions")
        for doc_type, provisions in self.provision_index.items():
            logger.info(f"  {doc_name_map.get(doc_type)}: {len(provisions)} provisions")
    def load_all_cases(self):
        logger.info("Loading all case files...")
        courts = [
            'supreme_court', 'allahabad_high_court', 'bombay_high_court',
            'calcutta_high_court', 'delhi_high_court', 'madras_high_court']
        for court in courts:
            court_dir = self.processed_dir / court
            if not court_dir.exists():
                logger.warning(f"Directory not found: {court_dir}")
                continue
            self.cases_data[court] = []
            json_files = list(court_dir.glob("*.json"))
            logger.info(f"Loading {len(json_files)} cases from {court}...")
            for json_file in tqdm(json_files, desc=f"Loading {court}"):
                try:
                    with open(json_file, 'r', encoding='utf-8') as f:
                        case_data = json.load(f)
                        case_data['court_name'] = court
                        self.cases_data[court].append(case_data)
                        self.stats['by_court'][court] += 1
                        year = self._extract_year(case_data)
                        if year:
                            self.stats['by_year'][year] += 1
                except Exception as e:
                    logger.error(f"Error loading {json_file}: {e}")
            logger.info(f"Loaded {len(self.cases_data[court])} cases from {court}")
        self.stats['total_cases'] = sum(self.stats['by_court'].values())
        logger.info(f"Total cases loaded: {self.stats['total_cases']}")    
    def _extract_year(self, case: Dict) -> Optional[str]:
        if 'metadata' in case and 'date' in case['metadata']:
            date_str = case['metadata']['date']
            try:
                date_obj = datetime.strptime(date_str, '%Y-%m-%d')
                return str(date_obj.year)
            except:
                pass
        if 'file_name' in case:
            match = re.search(r'_(19\d{2}|20\d{2})_', case['file_name'])
            if match:
                return match.group(1)
        return None
    def entity_extraction(self, case: Dict) -> Dict:
        case_id = case.get('file_name', '')
        if self.entity_cache_enabled and case_id:
            cache_file = self.cache_dir / f"{case_id.replace('/', '_')}.json"
            if cache_file.exists():
                try:
                    with open(cache_file, 'r', encoding='utf-8') as f:
                        return json.load(f)
                except:
                    pass        
        entities = self._extract_entities_internal(case)
        if self.entity_cache_enabled and case_id:
            cache_file = self.cache_dir / f"{case_id.replace('/', '_')}.json"
            try:
                with open(cache_file, 'w', encoding='utf-8') as f:
                    json.dump(entities, f)
            except:
                pass
        return entities
    def _extract_entities_internal(self, case: Dict) -> Dict:
        entities = {
            'court': None, 'parties': {'petitioner': [], 'respondent': []}, 'judges': [],
            'provisions': {'sections': [], 'articles': [], 'rules': [], 'orders': []},
            'citations': [], 'acts': [], 'legal_concepts': [], 'case_types': [],
            'doctrines': [], 'outcomes': []}        
        metadata = case.get('metadata', {})
        entities['court'] = case.get('court_name', metadata.get('court', ''))
        entities['date'] = metadata.get('date', '')        
        if metadata.get('petitioner'):
            entities['parties']['petitioner'] = self._normalize_party_name(metadata['petitioner'])
        if metadata.get('respondent'):
            entities['parties']['respondent'] = self._normalize_party_name(metadata['respondent'])
        entities['citations'] = metadata.get('citations', [])[:15]
        text = case.get('text', '')
        if not text:
            return entities
        text_sample = text[:45000]        
        entities['provisions'] = self._extract_provisions(text_sample)
        entities['acts'] = self._extract_acts(text_sample)
        entities['legal_concepts'] = self._extract_legal_concepts(text_sample)
        entities['doctrines'] = self._extract_doctrines(text_sample)
        entities['case_types'] = self._classify_case_type(text_sample, metadata)
        entities['judges'] = self._extract_judges(text_sample)
        entities['outcomes'] = self._extract_outcomes(text_sample)        
        entities['ontology_mappings'] = self._map_to_ontology_entities(entities)        
        entities = self.enrich_entities_with_rulebooks(entities)
        return entities
    def _normalize_party_name(self, party_name: str) -> List[str]:
        if not party_name:
            return []
        party_name = re.sub(r'\s+(Ltd\.|Limited|Pvt\.|Private|Co\.)\.?$', 
                           '', party_name, flags=re.IGNORECASE)
        parties = re.split(r'\s+(?:vs?\.?|versus|and others?)\s+', 
                          party_name, flags=re.IGNORECASE)
        return [p.strip() for p in parties if p.strip()][:3]
    def _extract_provisions(self, text: str) -> Dict[str, List[str]]:
        provisions = {'sections': [], 'articles': [], 'rules': [], 'orders': []}        
        patterns = {
            'sections': [r"Section\s+(\d+[A-Z]*(?:\s*\([a-z0-9]+\))*)"],
            'articles': [r"Article\s+(\d+[A-Z]*(?:\s*\([a-z0-9]+\))*)"],
            'rules': [r"Rule\s+(\d+[A-Z]*)"], 'orders': [r"Order\s+([IVX]+)"]}
        for prov_type, pattern_list in patterns.items():
            for pattern in pattern_list:
                matches = re.findall(pattern, text, re.IGNORECASE)
                provisions[prov_type].extend(matches)
        for prov_type in provisions:
            provisions[prov_type] = list(set([p.strip() for p in provisions[prov_type] if p.strip()]))[:30]
        return provisions
    def _extract_acts(self, text: str) -> List[Dict[str, str]]:
        act_pattern = r"([A-Z][A-Za-z\s&,\(\)]+(?:Act|Code)(?:\s*,?\s*(\d{4}))?)"
        matches = re.findall(act_pattern, text)
        acts = []
        seen = set()
        for act_name, year in matches:
            act_name = act_name.strip()
            if 10 < len(act_name) < 150 and act_name not in seen:
                acts.append({'name': act_name, 'year': year if year else None})
                seen.add(act_name)
        return acts[:25]    
    def _extract_legal_concepts(self, text: str) -> List[str]:
        concept_pattern = (
            r"\b(jurisdiction|appeal|writ|mandamus|certiorari|"
            r"natural justice|due process|fundamental right|"
            r"res judicata|mens rea|actus reus|burden of proof|"
            r"negligence|damages|injunction)\b")
        concepts = re.findall(concept_pattern, text, re.IGNORECASE)
        return list(set([c.title() for c in concepts]))[:40]    
    def _extract_doctrines(self, text: str) -> List[str]:
        doctrines = []
        text_lower = text.lower()
        for doctrine_type, doctrine_list in self.legal_doctrines.items():
            for doctrine in doctrine_list:
                if doctrine.lower() in text_lower:
                    doctrines.append(doctrine)        
        return list(set(doctrines))    
    def _classify_case_type(self, text: str, metadata: Dict) -> List[str]:
        case_types = []
        text_lower = text.lower()        
        type_patterns = {
            'Criminal': [
                ('criminal appeal', 2), ('conviction', 1), ('ipc', 1), 
                ('penal code', 1), ('accused', 2)],
            'Civil': [
                ('civil appeal', 2), ('contract', 1), ('property dispute', 2),
                ('damages', 1), ('plaintiff', 2)],
            'Constitutional': [
                ('writ petition', 2), ('article 32', 2), ('article 226', 2),
                ('fundamental right', 2), ('constitutional validity', 2)],
            'Tax': [
                ('income tax', 2), ('sales tax', 1), ('customs', 1),
                ('tax appeal', 2), ('assessment', 1)],
            'Service': [
                ('service matter', 2), ('employment', 1), ('termination', 1),
                ('pension', 1)]}
        for case_type, patterns in type_patterns.items():
            score = 0
            for keyword, weight in patterns:
                if keyword in text_lower:
                    score += weight            
            threshold = 2 if case_type == 'Constitutional' else 1
            if score >= threshold:
                case_types.append(case_type)
        return list(set(case_types))[:5] if case_types else ['General']    
    def _extract_judges(self, text: str) -> List[str]:
        judge_patterns = [r"(?:Hon'ble\s+)?Justice\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)"]        
        judges = []
        for pattern in judge_patterns:
            matches = re.findall(pattern, text[:5000])
            judges.extend(matches)
        return list(set(judges))[:10]    
    def _extract_outcomes(self, text: str) -> List[str]:
        outcomes = []
        text_lower = text.lower()        
        outcome_keywords = {
            'allowed': ['appeal allowed', 'petition allowed'],
            'dismissed': ['appeal dismissed', 'petition dismissed'],
            'quashed': ['order quashed', 'judgment quashed'],
            'upheld': ['upheld', 'affirmed']}
        for outcome, keywords in outcome_keywords.items():
            if any(keyword in text_lower for keyword in keywords):
                outcomes.append(outcome)
        return outcomes
    def _map_to_ontology_entities(self, entities: Dict) -> Dict[str, List[str]]:
        mappings = {'classes': [], 'individuals': []}
        for prov_list in entities['provisions'].values():
            for prov in prov_list:
                prov_lower = prov.lower()
                if prov_lower in self.onto_index['individuals']:
                    mappings['individuals'].append(prov)
        for concept in entities.get('legal_concepts', []):
            concept_lower = concept.lower()
            if concept_lower in self.onto_index['classes']:
                mappings['classes'].append(concept)
        return mappings
    def enrich_entities_with_rulebooks(self, entities: Dict) -> Dict:
        enriched = entities.copy()
        enriched['validated_provisions'] = []
        enriched['provision_details'] = {}
        enriched['detected_rulebooks'] = set()        
        for prov_type, prov_list in entities['provisions'].items():
            for provision in prov_list:
                matched_doc, matched_details = self._match_provision_to_rulebook(provision, prov_type)
                if matched_doc:
                    enriched['validated_provisions'].append(provision)
                    enriched['provision_details'][provision] = {
                        'rulebook': matched_doc,
                        'full_text': matched_details.get('text', ''),
                        'title': matched_details.get('title', ''),
                        'act': matched_details.get('act', '')}
                    enriched['detected_rulebooks'].add(matched_doc)
        enriched['act_to_rulebook_map'] = {}
        for act in entities.get('acts', []):
            act_name = act['name'] if isinstance(act, dict) else act
            rulebook = self._map_act_to_rulebook(act_name)
            if rulebook:
                enriched['act_to_rulebook_map'][act_name] = rulebook
                enriched['detected_rulebooks'].add(rulebook)
        enriched['detected_rulebooks'] = list(enriched['detected_rulebooks'])
        return enriched
    def _match_provision_to_rulebook(self, provision: str, prov_type: str) -> Tuple[Optional[str], Dict]:
        match = re.search(r'(\d+[A-Z]*)', provision)
        if not match:
            return None, {}
        prov_num = match.group(1)
        if prov_type == 'sections':
            if prov_num in self.provision_index.get('ipc', {}):
                return 'Indian Penal Code', self.provision_index['ipc'][prov_num]
            if prov_num in self.provision_index.get('crpc', {}):
                return 'Code of Criminal Procedure', self.provision_index['crpc'][prov_num]
            if prov_num in self.provision_index.get('evidence', {}):
                return 'Indian Evidence Act', self.provision_index['evidence'][prov_num]
        elif prov_type == 'articles':
            if prov_num in self.provision_index.get('constitution', {}):
                return 'Constitution of India', self.provision_index['constitution'][prov_num]
        return None, {}    
    def _map_act_to_rulebook(self, act_name: str) -> Optional[str]:
        act_lower = act_name.lower()
        if any(k in act_lower for k in ['penal code', 'ipc']):
            return 'Indian Penal Code'
        elif any(k in act_lower for k in ['criminal procedure', 'crpc']):
            return 'Code of Criminal Procedure'
        elif 'constitution' in act_lower:
            return 'Constitution of India'
        elif 'evidence' in act_lower:
            return 'Indian Evidence Act'
        return None
    def create_court_knowledge_graph(self, court_name: str, num_cases: int = 10,
                                    random_sample: bool = True) -> nx.DiGraph:
        logger.info(f"Creating knowledge graph for {court_name}...")        
        G = nx.DiGraph()
        court_cases = self.cases_data.get(court_name, [])
        if not court_cases:
            logger.warning(f"No cases found for {court_name}")
            return G        
        if random_sample and len(court_cases) > num_cases:
            sampled_cases = random.sample(court_cases, num_cases)
        else:
            sampled_cases = court_cases[:num_cases]        
        court_node_id = f"Court_{court_name}"
        G.add_node(court_node_id, 
                  node_type='court', label=court_name.replace('_', ' ').title(),
                  hierarchy_level=self.court_hierarchy.get(court_name, 2),
                  size=7000)        
        rulebook_nodes = {}
        for rulebook in ['Indian Penal Code', 'Code of Criminal Procedure', 
                        'Constitution of India', 'Indian Evidence Act']:
            rb_node_id = f"Rulebook_{rulebook.replace(' ', '_')}"
            G.add_node(rb_node_id, node_type='rulebook', label=rulebook, size=5000)
            rulebook_nodes[rulebook] = rb_node_id        
        global_acts = {}
        global_provisions = {}        
        for idx, case in enumerate(sampled_cases):
            case_id = case.get('file_name', f"{court_name}_{idx}")
            entities = self.entity_extraction(case)
            self.entity_mappings[case_id] = entities            
            case_node_id = f"Case_{court_name}_{idx+1}"
            case_types = entities.get('case_types', ['General'])
            year = self._extract_year(case)
            G.add_node(case_node_id, node_type='case',
                      label=f"Case {idx+1}", case_types=case_types,
                      year=year, court=court_name, size=3000,
                      validated_provisions=len(entities.get('validated_provisions', [])))
            G.add_edge(court_node_id, case_node_id, relation='adjudicated', weight=1.0)
            for rulebook in entities.get('detected_rulebooks', []):
                if rulebook in rulebook_nodes:
                    G.add_edge(case_node_id, rulebook_nodes[rulebook],
                              relation='references_rulebook', weight=1.0)
            for party_type in ['petitioner', 'respondent']:
                parties = entities['parties'].get(party_type, [])
                for party in parties[:2]:
                    party_node_id = f"Party_{party[:30]}"
                    if party_node_id not in G:
                        G.add_node(party_node_id, node_type='party', 
                                  label=party[:30], size=2200)
                    G.add_edge(case_node_id, party_node_id, relation=party_type, weight=1.0)
            all_provisions = []
            for prov_type, prov_list in entities['provisions'].items():
                all_provisions.extend([(prov_type, p) for p in prov_list[:5]])
            for prov_type, provision in all_provisions[:8]:
                prov_details = entities.get('provision_details', {}).get(provision, {})
                parent_act = prov_details.get('act', 'Unknown')
                prov_key = f"{parent_act}_{prov_type}_{provision[:25]}"
                is_validated = provision in entities.get('validated_provisions', [])
                parent_rulebook = prov_details.get('rulebook', '')                
                if prov_key not in global_provisions:
                    prov_node_id = f"Provision_{len(global_provisions)}"
                    global_provisions[prov_key] = prov_node_id
                    G.add_node(prov_node_id,
                              node_type='provision',
                              label=provision[:25],
                              size=2000 if is_validated else 1800,
                              validated=is_validated,
                              rulebook=parent_rulebook,
                              citation_count=1)
                else:
                    prov_node_id = global_provisions[prov_key]
                    G.nodes[prov_node_id]['citation_count'] = G.nodes[prov_node_id].get('citation_count', 0) + 1
                G.add_edge(case_node_id, prov_node_id, relation='cites_provision',
                          weight=1.5 if is_validated else 1.0)
                if is_validated and parent_rulebook and parent_rulebook in rulebook_nodes:
                    if not G.has_edge(prov_node_id, rulebook_nodes[parent_rulebook]):
                        G.add_edge(prov_node_id, rulebook_nodes[parent_rulebook],
                                  relation='defined_in', weight=2.0)            
            for act in entities.get('acts', [])[:5]:
                act_name = act['name'] if isinstance(act, dict) else act
                act_key = act_name[:50]
                mapped_rulebook = entities.get('act_to_rulebook_map', {}).get(act_name, '')                
                if act_key not in global_acts:
                    act_node_id = f"Act_{len(global_acts)}"
                    global_acts[act_key] = act_node_id
                    G.add_node(act_node_id, node_type='act', label=act_name[:50],
                              size=2600 if mapped_rulebook else 2400,
                              citation_count=1)
                else:
                    act_node_id = global_acts[act_key]
                    G.nodes[act_node_id]['citation_count'] = G.nodes[act_node_id].get('citation_count', 0) + 1
                G.add_edge(case_node_id, act_node_id, relation='governed_by',
                          weight=1.5 if mapped_rulebook else 1.0)
                if mapped_rulebook and mapped_rulebook in rulebook_nodes:
                    if not G.has_edge(act_node_id, rulebook_nodes[mapped_rulebook]):
                        G.add_edge(act_node_id, rulebook_nodes[mapped_rulebook],
                                  relation='codified_in', weight=2.0)
            for concept in entities.get('legal_concepts', [])[:6]:
                concept_node_id = f"Concept_{concept[:30]}"
                if concept_node_id not in G:
                    G.add_node(concept_node_id, node_type='concept', 
                              label=concept[:30], size=1600)
                G.add_edge(case_node_id, concept_node_id, 
                          relation='involves_concept', weight=1.0)
            for doctrine in entities.get('doctrines', [])[:4]:
                doctrine_node_id = f"Doctrine_{doctrine[:30]}"
                if doctrine_node_id not in G:
                    G.add_node(doctrine_node_id, node_type='doctrine', 
                              label=doctrine[:30], size=1700)
                G.add_edge(case_node_id, doctrine_node_id, 
                          relation='applies_doctrine', weight=1.0)
            for judge in entities.get('judges', [])[:3]:
                judge_node_id = f"Judge_{judge[:30]}"
                if judge_node_id not in G:
                    G.add_node(judge_node_id, node_type='judge', 
                              label=judge[:30], size=2000)
                G.add_edge(case_node_id, judge_node_id, relation='decided_by', weight=1.0)
        logger.info(f"{court_name} graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
        return G    
    def create_unified_knowledge_graph(self, max_cases_per_court: Optional[int] = None,
                                      batch_size: int = 500) -> nx.DiGraph:
        logger.info("Creating unified knowledge graph with batch processing...")
        G = nx.DiGraph()
        supreme_node = "Supreme_Court_of_India"
        G.add_node(supreme_node, node_type='supreme_court',
                  label='Supreme Court of India', hierarchy_level=0, size=10000)
        high_courts = ['allahabad_high_court', 'bombay_high_court', 'calcutta_high_court',
                      'delhi_high_court', 'madras_high_court']
        for court in high_courts:
            court_node = f"Court_{court}"
            G.add_node(court_node, node_type='high_court',
                      label=court.replace('_', ' ').title(), hierarchy_level=1, size=7000)
            G.add_edge(supreme_node, court_node, relation='superior_to', weight=2.0)        
        sc_node = "Court_supreme_court"
        G.add_node(sc_node, node_type='supreme_court_cases',
                  label='Supreme Court Cases', hierarchy_level=0, size=8000)
        global_entities = {
            'acts': {}, 'provisions': {}, 'concepts': {},
            'doctrines': {}, 'parties': {}, 'judges': {}}
        case_counter = 0
        for court_name, court_cases in self.cases_data.items():
            logger.info(f"Processing {court_name}...")
            if max_cases_per_court and len(court_cases) > max_cases_per_court:
                sampled_cases = random.sample(court_cases, max_cases_per_court)
            else:
                sampled_cases = court_cases
            court_node = f"Court_{court_name}"
            for batch_start in range(0, len(sampled_cases), batch_size):
                batch_end = min(batch_start + batch_size, len(sampled_cases))
                batch = sampled_cases[batch_start:batch_end]
                logger.info(f"  Batch {batch_start//batch_size + 1}: Processing cases {batch_start} to {batch_end}")
                for case in tqdm(batch, desc=f"Batch {batch_start}-{batch_end}"):
                    case_counter += 1
                    case_id = case.get('file_name', f"{court_name}_{case_counter}")
                    if case_id not in self.entity_mappings:
                        entities = self.entity_extraction(case)
                        self.entity_mappings[case_id] = entities
                    else:
                        entities = self.entity_mappings[case_id]
                    self._add_case_to_graph(G, case_id, case_counter, court_name, 
                                           entities, court_node, global_entities)
                gc.collect()
                logger.info(f"  Memory: Nodes={G.number_of_nodes():,}, Edges={G.number_of_edges():,}")
        logger.info("Calculating graph metrics...")
        self._calculate_graph_metrics(G)
        logger.info(f"Unified graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
        self.global_knowledge_graph = G
        return G
    def _add_case_to_graph(self, G, case_id, case_counter, court_name, 
                           entities, court_node, global_entities):
        case_node_id = f"Case_{case_counter}"
        year = self._extract_year({'file_name': case_id, **entities})
        case_types = entities.get('case_types', ['General'])
        G.add_node(case_node_id, node_type='case', label=f"Case {case_counter}",
                  case_id=case_id, court=court_name, year=year,
                  case_types=case_types, size=2500)
        G.add_edge(court_node, case_node_id, relation='adjudicated', weight=1.0)
        for act_item in entities.get('acts', [])[:6]:
            act_name = act_item['name'] if isinstance(act_item, dict) else act_item
            act_key = act_name[:50]
            if act_key not in global_entities['acts']:
                act_node_id = f"Act_{len(global_entities['acts'])}"
                global_entities['acts'][act_key] = act_node_id
                G.add_node(act_node_id, node_type='act', label=act_name[:50],
                          size=3500, citation_count=1)
            else:
                act_node_id = global_entities['acts'][act_key]
                G.nodes[act_node_id]['citation_count'] += 1
            G.add_edge(case_node_id, act_node_id, relation='governed_by', weight=1.0)
        for prov_type, prov_list in entities['provisions'].items():
            for provision in prov_list[:8]:
                prov_details = entities.get('provision_details', {}).get(provision, {})
                parent_act = prov_details.get('act', 'Unknown')
                prov_key = f"{parent_act}_{prov_type}_{provision[:30]}"
                if prov_key not in global_entities['provisions']:
                    prov_node_id = f"Prov_{len(global_entities['provisions'])}"
                    global_entities['provisions'][prov_key] = prov_node_id
                    G.add_node(prov_node_id, node_type='provision',
                              label=provision[:30], size=2000, citation_count=1,
                              parent_act=parent_act)
                else:
                    prov_node_id = global_entities['provisions'][prov_key]
                    G.nodes[prov_node_id]['citation_count'] += 1
                G.add_edge(case_node_id, prov_node_id, relation='cites_provision', weight=1.0)
        for concept in entities.get('legal_concepts', [])[:8]:
            concept_key = concept[:40]
            if concept_key not in global_entities['concepts']:
                concept_node_id = f"Concept_{len(global_entities['concepts'])}"
                global_entities['concepts'][concept_key] = concept_node_id
                G.add_node(concept_node_id, node_type='concept',
                          label=concept[:40], size=2200, mention_count=1)
            else:
                concept_node_id = global_entities['concepts'][concept_key]
                G.nodes[concept_node_id]['mention_count'] += 1
            G.add_edge(case_node_id, concept_node_id, relation='involves_concept', weight=1.0)
        for doctrine in entities.get('doctrines', [])[:5]:
            doctrine_key = doctrine[:40]
            if doctrine_key not in global_entities['doctrines']:
                doctrine_node_id = f"Doctrine_{len(global_entities['doctrines'])}"
                global_entities['doctrines'][doctrine_key] = doctrine_node_id
                G.add_node(doctrine_node_id, node_type='doctrine',
                          label=doctrine[:40], size=2400, application_count=1)
            else:
                doctrine_node_id = global_entities['doctrines'][doctrine_key]
                G.nodes[doctrine_node_id]['application_count'] += 1
            G.add_edge(case_node_id, doctrine_node_id, relation='applies_doctrine', weight=1.0)
        for judge in entities.get('judges', [])[:3]:
            judge_key = judge[:40]
            if judge_key not in global_entities['judges']:
                judge_node_id = f"Judge_{len(global_entities['judges'])}"
                global_entities['judges'][judge_key] = judge_node_id
                G.add_node(judge_node_id, node_type='judge',
                          label=judge[:40], size=2600, case_count=1)
            else:
                judge_node_id = global_entities['judges'][judge_key]
                G.nodes[judge_node_id]['case_count'] += 1
            G.add_edge(case_node_id, judge_node_id, relation='decided_by', weight=1.0)
    def _calculate_graph_metrics(self, G: nx.DiGraph):
        logger.info("Computing centrality measures...")
        degree_centrality = nx.degree_centrality(G)
        for node, centrality in degree_centrality.items():
            G.nodes[node]['degree_centrality'] = centrality
        try:
            max_iter = 30 if G.number_of_nodes() > 10000 else 50
            pagerank = nx.pagerank(G, max_iter=max_iter, tol=1e-4)
            for node, score in pagerank.items():
                G.nodes[node]['pagerank'] = score
        except:
            logger.warning("PageRank calculation failed")        
        if G.number_of_nodes() < 3000:
            try:
                k_sample = min(50, G.number_of_nodes() // 10)
                betweenness = nx.betweenness_centrality(G, k=k_sample)
                for node, score in betweenness.items():
                    G.nodes[node]['betweenness'] = score
            except:
                logger.warning("Betweenness calculation failed")
        else:
            logger.info("Skipping betweenness centrality for large graph (>3000 nodes)")
    def visualize_knowledge_graph(self, G: nx.DiGraph, title: str, figsize=(28, 22)):
        if G is None or G.number_of_nodes() == 0:
            logger.warning("Empty graph, skipping visualization")
            return
        logger.info(f"Visualizing: {title}...")
        plt.figure(figsize=figsize)
        color_map = {
            'supreme_court': '#8B0000', 'high_court': '#DC143C',
            'supreme_court_cases': '#B22222', 'court': '#E74C3C',
            'case': '#3498DB', 'party': '#2ECC71', 'provision': '#9B59B6',
            'act': '#F39C12', 'concept': '#1ABC9C', 'doctrine': '#E67E22',
            'judge': '#34495E', 'cited_case': '#95A5A6', 'rulebook': '#8E44AD'}
        node_colors = [color_map.get(G.nodes[node].get('node_type', 'case'), '#95A5A6')
                      for node in G.nodes()]
        node_sizes = [G.nodes[node].get('size', 1500) for node in G.nodes()]
        logger.info("Computing layout...")
        if G.number_of_nodes() < 500:
            pos = nx.spring_layout(G, k=3, iterations=50, seed=42, scale=3)
        else:
            pos = nx.kamada_kawai_layout(G, scale=3)        
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes,
                              alpha=0.85, edgecolors='black', linewidths=1.5)
        if G.number_of_nodes() < 300:
            labels = {node: G.nodes[node].get('label', str(node))[:25] for node in G.nodes()}
            nx.draw_networkx_labels(G, pos, labels, font_size=6, font_weight='bold')
        else:
            important_nodes = [n for n in G.nodes() 
                             if G.nodes[n].get('node_type') in ['supreme_court', 'high_court', 'court']]
            labels = {node: G.nodes[node].get('label', str(node))[:25] for node in important_nodes}
            nx.draw_networkx_labels(G, pos, labels, font_size=8, font_weight='bold')
        nx.draw_networkx_edges(G, pos, alpha=0.2, arrows=True, arrowsize=8,
                              edge_color='#34495E', width=0.8)
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
                          markerfacecolor=color, markersize=12,
                          label=ntype.replace('_', ' ').title())
                          for ntype, color in color_map.items()]
        plt.legend(handles=legend_elements, loc='upper left', fontsize=10, ncol=2)
        stats_text = f"{title}\nNodes: {G.number_of_nodes()} | Edges: {G.number_of_edges()}"
        plt.title(stats_text, fontsize=18, fontweight='bold', pad=20)
        plt.axis('off')
        plt.tight_layout()
        plt.show()
        plt.close()
    def visualize_hierarchical_graph(self, G: nx.DiGraph, title: str, figsize=(30, 24)):
        if G is None or G.number_of_nodes() == 0:
            logger.warning("Empty graph, skipping visualization")
            return
        logger.info(f"Creating hierarchical visualization: {title}...")
        fig, ax = plt.subplots(figsize=figsize)
        color_map = {
            'supreme_court': '#8B0000', 'high_court': '#DC143C',
            'supreme_court_cases': '#B22222', 'court': '#E74C3C',
            'case': '#3498DB', 'party': '#2ECC71', 'provision': '#9B59B6',
            'act': '#F39C12', 'concept': '#1ABC9C', 'doctrine': '#E67E22',
            'judge': '#34495E', 'cited_case': '#95A5A6', 'rulebook': '#8E44AD'}        
        edge_color_map = {
            'adjudicated': '#E74C3C', 'decided_by': '#34495E',
            'governed_by': '#F39C12', 'cites_provision': '#9B59B6',
            'involves_concept': '#1ABC9C', 'applies_doctrine': '#E67E22',
            'references_rulebook': '#8E44AD', 'defined_in': '#6C3483',
            'codified_in': '#7D3C98', 'superior_to': '#8B0000'}
        logger.info("Computing hierarchical layout...")
        pos = self._compute_hierarchical_layout(G)
        node_colors = [color_map.get(G.nodes[node].get('node_type', 'case'), '#95A5A6')
                      for node in G.nodes()]
        node_sizes = [G.nodes[node].get('size', 1500) for node in G.nodes()]        
        edge_types = defaultdict(list)
        for u, v, data in G.edges(data=True):
            relation = data.get('relation', 'unknown')
            edge_types[relation].append((u, v))
        for relation, edges in edge_types.items():
            color = edge_color_map.get(relation, '#34495E')
            alpha = 0.4 if relation in ['adjudicated', 'governed_by'] else 0.2
            width = 1.5 if relation in ['adjudicated', 'governed_by'] else 0.8
            nx.draw_networkx_edges(G, pos, edgelist=edges, edge_color=color,
                                  alpha=alpha, arrows=True, arrowsize=10,
                                  width=width, ax=ax)        
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes,
                              alpha=0.9, edgecolors='black', linewidths=2, ax=ax)        
        if G.number_of_nodes() < 200:
            labels = {node: G.nodes[node].get('label', str(node))[:20] for node in G.nodes()}
            nx.draw_networkx_labels(G, pos, labels, font_size=7, font_weight='bold', ax=ax)
        else:
            important_types = ['supreme_court', 'high_court', 'court', 'supreme_court_cases']
            important_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') in important_types]
            labels = {node: G.nodes[node].get('label', str(node))[:20] for node in important_nodes}
            nx.draw_networkx_labels(G, pos, labels, font_size=8, font_weight='bold', ax=ax)        
        node_legend = [plt.Line2D([0], [0], marker='o', color='w',
                      markerfacecolor=color, markersize=12,
                      label=ntype.replace('_', ' ').title())
                      for ntype, color in color_map.items()]
        edge_legend = [plt.Line2D([0], [0], color=color, linewidth=2,
                      label=relation.replace('_', ' ').title())
                      for relation, color in list(edge_color_map.items())[:8]]
        first_legend = ax.legend(handles=node_legend, loc='upper left', fontsize=9,
                                title='Node Types', title_fontsize=10)
        ax.add_artist(first_legend)
        ax.legend(handles=edge_legend, loc='upper right', fontsize=9,
                 title='Relationships', title_fontsize=10)
        stats_text = f"{title} (Hierarchical Layout)\nNodes: {G.number_of_nodes()} | Edges: {G.number_of_edges()}"
        ax.set_title(stats_text, fontsize=18, fontweight='bold', pad=20)
        ax.axis('off')
        plt.tight_layout()
        plt.show()
        plt.close()
    def _compute_hierarchical_layout(self, G: nx.DiGraph) -> Dict:
        pos = {}        
        hierarchy_levels = {
            'supreme_court': 0, 'supreme_court_cases': 1, 'high_court': 1,
            'court': 2, 'rulebook': 2, 'case': 3, 'judge': 4, 'act': 4,
            'provision': 5, 'doctrine': 5, 'concept': 5, 'party': 6, 'cited_case': 6}
        levels = defaultdict(list)
        for node in G.nodes():
            node_type = G.nodes[node].get('node_type', 'case')
            level = hierarchy_levels.get(node_type, 5)
            levels[level].append(node)
        y_spacing = 1.5
        max_level = max(levels.keys()) if levels else 0
        for level, nodes in levels.items():
            y = (max_level - level) * y_spacing
            x_spacing = 10.0 / max(len(nodes), 1)
            for i, node in enumerate(nodes):
                x = (i - len(nodes) / 2) * x_spacing
                pos[node] = (x, y)
        return pos    
    def visualize_graph_statistics_panel(self, G: nx.DiGraph, title: str):
        logger.info(f"Creating statistics panel: {title}...")
        fig = plt.figure(figsize=(24, 16))
        gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)        
        ax1 = fig.add_subplot(gs[0, 0])
        node_types = [G.nodes[n].get('node_type', 'unknown') for n in G.nodes()]
        type_counts = Counter(node_types)
        colors_list = ['#3498DB', '#E74C3C', '#F39C12', '#9B59B6', '#1ABC9C', 
                      '#2ECC71', '#E67E22', '#34495E', '#95A5A6']
        ax1.pie(type_counts.values(), labels=[t.replace('_', ' ').title() for t in type_counts.keys()],
               autopct='%1.1f%%', colors=colors_list[:len(type_counts)], startangle=90)
        ax1.set_title('Node Type Distribution', fontsize=12, fontweight='bold')        
        ax2 = fig.add_subplot(gs[0, 1])
        edge_types = [G.edges[e].get('relation', 'unknown') for e in G.edges()]
        edge_counts = Counter(edge_types)
        top_edges = dict(edge_counts.most_common(8))
        ax2.barh(list(top_edges.keys()), list(top_edges.values()), color='#3498DB')
        ax2.set_xlabel('Count', fontsize=10)
        ax2.set_title('Top Edge Types', fontsize=12, fontweight='bold')
        ax2.invert_yaxis()            
        ax9 = fig.add_subplot(gs[0,2])
        case_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') == 'case']
        if case_nodes:
            case_types = []
            for node in case_nodes:
                types = G.nodes[node].get('case_types', ['General'])
                case_types.extend(types)
            type_counts = Counter(case_types)
            if type_counts:
                top_types = dict(type_counts.most_common(8))
                colors_case = plt.cm.Set3(range(len(top_types)))
                ax9.pie(top_types.values(), labels=top_types.keys(),
                       autopct='%1.1f%%', colors=colors_case, startangle=90)
                ax9.set_title('Case Type Distribution', fontsize=12, fontweight='bold')
        else:
            ax9.text(0.5, 0.5, 'No case data', ha='center', va='center')
            ax9.set_title('Case Type Distribution', fontsize=12, fontweight='bold')
        fig.suptitle(f'{title} - Comprehensive Statistics Dashboard',
                    fontsize=18, fontweight='bold', y=0.98)
        plt.tight_layout()
        plt.show()
        plt.close()
    def visualize_subgraph_citation_network(self, G: nx.DiGraph, title: str, figsize=(24, 18)):
        logger.info(f"Creating citation network subgraph: {title}...")
        citation_edges = [(u, v) for u, v, data in G.edges(data=True)
                         if data.get('relation') in ['cites_case', 'cites_provision']]
        if not citation_edges:
            logger.warning("No citation edges found")
            return
        citation_nodes = set()
        for u, v in citation_edges:
            citation_nodes.add(u)
            citation_nodes.add(v)
        subG = G.subgraph(citation_nodes).copy()
        fig, ax = plt.subplots(figsize=figsize)
        color_map = {'case': '#3498DB', 'cited_case': '#95A5A6', 'provision': '#9B59B6'}
        node_colors = [color_map.get(subG.nodes[node].get('node_type', 'case'), '#95A5A6')
                      for node in subG.nodes()]
        node_sizes = [subG.nodes[node].get('size', 1500) * 1.2 for node in subG.nodes()]
        pos = nx.spring_layout(subG, k=2, iterations=50, seed=42)
        nx.draw_networkx_nodes(subG, pos, node_color=node_colors, node_size=node_sizes,
                              alpha=0.85, edgecolors='black', linewidths=2, ax=ax)
        nx.draw_networkx_edges(subG, pos, alpha=0.4, arrows=True, arrowsize=12,
                              edge_color='#8E44AD', width=1.5, ax=ax)
        if len(subG.nodes()) < 100:
            labels = {node: subG.nodes[node].get('label', str(node))[:20] for node in subG.nodes()}
            nx.draw_networkx_labels(subG, pos, labels, font_size=7, font_weight='bold', ax=ax)
        ax.set_title(f"{title} - Citation Network\nNodes: {subG.number_of_nodes()} | Links: {subG.number_of_edges()}",
                    fontsize=16, fontweight='bold', pad=20)
        ax.axis('off')
        plt.tight_layout()
        plt.show()
        plt.close()
    def visualize_subgraph_statutory_framework(self, G: nx.DiGraph, title: str, figsize=(24, 18)):
        logger.info(f"Creating statutory framework subgraph: {title}...")
        statutory_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') in ['act', 'provision', 'case']]
        if not statutory_nodes:
            logger.warning("No statutory nodes found")
            return
        subG = G.subgraph(statutory_nodes).copy()
        edges_to_remove = []
        for u, v, data in subG.edges(data=True):
            if data.get('relation') not in ['governed_by', 'cites_provision']:
                edges_to_remove.append((u, v))
        subG.remove_edges_from(edges_to_remove)
        fig, ax = plt.subplots(figsize=figsize)
        color_map = {'act': '#F39C12', 'provision': '#9B59B6', 'case': '#3498DB'}
        node_colors = [color_map.get(subG.nodes[node].get('node_type', 'case'), '#95A5A6')
                      for node in subG.nodes()]
        node_sizes = [subG.nodes[node].get('size', 1500) * 1.2 for node in subG.nodes()]
        pos = self._compute_hierarchical_layout(subG)
        nx.draw_networkx_nodes(subG, pos, node_color=node_colors, node_size=node_sizes,
                              alpha=0.85, edgecolors='black', linewidths=2, ax=ax)
        governed_edges = [(u, v) for u, v, d in subG.edges(data=True) if d.get('relation') == 'governed_by']
        citation_edges = [(u, v) for u, v, d in subG.edges(data=True) if d.get('relation') == 'cites_provision']
        nx.draw_networkx_edges(subG, pos, edgelist=governed_edges, alpha=0.5,
                              arrows=True, arrowsize=12, edge_color='#E67E22',
                              width=2, ax=ax, label='Governed By')
        nx.draw_networkx_edges(subG, pos, edgelist=citation_edges, alpha=0.4,
                              arrows=True, arrowsize=10, edge_color='#8E44AD',
                              width=1.5, ax=ax, label='Cites Provision')
        if len(subG.nodes()) < 100:
            labels = {node: subG.nodes[node].get('label', str(node))[:25] for node in subG.nodes()}
            nx.draw_networkx_labels(subG, pos, labels, font_size=7, font_weight='bold', ax=ax)
        ax.set_title(f"{title} - Statutory Framework\nNodes: {subG.number_of_nodes()} | Relationships: {subG.number_of_edges()}",
                    fontsize=16, fontweight='bold', pad=20)
        ax.legend(fontsize=12, loc='upper right')
        ax.axis('off')
        plt.tight_layout()
        plt.show()
        plt.close()
    def visualize_subgraph_concept_doctrine_network(self, G: nx.DiGraph, title: str, figsize=(24, 18)):
        logger.info(f"Creating concept-doctrine network subgraph: {title}...")
        concept_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') in ['concept', 'doctrine', 'case']]
        if not concept_nodes:
            logger.warning("No concept/doctrine nodes found")
            return
        subG = G.subgraph(concept_nodes).copy()
        edges_to_keep = []
        for u, v, data in subG.edges(data=True):
            relation = data.get('relation', '')
            if relation in ['involves_concept', 'applies_doctrine']:
                edges_to_keep.append((u, v))
        filtered_G = nx.DiGraph()
        filtered_G.add_nodes_from(subG.nodes(data=True))
        filtered_G.add_edges_from([(u, v, subG[u][v]) for u, v in edges_to_keep])
        fig, ax = plt.subplots(figsize=figsize)
        color_map = {'concept': '#1ABC9C', 'doctrine': '#E67E22', 'case': '#3498DB'}
        node_colors = [color_map.get(filtered_G.nodes[node].get('node_type', 'concept'), '#95A5A6')
                      for node in filtered_G.nodes()]
        node_sizes = [filtered_G.nodes[node].get('size', 1500) * 1.2 for node in filtered_G.nodes()]
        pos = nx.spring_layout(filtered_G, k=2.5, iterations=50, seed=42)
        nx.draw_networkx_nodes(filtered_G, pos, node_color=node_colors, node_size=node_sizes,
                              alpha=0.85, edgecolors='black', linewidths=2, ax=ax)
        nx.draw_networkx_edges(filtered_G, pos, alpha=0.3, arrows=True, arrowsize=10,
                              edge_color='#16A085', width=1.2, ax=ax)
        if len(filtered_G.nodes()) < 100:
            labels = {node: filtered_G.nodes[node].get('label', str(node))[:25] for node in filtered_G.nodes()}
            nx.draw_networkx_labels(filtered_G, pos, labels, font_size=8, font_weight='bold', ax=ax)
        legend_elements = [plt.Line2D([0], [0], marker='o', color='w',
                          markerfacecolor=color, markersize=14, label=ntype.title())
                          for ntype, color in color_map.items()]
        ax.legend(handles=legend_elements, fontsize=12, loc='upper right')
        ax.set_title(f"{title} - Legal Concepts & Doctrines\nNodes: {filtered_G.number_of_nodes()} | Relationships: {filtered_G.number_of_edges()}",
                    fontsize=16, fontweight='bold', pad=20)
        ax.axis('off')
        plt.tight_layout()
        plt.show()
        plt.close()
    def visualize_rulebook_integration_graph(self, G: nx.DiGraph, title: str, figsize=(26, 20)):
        logger.info(f"Creating rulebook integration graph: {title}...")
        rulebook_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') in 
                         ['rulebook', 'case', 'act', 'provision']]
        if not rulebook_nodes:
            logger.warning("No rulebook nodes found")
            return
        subG = G.subgraph(rulebook_nodes).copy()
        edges_to_keep = []
        for u, v, data in subG.edges(data=True):
            relation = data.get('relation', '')
            if relation in ['references_rulebook', 'defined_in', 'codified_in', 'governed_by', 'cites_provision']:
                edges_to_keep.append((u, v))
        filtered_G = nx.DiGraph()
        filtered_G.add_nodes_from(subG.nodes(data=True))
        filtered_G.add_edges_from([(u, v, subG[u][v]) for u, v in edges_to_keep])
        fig, ax = plt.subplots(figsize=figsize)
        color_map = {'rulebook': '#8E44AD', 'case': '#3498DB', 'act': '#F39C12', 'provision': '#9B59B6'}
        node_colors = [color_map.get(filtered_G.nodes[node].get('node_type', 'case'), '#95A5A6')
                      for node in filtered_G.nodes()]
        node_sizes = []
        for node in filtered_G.nodes():
            base_size = filtered_G.nodes[node].get('size', 2000)
            if filtered_G.nodes[node].get('validated', False):
                node_sizes.append(base_size * 1.3)
            else:
                node_sizes.append(base_size)
        pos = self._compute_hierarchical_layout(filtered_G)
        edge_types = {
            'references_rulebook': [], 'defined_in': [], 'codified_in': [],
            'governed_by': [], 'cites_provision': []}
        for u, v, data in filtered_G.edges(data=True):
            relation = data.get('relation', 'unknown')
            if relation in edge_types:
                edge_types[relation].append((u, v))
        edge_colors = {
            'references_rulebook': '#8E44AD', 'defined_in': '#6C3483',
            'codified_in': '#7D3C98', 'governed_by': '#F39C12', 'cites_provision': '#9B59B6'}
        for relation, edges in edge_types.items():
            if edges:
                nx.draw_networkx_edges(filtered_G, pos, edgelist=edges,
                                      edge_color=edge_colors[relation], alpha=0.6,
                                      arrows=True, arrowsize=12, width=2.5 if 'rulebook' in relation else 1.5,
                                      ax=ax, label=relation.replace('_', ' ').title())
        nx.draw_networkx_nodes(filtered_G, pos, node_color=node_colors, node_size=node_sizes,
                              alpha=0.9, edgecolors='black', linewidths=2, ax=ax)
        if len(filtered_G.nodes()) < 150:
            labels = {node: filtered_G.nodes[node].get('label', str(node))[:30] for node in filtered_G.nodes()}
            nx.draw_networkx_labels(filtered_G, pos, labels, font_size=7, font_weight='bold', ax=ax)
        else:
            important_nodes = [n for n in filtered_G.nodes() 
                             if filtered_G.nodes[n].get('node_type') in ['rulebook', 'case']]
            labels = {node: filtered_G.nodes[node].get('label', str(node))[:30] for node in important_nodes}
            nx.draw_networkx_labels(filtered_G, pos, labels, font_size=8, font_weight='bold', ax=ax)
        node_legend = [plt.Line2D([0], [0], marker='o', color='w',
                      markerfacecolor=color, markersize=14, label=ntype.title())
                      for ntype, color in color_map.items()]        
        first_legend = ax.legend(handles=node_legend, loc='upper left', fontsize=10,
                                title='Node Types', title_fontsize=11)
        ax.add_artist(first_legend)
        ax.legend(fontsize=10, loc='upper right', title='Relationships', title_fontsize=11)
        ax.set_title(f"{title} - Rulebook Integration\nNodes: {filtered_G.number_of_nodes()} | "
                    f"Relationships: {filtered_G.number_of_edges()}",
                    fontsize=16, fontweight='bold', pad=20)
        ax.axis('off')
        plt.tight_layout()
        plt.show()
        plt.close()
    def export_for_gnn(self, output_dir: str = "gnn_data"):
        logger.info("Exporting data for GNN...")        
        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)
        G = self.global_knowledge_graph
        node_list = list(G.nodes())
        node_to_idx = {node: idx for idx, node in enumerate(node_list)}
        node_features = []
        node_metadata = []
        for node in node_list:
            attrs = G.nodes[node]
            features = [
                attrs.get('size', 1500) / 10000, attrs.get('degree_centrality', 0),
                attrs.get('pagerank', 0), attrs.get('betweenness', 0), attrs.get('citation_count', 0) / 100,
                attrs.get('mention_count', 0) / 100, attrs.get('hierarchy_level', 2) / 2,]
            node_features.append(features)
            node_metadata.append({
                'id': node, 'label': attrs.get('label', ''),
                'type': attrs.get('node_type', ''),
                'attributes': {k: v for k, v in attrs.items() if k != 'label'}})
        np.save(output_path / 'node_features.npy', np.array(node_features))
        with open(output_path / 'node_metadata.json', 'w') as f:
            json.dump(node_metadata, f, indent=2)
        edge_list = []
        edge_features = []
        edge_types = []
        for u, v, attrs in G.edges(data=True):
            edge_list.append([node_to_idx[u], node_to_idx[v]])
            features = [
                attrs.get('weight', 1.0),
                1 if attrs.get('relation') == 'cites_provision' else 0,
                1 if attrs.get('relation') == 'superior_to' else 0,
                1 if attrs.get('relation') == 'governed_by' else 0,]
            edge_features.append(features)
            edge_types.append(attrs.get('relation', 'unknown'))
        np.save(output_path / 'edge_list.npy', np.array(edge_list))
        np.save(output_path / 'edge_features.npy', np.array(edge_features))
        with open(output_path / 'edge_types.json', 'w') as f:
            json.dump(edge_types, f)
        stats = {
            'num_nodes': G.number_of_nodes(), 'num_edges': G.number_of_edges(),
            'node_type_distribution': dict(Counter([G.nodes[n].get('node_type', 'unknown') for n in G.nodes()])),
            'edge_type_distribution': dict(Counter(edge_types)),
            'avg_degree': sum(dict(G.degree()).values()) / max(G.number_of_nodes(), 1),}
        with open(output_path / 'graph_stats.json', 'w') as f:
            json.dump(stats, f, indent=2)        
        try:
            import pickle
            with open(output_path / 'knowledge_graph.gpickle', 'wb') as f:
                pickle.dump(G, f)
            logger.info("Graph saved as knowledge_graph.gpickle")
        except Exception as e:
            logger.warning(f"Failed to save graph with pickle: {e}")
        try:
            self._export_pyg_format(output_path, node_to_idx)
        except Exception as e:
            logger.warning(f"PyG export failed: {e}")
        logger.info(f"GNN data exported to {output_path}")
    def _export_pyg_format(self, output_path: Path, node_to_idx: Dict):
        try:
            import torch
            from torch_geometric.data import Data            
            G = self.global_knowledge_graph
            x = torch.tensor(np.load(output_path / 'node_features.npy'), dtype=torch.float)
            edge_index = torch.tensor(np.load(output_path / 'edge_list.npy').T, dtype=torch.long)
            edge_attr = torch.tensor(np.load(output_path / 'edge_features.npy'), dtype=torch.float)
            data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
            torch.save(data, output_path / 'pyg_data.pt')
            logger.info("PyTorch Geometric format exported")
        except ImportError:
            logger.warning("PyTorch or PyTorch Geometric not installed, skipping")
    def generate_comprehensive_report(self):
        logger.info("\n")
        logger.info("Knowledge Graph Analysis Report:")        
        logger.info(f"\nOverall Statistics:")
        logger.info(f"Total Cases: {self.stats['total_cases']:,}")
        logger.info(f"Cases By Court:")
        for court, count in sorted(self.stats['by_court'].items()):
            logger.info(f"  {court.replace('_', ' ').title()}: {count:,}")
        logger.info(f"Cases By Year (Top 10):")
        for year, count in Counter(self.stats['by_year']).most_common(10):
            logger.info(f"  {year}: {count:,}")        
        if self.global_knowledge_graph:
            G = self.global_knowledge_graph
            logger.info(f"\nUnified Knowledge Graph:")
            logger.info(f"  Total Nodes: {G.number_of_nodes():,}")
            logger.info(f"  Total Edges: {G.number_of_edges():,}")
            logger.info(f"  Avg Degree: {sum(dict(G.degree()).values()) / G.number_of_nodes():.2f}")
            logger.info(f"\nNode Type Distribution:")
            node_types = Counter([G.nodes[n].get('node_type', 'unknown') for n in G.nodes()])
            for ntype, count in node_types.most_common():
                logger.info(f"  {ntype.replace('_', ' ').title()}: {count:,}")
            logger.info(f"\nEdge Type Distribution:")
            edge_types = Counter([G.edges[e].get('relation', 'unknown') for e in G.edges()])
            for etype, count in edge_types.most_common():
                logger.info(f"  {etype.replace('_', ' ').title()}: {count:,}")            
            logger.info(f"\nTop Cited Acts:")
            act_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') == 'act']
            top_acts = sorted(act_nodes, key=lambda n: G.nodes[n].get('citation_count', 0), reverse=True)[:10]
            for act in top_acts:
                count = G.nodes[act].get('citation_count', 0)
                label = G.nodes[act].get('label', act)
                logger.info(f"  {label}: {count} citations")
            logger.info(f"\nTop Cited Provisions:")
            prov_nodes = [n for n in G.nodes() if G.nodes[n].get('node_type') == 'provision']
            top_provs = sorted(prov_nodes, key=lambda n: G.nodes[n].get('citation_count', 0), reverse=True)[:10]
            for prov in top_provs:
                count = G.nodes[prov].get('citation_count', 0)
                label = G.nodes[prov].get('label', prov)
                validated = " ✓" if G.nodes[prov].get('validated', False) else ""
                parent_act = G.nodes[prov].get('parent_act', '')
                rulebook = G.nodes[prov].get('rulebook', '')
                if parent_act and parent_act != 'Unknown':
                    logger.info(f"  {label}{validated}: {count} citations [{parent_act}]")
                elif rulebook:
                    logger.info(f"  {label}{validated}: {count} citations [{rulebook}]")
                else:
                    logger.info(f"  {label}{validated}: {count} citations")            
            logger.info(f"\nRulebook Integration Statistics:")
            rulebook_stats = defaultdict(int)
            validated_provisions = 0
            total_provisions = 0
            for node in G.nodes():
                if G.nodes[node].get('node_type') == 'provision':
                    total_provisions += 1
                    if G.nodes[node].get('validated', False):
                        validated_provisions += 1
                        rulebook = G.nodes[node].get('rulebook', 'Unknown')
                        rulebook_stats[rulebook] += 1
            logger.info(f"  Total Provisions Extracted: {total_provisions}")
            if rulebook_stats:
                logger.info(f"\n  Provisions by Rulebook:")
                for rulebook, count in sorted(rulebook_stats.items(), key=lambda x: x[1], reverse=True):
                    logger.info(f"    {rulebook}: {count}")
        logger.info("\n")
    def step1_setup(self):
        logger.info("\n")
        logger.info("Step 1: Setup and Data Loading -")        
        self.load_official_documents()
        self.load_all_cases()
        logger.info("\nStep 1 Completed.")
        logger.info(f"   Loaded {self.stats['total_cases']:,} cases from {len(self.cases_data)} courts")
        logger.info(f"   Indexed {sum(len(v) for v in self.provision_index.values())} provisions from rulebooks")
    def step2_create_individual_graphs(self, num_cases_per_court: int = 10):
        logger.info("\n")
        logger.info("Step 2: Creating Individual Court Knowledge Graphs -")
        logger.info(f"Sampling {num_cases_per_court} random cases from each court...")
        courts = [
            'supreme_court', 'allahabad_high_court', 'bombay_high_court',
            'calcutta_high_court', 'delhi_high_court', 'madras_high_court']
        self.court_graphs = {}
        for court in courts:
            logger.info("\n")
            logger.info(f"Processing: {court.upper()}")
            G = self.create_court_knowledge_graph(court, num_cases=num_cases_per_court, random_sample=True)
            self.court_graphs[court] = G
            logger.info("\n[1/7] Creating circular layout visualization...")
            self.visualize_knowledge_graph(G, f"{court.replace('_', ' ').title()} - {num_cases_per_court} Random Cases",
                                          figsize=(26, 20))
            logger.info("[2/7] Creating hierarchical layout visualization...")
            self.visualize_hierarchical_graph(G, f"{court.replace('_', ' ').title()} - {num_cases_per_court} Random Cases",
                                             figsize=(28, 22))
            logger.info("[3/7] Creating statistics dashboard...")
            self.visualize_graph_statistics_panel(G, court.replace('_', ' ').title())
            logger.info("[4/7] Creating citation network subgraph...")
            self.visualize_subgraph_citation_network(G, court.replace('_', ' ').title())
            logger.info("[5/7] Creating statutory framework subgraph...")
            self.visualize_subgraph_statutory_framework(G, court.replace('_', ' ').title())
            logger.info("[6/7] Creating concept-doctrine network subgraph...")
            self.visualize_subgraph_concept_doctrine_network(G, court.replace('_', ' ').title())
            logger.info("[7/7] Creating rulebook integration subgraph...")
            self.visualize_rulebook_integration_graph(G, court.replace('_', ' ').title())
            logger.info(f"Completed {court}: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
            gc.collect()
        logger.info("\n")
        logger.info("Step 2 Completed.")
        logger.info(f"   Created {len(self.court_graphs)} court-specific knowledge graphs")
        logger.info(f"   Generated {len(self.court_graphs) * 7} visualizations")
    def step3_create_unified_graph(self, max_cases_per_court: Optional[int] = None,
                                   skip_large_visualizations: bool = False):
        logger.info("\n")
        logger.info("Step 3: Creating Unified Knowledge Graph -")        
        if max_cases_per_court:
            logger.warning(f"Limited to {max_cases_per_court} cases per court")
            logger.warning(f"Total cases: ~{max_cases_per_court * 6} (instead of {self.stats['total_cases']:,})")
        else:
            logger.info(f"Processing all {self.stats['total_cases']:,} cases...")        
        G_unified = self.create_unified_knowledge_graph(max_cases_per_court=max_cases_per_court)
        logger.info("\n")
        logger.info("Visualizing Unified Graph...")
        if skip_large_visualizations and G_unified.number_of_nodes() > 10000:
            logger.warning(f"Graph has {G_unified.number_of_nodes():,} nodes, skipping large visualizations")
            logger.info("Generating statistics dashboard only...")
            self.visualize_graph_statistics_panel(G_unified, "Unified Graph (All Courts)")
            logger.info("Creating focused subgraph visualizations...")
            self.visualize_rulebook_integration_graph(G_unified, "Unified Graph")
        else:
            logger.info("\n[1/7] Creating circular layout...")
            self.visualize_knowledge_graph(G_unified, "Unified Indian Legal Knowledge Graph (All Courts)",
                                          figsize=(32, 26))
            logger.info("[2/7] Creating hierarchical layout...")
            self.visualize_hierarchical_graph(G_unified, "Unified Indian Legal Knowledge Graph (All Courts)",
                                             figsize=(34, 28))
            logger.info("[3/7] Creating statistics dashboard...")
            self.visualize_graph_statistics_panel(G_unified, "Unified Graph (All Courts)")
            logger.info("[4/7] Creating unified citation network...")
            self.visualize_subgraph_citation_network(G_unified, "Unified Graph")
            logger.info("[5/7] Creating unified statutory framework...")
            self.visualize_subgraph_statutory_framework(G_unified, "Unified Graph")
            logger.info("[6/7] Creating unified concept-doctrine network...")
            self.visualize_subgraph_concept_doctrine_network(G_unified, "Unified Graph")
            logger.info("[7/7] Creating unified rulebook integration...")
            self.visualize_rulebook_integration_graph(G_unified, "Unified Graph")
        logger.info("\nStep 3 Completed.")
        logger.info(f"   Unified graph: {G_unified.number_of_nodes():,} nodes, {G_unified.number_of_edges():,} edges")
        if skip_large_visualizations and G_unified.number_of_nodes() > 10000:
            logger.info(f"   Generated 2 visualizations (limited for large graph)")
        else:
            logger.info(f"   Generated 7 comprehensive visualizations")
    def step4_export_for_gnn(self, output_dir: str = "gnn_data"):
        logger.info("\n")
        logger.info("Step 4: Exporting for GNN Processing -")
        if not self.global_knowledge_graph or self.global_knowledge_graph.number_of_nodes() == 0:
            logger.error("No unified graph found.")
            return
        self.export_for_gnn(output_dir=output_dir)
        logger.info("\nStep 4 Completed.")
        logger.info(f"   GNN-ready datasets exported to '{output_dir}/' directory")
    def step5_generate_analysis_report(self):
        logger.info("\n")
        logger.info("Step 5: Generating Analysis Report -")
        if not self.global_knowledge_graph or self.global_knowledge_graph.number_of_nodes() == 0:
            logger.error("No unified graph found.")
            return
        self.generate_comprehensive_report()
        logger.info("\nStep 5 Completed.")
        logger.info("   Comprehensive analysis report generated")
        logger.info("\nAll Steps Completed.")
if __name__ == "__main__":
    OWL_PATH = "IndiLegalOnt.owl"
    PROCESSED_DIR = "dataset_processed"
    RULES_DIR = "official_documents"    
    pipeline = LegalKnowledgeGraphPipeline(
        owl_path=OWL_PATH, processed_dir=PROCESSED_DIR,
        rules_dir=RULES_DIR, cache_dir="entity_cache")

2025-11-06 09:36:39,478 - INFO - Loading ontology from IndiLegalOnt.owl...
2025-11-06 09:36:40,194 - INFO - Ontology loaded: http://lmss.sali.org/


In [2]:
pipeline.step1_setup()
_ = gc.collect()

2025-11-06 09:36:44,180 - INFO - 

2025-11-06 09:36:44,181 - INFO - Step 1: Setup and Data Loading -
2025-11-06 09:36:44,182 - INFO - Loading official legal documents from PDFs...
2025-11-06 09:36:44,184 - INFO - Loading Indian Penal Code.pdf...
2025-11-06 09:36:44,335 - INFO -   Reading 119 pages...
2025-11-06 09:36:46,920 - INFO - Loading Code of Criminal Procedure.pdf...
2025-11-06 09:36:46,945 - INFO -   Reading 263 pages...
2025-11-06 09:36:52,718 - INFO - Loading Constitution of India.pdf...
2025-11-06 09:36:52,740 - INFO -   Reading 402 pages...
2025-11-06 09:37:01,330 - INFO - Loading Indian Evidence Act.pdf...
2025-11-06 09:37:01,334 - INFO -   Reading 60 pages...
2025-11-06 09:37:02,713 - INFO - Indexed 1684 provisions
2025-11-06 09:37:02,714 - INFO -   Indian Penal Code: 498 provisions
2025-11-06 09:37:02,714 - INFO -   Code of Criminal Procedure: 493 provisions
2025-11-06 09:37:02,714 - INFO -   Constitution of India: 497 provisions
2025-11-06 09:37:02,714 - INFO -   Indian

In [None]:
pipeline.step2_create_individual_graphs(num_cases_per_court=10)
_ = gc.collect()

In [None]:
pipeline.step2_create_individual_graphs(num_cases_per_court=100)
_ = gc.collect()

In [None]:
pipeline.step3_create_unified_graph(max_cases_per_court=250, skip_large_visualizations=False)
_ = gc.collect()

In [None]:
pipeline.step3_create_unified_graph(max_cases_per_court=None, skip_large_visualizations=True)
_ = gc.collect()

In [7]:
pipeline.step4_export_for_gnn(output_dir="gnn_data")
_ = gc.collect()

2025-11-06 10:15:35,981 - INFO - 

2025-11-06 10:15:35,983 - INFO - Step 4: Exporting for GNN Processing -
2025-11-06 10:15:35,985 - INFO - Exporting data for GNN...
2025-11-06 10:15:41,043 - INFO - Graph saved as knowledge_graph.gpickle
2025-11-06 10:15:43,068 - INFO - PyTorch Geometric format exported
2025-11-06 10:15:43,069 - INFO - GNN data exported to gnn_data
2025-11-06 10:15:43,227 - INFO - 
Step 4 Completed.
2025-11-06 10:15:43,228 - INFO -    GNN-ready datasets exported to 'gnn_data/' directory


In [8]:
pipeline.step5_generate_analysis_report()

2025-11-06 10:15:52,893 - INFO - 

2025-11-06 10:15:52,896 - INFO - Step 5: Generating Analysis Report -
2025-11-06 10:15:52,897 - INFO - 

2025-11-06 10:15:52,899 - INFO - Knowledge Graph Analysis Report:
2025-11-06 10:15:52,900 - INFO - 
Overall Statistics:
2025-11-06 10:15:52,901 - INFO - Total Cases: 56,025
2025-11-06 10:15:52,901 - INFO - Cases By Court:
2025-11-06 10:15:52,902 - INFO -   Allahabad High Court: 8,398
2025-11-06 10:15:52,902 - INFO -   Bombay High Court: 9,436
2025-11-06 10:15:52,903 - INFO -   Calcutta High Court: 9,584
2025-11-06 10:15:52,904 - INFO -   Delhi High Court: 8,821
2025-11-06 10:15:52,904 - INFO -   Madras High Court: 9,807
2025-11-06 10:15:52,904 - INFO -   Supreme Court: 9,979
2025-11-06 10:15:52,905 - INFO - Cases By Year (Top 10):
2025-11-06 10:15:52,905 - INFO -   2013: 2,400
2025-11-06 10:15:52,905 - INFO -   2010: 2,396
2025-11-06 10:15:52,905 - INFO -   2003: 2,395
2025-11-06 10:15:52,906 - INFO -   2023: 2,394
2025-11-06 10:15:52,906 - INFO - 