In [None]:
# !pip install langchain sentence-transformers
# !conda install -c conda-forge rdkit
import json
import numpy as np
import pandas as pd
import anthropic
import re
import time
import random
from typing import List, Dict, Tuple, Optional
import logging
import os
import math
from collections import Counter

# LangChain imports (로컬 임베딩 사용)
try:
    from langchain.vectorstores import FAISS
    from langchain.embeddings import HuggingFaceEmbeddings
    from langchain.text_splitter import RecursiveCharacterTextSplitter
    from langchain.schema import Document
    LANGCHAIN_AVAILABLE = True
    print("✅ LangChain loaded successfully!")
except ImportError:
    LANGCHAIN_AVAILABLE = False
    print("⚠️ LangChain not available. Install with: pip install langchain sentence-transformers")

# RDKit imports for molecular similarity
try:
    from rdkit import Chem, DataStructs
    from rdkit.Chem import AllChem, MACCSkeys, Descriptors
    from rdkit.Chem.Fingerprints import FingerprintMols
    RDKIT_AVAILABLE = True
    print("✅ RDKit loaded successfully!")
except ImportError:
    RDKIT_AVAILABLE = False
    print("⚠️ RDKit not available. Install with: conda install -c conda-forge rdkit")

# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class DataParser:
    """입력 텍스트를 실험 조건과 SMILES로 분리"""
    
    @staticmethod
    def parse_input_text(input_text: str) -> Dict:
        """input_text를 구성 요소로 분리"""
        try:
            # SMILES 추출
            smiles_pattern = r'SMILES:\s*([^\n\r]+)'
            smiles_match = re.search(smiles_pattern, input_text)
            smiles = smiles_match.group(1).strip() if smiles_match else ""
            
            # Assay 이름 추출
            assay_pattern = r'Assay:\s*([^\n\r]+)'
            assay_match = re.search(assay_pattern, input_text)
            assay_name = assay_match.group(1).strip() if assay_match else ""
            
            # 실험 설명 추출 (TOX21... 로 시작하는 부분)
            assay_desc_pattern = r'(TOX21_[^\.]+[^\.]*\.)'
            assay_desc_match = re.search(assay_desc_pattern, input_text, re.DOTALL)
            assay_description = assay_desc_match.group(1).strip() if assay_desc_match else ""
            
            # 전체 지시사항 추출
            instruction_pattern = r'(Given an Assay and SMILES.*?)(?:SMILES:|$)'
            instruction_match = re.search(instruction_pattern, input_text, re.DOTALL)
            instruction = instruction_match.group(1).strip() if instruction_match else ""
            
            return {
                'smiles': smiles,
                'assay_name': assay_name,
                'assay_description': assay_description,
                'instruction': instruction,
                'full_text': input_text
            }
            
        except Exception as e:
            logger.error(f"Error parsing input text: {e}")
            return {
                'smiles': '',
                'assay_name': '',
                'assay_description': '',
                'instruction': '',
                'full_text': input_text
            }

class HybridSMILESRAG:
    """
    하이브리드 RAG 시스템: 자연어(LangChain) + 화학적 유사도(RDKit)
    
    Features:
    - 실험 조건 유사도: LangChain + 로컬 임베딩
    - 화학적 유사도: RDKit 분자 지문
    - 적응형 융합: 동적 가중치 기반 통합
    - Claude 3.7 추론: 하이브리드 컨텍스트 활용
    """
    
    def __init__(self, claude_api_key: str = None, model: str = "claude-sonnet-4-20250514", temperature: float = 0.1):
        # Claude API 설정
        self.claude_client = anthropic.Anthropic(
            api_key=claude_api_key or os.getenv("ANTHROPIC_API_KEY")
        )
        if not (claude_api_key or os.getenv("ANTHROPIC_API_KEY")):
            raise ValueError("Claude API key not found. Set ANTHROPIC_API_KEY environment variable.")
        
        self.model = model
        self.temperature = temperature
        
        # 데이터 저장소
        self.train_data = []
        self.parsed_train_data = []
        
        # LangChain 컴포넌트 (로컬 임베딩 사용)
        if LANGCHAIN_AVAILABLE:
            self.embeddings = HuggingFaceEmbeddings(
                model_name="all-MiniLM-L6-v2",
                model_kwargs={'device': 'cpu'}
            )
            self.assay_vectorstore = None
            logger.info("🔤 Using local HuggingFace embeddings for assay similarity")
        else:
            self.embeddings = None
            logger.warning("⚠️ LangChain not available, assay similarity disabled")
        
        # RDKit 컴포넌트
        self.mol_objects = {}
        self.fingerprints = {}
        
        # 비용 추적
        self.cost_tracker = {
            'input_tokens': 0,
            'output_tokens': 0,
            'total_cost': 0.0,
            'api_calls': 0
        }
        
        logger.info("🔬 Hybrid RAG System initialized")
    
    def load_jsonl_data(self, file_path: str) -> List[Dict]:
        """JSONL 파일에서 데이터 로드"""
        data = []
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                for line in f:
                    item = json.loads(line.strip())
                    data.append(item)
        except Exception as e:
            logger.error(f"Error loading data: {e}")
            raise
        
        logger.info(f"Loaded {len(data)} samples from {file_path}")
        return data
    
    def simple_train_test_split(self, data: List[Dict], test_size: float = 0.2, random_state: int = 42) -> Tuple[List[Dict], List[Dict]]:
        """간단한 train/test 분할"""
        random.seed(random_state)
        np.random.seed(random_state)
        
        shuffled_data = data.copy()
        random.shuffle(shuffled_data)
        
        split_idx = int(len(shuffled_data) * (1 - test_size))
        train_data = shuffled_data[:split_idx]
        test_data = shuffled_data[split_idx:]
        
        logger.info(f"Train set: {len(train_data)} samples")
        logger.info(f"Test set: {len(test_data)} samples")
        
        return train_data, test_data
    
    def prepare_hybrid_training_data(self, train_data: List[Dict]):
        """하이브리드 훈련 데이터 준비"""
        logger.info("🏗️ Preparing hybrid training data...")
        
        self.train_data = train_data
        self.parsed_train_data = []
        assay_documents = []
        
        for idx, item in enumerate(train_data):
            # 1. 텍스트 파싱
            parsed = DataParser.parse_input_text(item['input_text'])
            parsed['logac50'] = int(item['output_text'])
            parsed['idx'] = idx
            
            # 2. 화학 구조 처리 (RDKit)
            if RDKIT_AVAILABLE and parsed['smiles']:
                mol = self._create_mol_object(parsed['smiles'])
                if mol is not None:
                    parsed['mol'] = mol
                    parsed['fingerprints'] = self._generate_fingerprints(mol, parsed['smiles'])
                    parsed['molecular_props'] = self._calculate_molecular_properties(mol)
                    parsed['activity_category'] = self._categorize_activity(parsed['logac50'])
                    
                    # 캐싱
                    self.mol_objects[parsed['smiles']] = mol
                    self.fingerprints[parsed['smiles']] = parsed['fingerprints']
            
            # 3. 실험 조건 문서 생성 (LangChain용)
            if LANGCHAIN_AVAILABLE and parsed['assay_description']:
                assay_doc_content = f"""
                Assay: {parsed['assay_name']}
                Description: {parsed['assay_description']}
                Activity: {parsed['logac50']}
                Category: {self._categorize_activity(parsed['logac50'])}
                Instructions: {parsed['instruction']}
                """
                
                assay_doc = Document(
                    page_content=assay_doc_content,
                    metadata={
                        'assay_name': parsed['assay_name'],
                        'logac50': parsed['logac50'],
                        'idx': idx
                    }
                )
                assay_documents.append(assay_doc)
            
            self.parsed_train_data.append(parsed)
            
            if (idx + 1) % 100 == 0:
                logger.info(f"Processed {idx + 1}/{len(train_data)} samples...")
        
        # 4. LangChain 벡터스토어 구축
        if LANGCHAIN_AVAILABLE and assay_documents:
            logger.info("🔤 Building assay vector store...")
            text_splitter = RecursiveCharacterTextSplitter(
                chunk_size=500,
                chunk_overlap=50
            )
            split_docs = text_splitter.split_documents(assay_documents)
            self.assay_vectorstore = FAISS.from_documents(split_docs, self.embeddings)
            logger.info(f"✅ Built assay vector store with {len(split_docs)} chunks")
        
        logger.info(f"✅ Prepared {len(self.parsed_train_data)} training examples")
    
    def _categorize_activity(self, logac50: int) -> str:
        """활성도 카테고리 분류"""
        if logac50 >= 85:
            return "Very High"
        elif logac50 >= 70:
            return "High"
        elif logac50 >= 50:
            return "Medium"
        elif logac50 >= 30:
            return "Low"
        else:
            return "Very Low"
    
    def _create_mol_object(self, smiles: str):
        """SMILES에서 RDKit 분자 객체 생성"""
        if not RDKIT_AVAILABLE:
            return None
            
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                mol = Chem.AddHs(mol)
                return mol
            return None
        except Exception as e:
            logger.debug(f"Error creating mol object for {smiles}: {e}")
            return None
    
    def _generate_fingerprints(self, mol, smiles: str) -> Dict:
        """다양한 분자 지문 생성"""
        if not RDKIT_AVAILABLE or mol is None:
            return {}
        
        fingerprints = {}
        try:
            fingerprints['morgan'] = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
            fingerprints['maccs'] = MACCSkeys.GenMACCSKeys(mol)
            fingerprints['rdkit'] = AllChem.GetRDKitFPGenerator().GetFingerprint(mol)
            fingerprints['atompair'] = AllChem.GetHashedAtomPairFingerprintAsBitVect(mol, nBits=2048)
        except Exception as e:
            logger.debug(f"Error generating fingerprints for {smiles}: {e}")
        
        return fingerprints
    
    def _calculate_molecular_properties(self, mol) -> Dict:
        """분자 물성 계산"""
        if not RDKIT_AVAILABLE or mol is None:
            return {}
            
        try:
            props = {
                'mw': Descriptors.MolWt(mol),
                'logp': Descriptors.MolLogP(mol),
                'hbd': Descriptors.NumHDonors(mol),
                'hba': Descriptors.NumHAcceptors(mol),
                'tpsa': Descriptors.TPSA(mol),
                'rotatable_bonds': Descriptors.NumRotatableBonds(mol),
                'aromatic_rings': Descriptors.NumAromaticRings(mol),
                'heavy_atoms': Descriptors.HeavyAtomCount(mol)
            }
            return props
        except Exception:
            return {}
    
    def hybrid_similarity_search(self, query_input: str, k_assay: int = 3, k_chemical: int = 5) -> Tuple[List[Dict], List[Dict]]:
        """하이브리드 유사도 검색: 실험조건 + 화학구조"""
        
        # 1. 입력 파싱
        parsed_query = DataParser.parse_input_text(query_input)
        
        # 2. 실험 조건 유사도 검색 (LangChain)
        similar_assays = []
        if LANGCHAIN_AVAILABLE and self.assay_vectorstore and parsed_query['assay_description']:
            try:
                assay_query = f"{parsed_query['assay_name']} {parsed_query['assay_description']}"
                assay_docs = self.assay_vectorstore.similarity_search_with_score(
                    assay_query, k=k_assay
                )
                
                for doc, score in assay_docs:
                    similar_assays.append({
                        'content': doc.page_content,
                        'metadata': doc.metadata,
                        'similarity_score': 1 - score,  # 거리를 유사도로 변환
                        'assay_name': doc.metadata.get('assay_name', ''),
                        'logac50': doc.metadata.get('logac50', 0)
                    })
                    
                logger.debug(f"Found {len(similar_assays)} similar assays")
            except Exception as e:
                logger.warning(f"Assay search failed: {e}")
        
        # 3. 화학적 유사도 검색 (RDKit)
        similar_molecules = []
        if RDKIT_AVAILABLE and parsed_query['smiles']:
            try:
                query_mol = self._create_mol_object(parsed_query['smiles'])
                if query_mol is not None:
                    query_fps = self._generate_fingerprints(query_mol, parsed_query['smiles'])
                    
                    similarities = []
                    for example in self.parsed_train_data:
                        if 'fingerprints' in example and example['fingerprints']:
                            similarity_scores = self._calculate_multi_fingerprint_similarity(
                                query_fps, example['fingerprints']
                            )
                            final_similarity = self._combine_similarity_scores(similarity_scores)
                            
                            similarities.append((final_similarity, example, similarity_scores))
                    
                    # 상위 k개 선택
                    similarities.sort(key=lambda x: x[0], reverse=True)
                    
                    for sim_score, example, breakdown in similarities[:k_chemical]:
                        similar_molecules.append({
                            'smiles': example['smiles'],
                            'logac50': example['logac50'],
                            'activity_category': example.get('activity_category', ''),
                            'molecular_props': example.get('molecular_props', {}),
                            'similarity_score': sim_score,
                            'similarity_breakdown': breakdown,
                            'assay_name': example.get('assay_name', '')
                        })
                    
                    logger.debug(f"Found {len(similar_molecules)} similar molecules")
            except Exception as e:
                logger.warning(f"Chemical search failed: {e}")
        
        return similar_assays, similar_molecules
    
    def _calculate_multi_fingerprint_similarity(self, query_fps: Dict, target_fps: Dict) -> Dict:
        """다중 지문을 사용한 유사도 계산"""
        similarities = {}
        
        fingerprint_types = ['morgan', 'maccs', 'rdkit', 'atompair']
        
        for fp_type in fingerprint_types:
            if fp_type in query_fps and fp_type in target_fps:
                try:
                    tanimoto_sim = DataStructs.TanimotoSimilarity(
                        query_fps[fp_type], target_fps[fp_type]
                    )
                    similarities[fp_type] = tanimoto_sim
                except Exception as e:
                    logger.debug(f"Error calculating {fp_type} similarity: {e}")
                    similarities[fp_type] = 0.0
            else:
                similarities[fp_type] = 0.0
        
        return similarities
    
    def _combine_similarity_scores(self, similarity_scores: Dict) -> float:
        """여러 지문 유사도를 가중 평균으로 결합"""
        weights = {
            'morgan': 0.4,
            'maccs': 0.3,
            'rdkit': 0.2,
            'atompair': 0.1
        }
        
        weighted_sum = 0.0
        total_weight = 0.0
        
        for fp_type, weight in weights.items():
            if fp_type in similarity_scores:
                weighted_sum += similarity_scores[fp_type] * weight
                total_weight += weight
        
        return weighted_sum / total_weight if total_weight > 0 else 0.0
    
    def calculate_context_weights(self, similar_assays: List[Dict], similar_molecules: List[Dict]) -> Dict:
        """컨텍스트 가중치 동적 계산"""
        
        # 최고 유사도 점수 추출
        max_assay_sim = max([assay.get('similarity_score', 0) for assay in similar_assays]) if similar_assays else 0
        max_chem_sim = max([mol.get('similarity_score', 0) for mol in similar_molecules]) if similar_molecules else 0
        
        # 데이터 가용성 고려
        assay_availability = len(similar_assays) / 3.0  # 최대 3개 대비
        chem_availability = len(similar_molecules) / 5.0  # 최대 5개 대비
        
        # 동적 가중치 계산
        if max_assay_sim > 0.8 and max_chem_sim < 0.5:
            # 실험 조건 매우 유사, 화학 구조 다름
            weights = {'assay': 0.75, 'chemical': 0.25}
        elif max_chem_sim > 0.8 and max_assay_sim < 0.5:
            # 화학 구조 매우 유사, 실험 조건 다름
            weights = {'assay': 0.25, 'chemical': 0.75}
        elif max_assay_sim > 0.7 and max_chem_sim > 0.7:
            # 둘 다 유사함 - 균형
            weights = {'assay': 0.5, 'chemical': 0.5}
        else:
            # 기본 가중치, 가용성으로 조정
            base_assay_weight = 0.4 + (assay_availability * 0.2)
            base_chem_weight = 0.6 - (assay_availability * 0.2)
            weights = {'assay': base_assay_weight, 'chemical': base_chem_weight}
        
        # 정규화
        total = weights['assay'] + weights['chemical']
        weights = {k: v/total for k, v in weights.items()}
        
        logger.debug(f"Context weights: Assay={weights['assay']:.2f}, Chemical={weights['chemical']:.2f}")
        
        return weights
    
    def create_hybrid_prompt(self, query_input: str, similar_assays: List[Dict], similar_molecules: List[Dict], weights: Dict) -> str:
        """하이브리드 컨텍스트 통합 프롬프트 생성"""
        
        parsed_query = DataParser.parse_input_text(query_input)
        
        # 실험 조건 컨텍스트
        assay_context = ""
        if similar_assays:
            assay_context = f"🧪 EXPERIMENTAL PROTOCOL CONTEXT (Weight: {weights['assay']:.2f}):\n\n"
            for i, assay in enumerate(similar_assays, 1):
                assay_context += f"Similar Assay {i} (Similarity: {assay['similarity_score']:.3f}):\n"
                assay_context += f"  {assay['content']}\n\n"
        else:
            assay_context = "🧪 EXPERIMENTAL PROTOCOL CONTEXT: No similar assays found.\n\n"
        
        # 화학 구조 컨텍스트
        chemical_context = ""
        if similar_molecules:
            chemical_context = f"🧬 CHEMICAL STRUCTURE CONTEXT (Weight: {weights['chemical']:.2f}):\n\n"
            for i, mol in enumerate(similar_molecules, 1):
                chemical_context += f"Similar Molecule {i} (Tanimoto: {mol['similarity_score']:.3f}):\n"
                chemical_context += f"  SMILES: {mol['smiles']}\n"
                chemical_context += f"  LogAC50: {mol['logac50']}\n"
                chemical_context += f"  Activity: {mol['activity_category']}\n"
                
                if 'similarity_breakdown' in mol:
                    breakdown = mol['similarity_breakdown']
                    chemical_context += f"  Fingerprint Details:\n"
                    chemical_context += f"    - Morgan: {breakdown.get('morgan', 0):.3f}\n"
                    chemical_context += f"    - MACCS: {breakdown.get('maccs', 0):.3f}\n"
                    chemical_context += f"    - RDKit: {breakdown.get('rdkit', 0):.3f}\n"
                
                if 'molecular_props' in mol and mol['molecular_props']:
                    props = mol['molecular_props']
                    chemical_context += f"  Properties: MW={props.get('mw', 'N/A'):.1f}, "
                    chemical_context += f"LogP={props.get('logp', 'N/A'):.2f}\n"
                
                chemical_context += "\n"
        else:
            chemical_context = "🧬 CHEMICAL STRUCTURE CONTEXT: No similar molecules found.\n\n"
        
        # 통합 프롬프트
        integrated_prompt = f"""<thinking>
I am performing a hybrid analysis combining experimental protocol knowledge and chemical structure similarity for toxicity prediction.

Query Details:
- Assay: {parsed_query['assay_name']}
- SMILES: {parsed_query['smiles']}

Context Analysis:
- Assay context weight: {weights['assay']:.2f}
- Chemical context weight: {weights['chemical']:.2f}

This weighting suggests I should prioritize {"experimental context" if weights['assay'] > weights['chemical'] else "chemical structure analysis"} while considering both sources of information.

Let me analyze the patterns systematically...
</thinking>

{assay_context}

{chemical_context}

🎯 HYBRID TOXICITY PREDICTION TASK:

Query Input:
- Assay: {parsed_query['assay_name']}
- SMILES: {parsed_query['smiles']}
- Task: {parsed_query['instruction']}

📊 INTEGRATED ANALYSIS FRAMEWORK:

1. **Context Weighting Strategy**:
   - Experimental Protocol Weight: {weights['assay']:.2f}
   - Chemical Structure Weight: {weights['chemical']:.2f}

2. **Primary Analysis Focus**:
   {"Focus on experimental protocol patterns and assay-specific factors" if weights['assay'] > 0.6 else "Focus on chemical structure-activity relationships" if weights['chemical'] > 0.6 else "Balance both experimental and chemical contexts equally"}

3. **Cross-Validation Approach**:
   - Compare patterns from both experimental and chemical contexts
   - Identify consistent vs conflicting predictions
   - Resolve conflicts using the higher-weighted context

4. **Evidence Integration**:
   - Experimental evidence: {"Strong" if len(similar_assays) >= 2 else "Moderate" if len(similar_assays) == 1 else "Weak"}
   - Chemical evidence: {"Strong" if len(similar_molecules) >= 3 else "Moderate" if len(similar_molecules) >= 1 else "Weak"}

🔬 REQUIRED ANALYSIS:

**EXPERIMENTAL CONTEXT ANALYSIS**:
[Analyze the experimental protocol patterns and assay-specific factors]

**CHEMICAL STRUCTURE ANALYSIS**:
[Analyze the molecular structure and chemical similarity patterns]

**INTEGRATED PREDICTION LOGIC**:
[Combine both contexts using the calculated weights]

**FINAL PREDICTION**: [INTEGER 0-100]

**CONFIDENCE ASSESSMENT**: [High/Medium/Low with justification based on context quality]

Remember: Weight your analysis according to the calculated context weights, but always provide reasoning from both experimental and chemical perspectives when available."""

        return integrated_prompt
    
    def predict_single(self, query_input: str) -> Tuple[int, str, float, Dict]:
        """단일 입력에 대한 하이브리드 예측"""
        start_time = time.time()
        
        try:
            # 1. 하이브리드 유사도 검색
            similar_assays, similar_molecules = self.hybrid_similarity_search(query_input)
            
            # 2. 컨텍스트 가중치 계산
            weights = self.calculate_context_weights(similar_assays, similar_molecules)
            
            # 3. 통합 프롬프트 생성
            prompt = self.create_hybrid_prompt(query_input, similar_assays, similar_molecules, weights)
            
            # 4. Claude API 호출
            response = self.claude_client.messages.create(
                model=self.model,
                max_tokens=3000,  # 하이브리드 분석을 위한 더 긴 응답
                temperature=self.temperature,
                messages=[{"role": "user", "content": prompt}]
            )
            
            result_text = response.content[0].text
            
            # 5. 토큰 사용량 추적
            self.cost_tracker['input_tokens'] += response.usage.input_tokens
            self.cost_tracker['output_tokens'] += response.usage.output_tokens
            self.cost_tracker['api_calls'] += 1
            
            # 6. 예측값 추출
            prediction = self._extract_prediction(result_text)
            
            elapsed_time = time.time() - start_time
            
            # 7. 메타데이터 수집
            metadata = {
                'weights': weights,
                'n_similar_assays': len(similar_assays),
                'n_similar_molecules': len(similar_molecules),
                'max_assay_similarity': max([a.get('similarity_score', 0) for a in similar_assays]) if similar_assays else 0,
                'max_chemical_similarity': max([m.get('similarity_score', 0) for m in similar_molecules]) if similar_molecules else 0
            }
            
            logger.debug(f"Hybrid prediction: {prediction} (assay_weight: {weights['assay']:.2f}, chem_weight: {weights['chemical']:.2f})")
            
            return prediction, result_text, elapsed_time, metadata
            
        except Exception as e:
            logger.error(f"Error in hybrid prediction: {e}")
            return 50, f"Error: {str(e)}", 0.0, {}
    
    def _extract_prediction(self, result_text: str) -> int:
        """LLM 응답에서 예측값 추출"""
        patterns = [
            r'FINAL PREDICTION:\s*(\d{1,3})',
            r'PREDICTION:\s*(\d{1,3})',
            r'Prediction:\s*(\d{1,3})',
            r'LogAC50:\s*(\d{1,3})',
        ]
        
        for pattern in patterns:
            match = re.search(pattern, result_text, re.IGNORECASE)
            if match:
                val = int(match.group(1))
                if 0 <= val <= 100:
                    return val
        
        # 백업: 0-100 범위의 첫 번째 숫자
        numbers = re.findall(r'\b(\d{1,3})\b', result_text)
        for num in numbers:
            val = int(num)
            if 0 <= val <= 100:
                return val
        
        logger.warning("Could not extract valid prediction, using default value 50")
        return 50
    
    def evaluate_test_set(self, test_data: List[Dict]) -> Dict:
        """하이브리드 시스템으로 테스트 세트 평가"""
        predictions = []
        actuals = []
        explanations = []
        times = []
        errors = []
        metadata_list = []
        
        logger.info(f"🔬 Starting hybrid evaluation on {len(test_data)} test samples...")
        
        for i, item in enumerate(test_data):
            input_text = item['input_text']
            actual = int(item['output_text'])
            
            try:
                pred, explanation, elapsed_time, metadata = self.predict_single(input_text)
                
                predictions.append(pred)
                actuals.append(actual)
                explanations.append(explanation)
                times.append(elapsed_time)
                errors.append(None)
                metadata_list.append(metadata)
                
                # 실시간 성능 모니터링
                if i > 0:
                    current_mae = np.mean([abs(a - p) for a, p in zip(actuals, predictions)])
                    logger.info(f"Sample {i+1}/{len(test_data)} | Pred: {pred} | Actual: {actual} | "
                              f"Weights: A={metadata['weights']['assay']:.2f}/C={metadata['weights']['chemical']:.2f} | "
                              f"Running MAE: {current_mae:.2f}")
                else:
                    logger.info(f"Sample {i+1}/{len(test_data)} | Pred: {pred} | Actual: {actual}")
                    
            except Exception as e:
                logger.error(f"Error processing sample {i}: {e}")
                predictions.append(50)
                actuals.append(actual)
                explanations.append(f"Error: {str(e)}")
                times.append(0.0)
                errors.append(str(e))
                metadata_list.append({})
        
        # 성능 메트릭 계산
        mae = np.mean([abs(a - p) for a, p in zip(actuals, predictions)])
        mse = np.mean([(a - p)**2 for a, p in zip(actuals, predictions)])
        rmse = math.sqrt(mse)
        r2 = self._calculate_r2(actuals, predictions)
        
        # 하이브리드 특화 메트릭
        assay_weighted_samples = [i for i, meta in enumerate(metadata_list) 
                                 if meta.get('weights', {}).get('assay', 0) > 0.6]
        chemical_weighted_samples = [i for i, meta in enumerate(metadata_list) 
                                   if meta.get('weights', {}).get('chemical', 0) > 0.6]
        balanced_samples = [i for i, meta in enumerate(metadata_list) 
                           if 0.4 <= meta.get('weights', {}).get('assay', 0.5) <= 0.6]
        
        # 가중치별 성능 분석
        assay_mae = np.mean([abs(actuals[i] - predictions[i]) for i in assay_weighted_samples]) if assay_weighted_samples else float('inf')
        chemical_mae = np.mean([abs(actuals[i] - predictions[i]) for i in chemical_weighted_samples]) if chemical_weighted_samples else float('inf')
        balanced_mae = np.mean([abs(actuals[i] - predictions[i]) for i in balanced_samples]) if balanced_samples else float('inf')
        
        # 유사도 기반 분석
        high_assay_sim_samples = [i for i, meta in enumerate(metadata_list) 
                                 if meta.get('max_assay_similarity', 0) > 0.7]
        high_chem_sim_samples = [i for i, meta in enumerate(metadata_list) 
                                if meta.get('max_chemical_similarity', 0) > 0.7]
        
        results = {
            'predictions': predictions,
            'actuals': actuals,
            'explanations': explanations,
            'times': times,
            'errors': errors,
            'metadata': metadata_list,
            'metrics': {
                'mae': mae,
                'mse': mse,
                'rmse': rmse,
                'r2': r2,
                'within_10_pct': sum([1 for a, p in zip(actuals, predictions) if abs(a - p) <= 10]) / len(actuals) * 100,
                'within_20_pct': sum([1 for a, p in zip(actuals, predictions) if abs(a - p) <= 20]) / len(actuals) * 100,
                'assay_weighted_mae': assay_mae,
                'chemical_weighted_mae': chemical_mae,
                'balanced_mae': balanced_mae,
                'n_assay_weighted': len(assay_weighted_samples),
                'n_chemical_weighted': len(chemical_weighted_samples),
                'n_balanced': len(balanced_samples),
                'n_high_assay_sim': len(high_assay_sim_samples),
                'n_high_chem_sim': len(high_chem_sim_samples),
                'avg_assay_weight': np.mean([meta.get('weights', {}).get('assay', 0.5) for meta in metadata_list]),
                'avg_chemical_weight': np.mean([meta.get('weights', {}).get('chemical', 0.5) for meta in metadata_list]),
                'n_samples': len(test_data),
                'n_errors': len([e for e in errors if e is not None]),
                'avg_time': np.mean([t for t in times if t > 0])
            }
        }
        
        return results
    
    def _calculate_r2(self, y_true: List[float], y_pred: List[float]) -> float:
        """R² Score 계산"""
        y_mean = np.mean(y_true)
        ss_tot = sum([(y - y_mean)**2 for y in y_true])
        ss_res = sum([(t - p)**2 for t, p in zip(y_true, y_pred)])
        return 1 - (ss_res / ss_tot) if ss_tot != 0 else 0.0
    
    def calculate_cost(self) -> Dict:
        """Claude API 비용 계산"""
        input_cost_per_1k = 0.003   # $3/1M = $0.003/1K
        output_cost_per_1k = 0.015  # $15/1M = $0.015/1K
        
        input_cost = (self.cost_tracker['input_tokens'] / 1000) * input_cost_per_1k
        output_cost = (self.cost_tracker['output_tokens'] / 1000) * output_cost_per_1k
        total_cost = input_cost + output_cost
        
        return {
            'input_tokens': self.cost_tracker['input_tokens'],
            'output_tokens': self.cost_tracker['output_tokens'],
            'api_calls': self.cost_tracker['api_calls'],
            'input_cost': input_cost,
            'output_cost': output_cost,
            'total_cost': total_cost
        }
    
    def print_results(self, results: Dict, cost_breakdown: Dict):
        """하이브리드 시스템 결과 출력"""
        print("\n" + "="*90)
        print("🔬 HYBRID RAG SYSTEM - SMILES TOXICITY PREDICTION RESULTS")
        print("="*90)
        
        metrics = results['metrics']
        
        print(f"\n📊 Overall Performance Metrics:")
        print(f"   • Mean Absolute Error (MAE): {metrics['mae']:.2f}")
        print(f"   • Root Mean Square Error (RMSE): {metrics['rmse']:.2f}")
        print(f"   • R² Score: {metrics['r2']:.3f}")
        print(f"   • Accuracy within ±10: {metrics['within_10_pct']:.1f}%")
        print(f"   • Accuracy within ±20: {metrics['within_20_pct']:.1f}%")
        
        print(f"\n🔀 Hybrid Analysis Breakdown:")
        print(f"   • Average Assay Weight: {metrics['avg_assay_weight']:.2f}")
        print(f"   • Average Chemical Weight: {metrics['avg_chemical_weight']:.2f}")
        print(f"   • Assay-Weighted Samples: {metrics['n_assay_weighted']} (MAE: {metrics['assay_weighted_mae']:.2f})")
        print(f"   • Chemical-Weighted Samples: {metrics['n_chemical_weighted']} (MAE: {metrics['chemical_weighted_mae']:.2f})")
        print(f"   • Balanced Samples: {metrics['n_balanced']} (MAE: {metrics['balanced_mae']:.2f})")
        
        print(f"\n🎯 Similarity Analysis:")
        print(f"   • High Assay Similarity (>0.7): {metrics['n_high_assay_sim']} samples")
        print(f"   • High Chemical Similarity (>0.7): {metrics['n_high_chem_sim']} samples")
        
        print(f"\n⚡ Performance:")
        print(f"   • Average Prediction Time: {metrics['avg_time']:.2f}s")
        print(f"   • Test Samples: {metrics['n_samples']}")
        print(f"   • Errors: {metrics['n_errors']}")
        
        print(f"\n💰 Cost Breakdown (Claude Sonnet 4):")
        print(f"   • API Calls: {cost_breakdown['api_calls']}")
        print(f"   • Input Tokens: {cost_breakdown['input_tokens']:,}")
        print(f"   • Output Tokens: {cost_breakdown['output_tokens']:,}")
        print(f"   • Total Cost: ${cost_breakdown['total_cost']:.4f}")
        
        # 최고/최악 예측 분석
        errors = [abs(a - p) for a, p in zip(results['actuals'], results['predictions'])]
        best_indices = np.argsort(errors)[:3]
        worst_indices = np.argsort(errors)[-3:][::-1]
        
        print(f"\n🏆 Best Hybrid Predictions:")
        for i, idx in enumerate(best_indices):
            actual = results['actuals'][idx]
            pred = results['predictions'][idx]
            metadata = results['metadata'][idx]
            error = abs(actual - pred)
            weights = metadata.get('weights', {})
            print(f"   {i+1}. Actual: {actual:2d}, Predicted: {pred:2d}, Error: {error:2d}, "
                  f"Weights: A={weights.get('assay', 0):.2f}/C={weights.get('chemical', 0):.2f}")
        
        print(f"\n🤔 Most Challenging Predictions:")
        for i, idx in enumerate(worst_indices):
            actual = results['actuals'][idx]
            pred = results['predictions'][idx]
            metadata = results['metadata'][idx]
            error = abs(actual - pred)
            weights = metadata.get('weights', {})
            print(f"   {i+1}. Actual: {actual:2d}, Predicted: {pred:2d}, Error: {error:2d}, "
                  f"Weights: A={weights.get('assay', 0):.2f}/C={weights.get('chemical', 0):.2f}")


def main():
    """하이브리드 RAG 시스템 메인 실행 함수"""
    DATA_FILE = "combined_train_sampled2.jsonl"
    TEST_SIZE = 0.2
    RANDOM_STATE = 42
    
    # API 키 설정
    CLAUDE_API_KEY = ""  # 실제 키로 교체하세요
    
    try:
        print("🔬 Initializing Hybrid RAG System...")
        print("   • Natural Language Processing: LangChain + Local Embeddings")
        print("   • Chemical Similarity: RDKit + Molecular Fingerprints")
        print("   • Integration: Claude Sonnet 4")
        
        predictor = HybridSMILESRAG(claude_api_key=CLAUDE_API_KEY)
        
        print("\n📚 Loading data...")
        data = predictor.load_jsonl_data(DATA_FILE)
        
        print("✂️ Splitting data...")
        train_data, test_data = predictor.simple_train_test_split(
            data, test_size=TEST_SIZE, random_state=RANDOM_STATE
        )
        
        print("🏗️ Preparing hybrid training data...")
        print("   • Parsing experimental conditions and SMILES")
        print("   • Building assay vector store (LangChain)")
        print("   • Generating molecular fingerprints (RDKit)")
        predictor.prepare_hybrid_training_data(train_data)
        
        print("\n🔬 Evaluating with Hybrid RAG System...")
        print("   • Experimental similarity search + Chemical similarity search")
        print("   • Dynamic context weighting + Integrated reasoning")
        
        # 하이브리드 시스템 테스트
        test_subset = test_data[:20]  # 비용 고려하여 20개로 시작
        results = predictor.evaluate_test_set(test_subset)
        
        print("\n💰 Calculating costs...")
        cost_breakdown = predictor.calculate_cost()
        
        print("\n📊 Generating comprehensive results...")
        predictor.print_results(results, cost_breakdown)
        
        # 상세 결과 저장
        output_data = {
            'model': 'hybrid-rag-claude-sonnet-4',
            'components': {
                'assay_similarity': 'LangChain + HuggingFace Embeddings' if LANGCHAIN_AVAILABLE else 'Disabled',
                'chemical_similarity': 'RDKit + Molecular Fingerprints' if RDKIT_AVAILABLE else 'String Fallback',
                'integration': 'Claude Sonnet 4'
            },
            'results': results,
            'cost_breakdown': cost_breakdown,
            'settings': {
                'data_file': DATA_FILE,
                'test_size': TEST_SIZE,
                'random_state': RANDOM_STATE,
                'test_subset_size': len(test_subset),
                'langchain_available': LANGCHAIN_AVAILABLE,
                'rdkit_available': RDKIT_AVAILABLE
            }
        }
        
        filename = 'hybrid_rag_results.json'
        with open(filename, 'w') as f:
            json.dump(output_data, f, indent=2, default=str)
        
        print(f"\n💾 Results saved to '{filename}'")
        
        # 성능 요약
        mae = results['metrics']['mae']
        r2 = results['metrics']['r2']
        within_10 = results['metrics']['within_10_pct']
        avg_assay_weight = results['metrics']['avg_assay_weight']
        avg_chem_weight = results['metrics']['avg_chemical_weight']
        
        print(f"\n🏆 HYBRID SYSTEM PERFORMANCE SUMMARY:")
        print(f"   MAE: {mae:.2f} | R²: {r2:.3f} | Within ±10: {within_10:.1f}%")
        print(f"   Avg Context Weights: Assay={avg_assay_weight:.2f}, Chemical={avg_chem_weight:.2f}")
        
        # 시스템 상태 체크
        print(f"\n🔧 System Status:")
        if LANGCHAIN_AVAILABLE and RDKIT_AVAILABLE:
            print("   ✅ Full Hybrid System: Both LangChain and RDKit operational!")
            print("   🎯 Optimal performance with dual similarity engines")
        elif RDKIT_AVAILABLE:
            print("   ⚠️ Partial System: RDKit operational, LangChain disabled")
            print("   📝 Install LangChain for assay similarity: pip install langchain sentence-transformers")
        elif LANGCHAIN_AVAILABLE:
            print("   ⚠️ Partial System: LangChain operational, RDKit disabled")
            print("   🧪 Install RDKit for chemical similarity: conda install -c conda-forge rdkit")
        else:
            print("   ❌ Minimal System: Both engines disabled")
            print("   📦 Install dependencies for full functionality")
        
        print(f"\n🎉 Hybrid RAG Analysis complete!")
        
    except Exception as e:
        logger.error(f"Error in main execution: {e}")
        raise

if __name__ == "__main__":
    main()
