In [2]:
%cd ..

c:\Users\admin\as.drug_indexer\backup_repo


In [3]:
import pandas as pd

df = pd.read_json(r'C:\Users\admin\medical_data_saved\drugs\crawled\raw.json')
print(df.head())

                 drug_name                                               link  \
0             Parafizz 650  https://songkhoe.medplus.vn/thuoc-parafizz-650...   
1                   Damrin  https://songkhoe.medplus.vn/thuoc-damrin-lieu-...   
2                   Gonesi                https://songkhoe.medplus.vn/gonesi/   
3              Bayer (Đức)  https://www.nhathuocankhang.com/thuoc-tiet-nie...   
4  Bổ thận dương Nhất Nhất  https://songkhoe.medplus.vn/thuoc-bo-than-duon...   

                                           chemicals  \
0  Paracetamol 650 mg\n(Tá dược gồm: acid citric ...   
1               Mỗi viên nang chứa:\nDiacerein 50 mg   
2                           Mỗi viên của Gonesi chứa   
3  Trong mỗi viên Progynova 2mg chứa:\nHoạt chất:...   
4                        Mỗi viên nén bao phim chứa:   

                                               usage  \
0  \nĐiều trị các triệu chứng đau nhức và sốt từ ...   
1  Thuốc Damrin là thuốc ETC được dùng để điều...   
2  Gones

In [4]:
df['chemicals_length'] = df['chemicals'].str.len()
print(df['chemicals_length'].describe())

long_chemicals = df[df['chemicals_length'] > 2000][['drug_name', 'chemicals', 'chemicals_length']]
long_chemicals.value_counts().sum()
print(long_chemicals)

count    45088.000000
mean        84.186812
std        127.032901
min          1.000000
25%         21.000000
50%         41.000000
75%        113.000000
max       6753.000000
Name: chemicals_length, dtype: float64
                           drug_name  \
3179                Epfepara codeine   
7008                   Meyersina 100   
10300        Cao ích mẫu Mediplantex   
10329  Terpin Codein - F Mediplantex   
15619         Diclofenac 50 Cửu Long   
17855                   Noclaud 50mg   
20073                         Ukapin   
21173                  Degodas 2,5mg   
26796                  Carmotop 50mg   
29085  Paracetamol 650mg Mediplantex   
29558                    Roscef 10mg   
31051                    Zaclid 20mg   
31797                 RICHSTATIN 5mg   
32953                          ADMED   
32956                       Penveril   
33587                  Degodas 2,5mg   
33770        Cao ích mẫu Mediplantex   
34107           Medikids Mediplantex   
35279                    

In [5]:
df.duplicated()

0        False
1        False
2        False
3        False
4        False
         ...  
45083     True
45084    False
45085    False
45086     True
45087    False
Length: 45088, dtype: bool

In [6]:
df = df.drop_duplicates()
df.info()

<class 'pandas.core.frame.DataFrame'>
Index: 38471 entries, 0 to 45087
Data columns (total 6 columns):
 #   Column            Non-Null Count  Dtype 
---  ------            --------------  ----- 
 0   drug_name         38471 non-null  object
 1   link              38471 non-null  object
 2   chemicals         38471 non-null  object
 3   usage             38471 non-null  object
 4   side_effects      38471 non-null  object
 5   chemicals_length  38471 non-null  int64 
dtypes: int64(1), object(5)
memory usage: 2.1+ MB


In [7]:
sub_df = df[['drug_name', 'chemicals']].copy()
print(sub_df.head())

                 drug_name                                          chemicals
0             Parafizz 650  Paracetamol 650 mg\n(Tá dược gồm: acid citric ...
1                   Damrin               Mỗi viên nang chứa:\nDiacerein 50 mg
2                   Gonesi                           Mỗi viên của Gonesi chứa
3              Bayer (Đức)  Trong mỗi viên Progynova 2mg chứa:\nHoạt chất:...
4  Bổ thận dương Nhất Nhất                        Mỗi viên nén bao phim chứa:


## LLM pretrained model

In [None]:
import pandas as pd
import numpy as np
import json
import re
import asyncio
from typing import List, Dict, Optional, Union, Any
import hashlib
import httpx
from dataclasses import dataclass, field
from pathlib import Path
import warnings
import logging
from functools import wraps
import time
import gc
import pickle
import os

warnings.filterwarnings('ignore')

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('ner_processing.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

@dataclass
class ProductionNERConfig:
    api_key: str = ''
    api_base_url: str = "https://api.openai.com/v1"
    model: str = "gpt-3.5-turbo"
    temperature: float = 0.0
    max_tokens: int = 600
    timeout_seconds: int = 20
    max_concurrent_requests: int = 12
    batch_size: int = 20
    inter_batch_delay: float = 0.3
    max_retries: int = 2
    base_delay: float = 0.1
    checkpoint_interval: int = 100
    memory_cleanup_interval: int = 500
    max_processing_time: int = 1800
    cache_dir: Path = field(default_factory=lambda: Path("./cache"))
    checkpoint_dir: Path = field(default_factory=lambda: Path("./checkpoints"))
    
    def __post_init__(self):
        self.cache_dir.mkdir(exist_ok=True)
        self.checkpoint_dir.mkdir(exist_ok=True)

class ProductionPharmaceuticalNER:
    def __init__(self, config: ProductionNERConfig = None):
        self.config = config or ProductionNERConfig()
        self.cache_file = self.config.cache_dir / "production_cache.json"
        self.cache = self._load_cache()
        self.semaphore = asyncio.Semaphore(self.config.max_concurrent_requests)
        self.start_time = None
        self.processed_count = 0
        self.error_count = 0
        
        self.system_prompt = """Extract drug ingredients and doses. Return JSON: [{"ingredient": "name", "dose": "amount"}]. Empty if none: [{"ingredient": "", "dose": null}]"""
        self.user_template = 'Drug: {drug_name} | Text: {text} | JSON:'
        self.default_result = [{'ingredient': '', 'dose': None}]

    def _load_cache(self) -> Dict:
        try:
            if self.cache_file.exists():
                with open(self.cache_file, 'r', encoding='utf-8') as f:
                    cache = json.load(f)
                    logger.info(f"Loaded cache: {len(cache)} entries")
                    return cache
        except Exception as e:
            logger.warning(f"Cache load failed: {e}")
        return {}

    def _save_cache(self):
        try:
            with open(self.cache_file, 'w', encoding='utf-8') as f:
                json.dump(self.cache, f, ensure_ascii=False)
        except Exception as e:
            logger.error(f"Cache save failed: {e}")

    def _save_checkpoint(self, results: List, index: int):
        try:
            checkpoint_file = self.config.checkpoint_dir / f"checkpoint_{index}.pkl"
            with open(checkpoint_file, 'wb') as f:
                pickle.dump({
                    'results': results,
                    'index': index,
                    'timestamp': time.time(),
                    'processed_count': self.processed_count
                }, f)
            logger.info(f"Checkpoint saved at index {index}")
        except Exception as e:
            logger.error(f"Checkpoint save failed: {e}")

    def _load_checkpoint(self) -> tuple:
        try:
            checkpoint_files = list(self.config.checkpoint_dir.glob("checkpoint_*.pkl"))
            if not checkpoint_files:
                return [], 0
            
            latest_checkpoint = max(checkpoint_files, key=os.path.getctime)
            with open(latest_checkpoint, 'rb') as f:
                data = pickle.load(f)
                logger.info(f"Resumed from checkpoint: index {data['index']}")
                return data['results'], data['index']
        except Exception as e:
            logger.error(f"Checkpoint load failed: {e}")
            return [], 0

    def _cleanup_memory(self):
        try:
            gc.collect()
            if len(self.cache) > 10000:
                cache_items = list(self.cache.items())
                self.cache = dict(cache_items[-5000:])
                logger.info("Cache pruned for memory management")
        except Exception as e:
            logger.error(f"Memory cleanup failed: {e}")

    def _check_time_limit(self) -> bool:
        if self.start_time is None:
            return False
        
        elapsed = time.time() - self.start_time
        remaining = self.config.max_processing_time - elapsed
        
        if remaining < 300:
            logger.warning(f"Approaching time limit: {remaining:.0f}s remaining")
            return True
        return False

    async def _fast_api_call(self, text: str, drug_name: str) -> str:
        safe_text = str(text or "").strip()[:400]
        safe_drug_name = str(drug_name or "").strip()[:100]
        
        cache_key = hashlib.md5(f"{safe_drug_name}_{safe_text}".encode()).hexdigest()
        
        if cache_key in self.cache:
            return self.cache[cache_key]
        
        async with self.semaphore:
            for attempt in range(self.config.max_retries):
                try:
                    timeout = httpx.Timeout(self.config.timeout_seconds)
                    
                    async with httpx.AsyncClient(
                        timeout=timeout,
                        limits=httpx.Limits(max_keepalive_connections=10, max_connections=30)
                    ) as client:
                        
                        payload = {
                            "model": self.config.model,
                            "messages": [
                                {"role": "system", "content": self.system_prompt},
                                {"role": "user", "content": self.user_template.format(
                                    text=safe_text, drug_name=safe_drug_name)}
                            ],
                            "temperature": self.config.temperature,
                            "max_tokens": self.config.max_tokens
                        }
                        
                        response = await client.post(
                            f"{self.config.api_base_url}/chat/completions",
                            headers={
                                "Authorization": f"Bearer {self.config.api_key}",
                                "Content-Type": "application/json"
                            },
                            json=payload
                        )
                        
                        if response.status_code == 200:
                            result = response.json()
                            content = result['choices'][0]['message']['content']
                            if content:
                                self.cache[cache_key] = content.strip()
                                return content.strip()
                        elif response.status_code == 429:
                            await asyncio.sleep(self.config.base_delay * (1.2 ** attempt))
                        else:
                            await asyncio.sleep(self.config.base_delay)
                
                except Exception as e:
                    if attempt == self.config.max_retries - 1:
                        logger.error(f"All API attempts failed: {e}")
                        self.error_count += 1
                    await asyncio.sleep(self.config.base_delay)
        
        return self._instant_fallback(safe_text, safe_drug_name)

    def _instant_fallback(self, text: str, drug_name: str) -> str:
        combined = f"{drug_name} {text}".lower()
        
        patterns = [
            r'\b(?:paracetamol|ibuprofen|aspirin|metformin|omeprazole)\b',
            r'\b(?:lactose|cellulose|magnesium|calcium|sodium)\b'
        ]
        
        ingredients = set()
        for pattern in patterns:
            try:
                matches = re.findall(pattern, combined, re.IGNORECASE)
                ingredients.update([m for m in matches if len(m) > 2])
            except:
                continue
        
        doses = re.findall(r'\b\d+\s*(?:mg|ml|g|%)\b', combined, re.IGNORECASE)
        
        if not ingredients and drug_name:
            clean_name = re.sub(r'\d+', '', drug_name.lower()).strip()
            if len(clean_name) > 2:
                ingredients.add(clean_name)
        
        entities = []
        ingredients_list = list(ingredients)
        
        for i in range(max(len(ingredients_list), len(doses), 1)):
            ingredient = ingredients_list[i] if i < len(ingredients_list) else ''
            dose = doses[i] if i < len(doses) else None
            
            if ingredient or dose:
                entities.append({'ingredient': ingredient, 'dose': dose})
        
        return json.dumps(entities if entities else self.default_result, ensure_ascii=False)

    def _parse_response(self, response: str) -> List[Dict[str, Optional[str]]]:
        if not response:
            return self.default_result.copy()
        
        try:
            response = response.strip()
            
            if '[' in response and ']' in response:
                start = response.find('[')
                end = response.rfind(']') + 1
                json_str = response[start:end]
            elif '{' in response and '}' in response:
                start = response.find('{')
                end = response.rfind('}') + 1
                json_str = '[' + response[start:end] + ']'
            else:
                return self.default_result.copy()
            
            entities = json.loads(json_str)
            if not isinstance(entities, list):
                entities = [entities] if isinstance(entities, dict) else []
            
            validated = []
            for entity in entities:
                if isinstance(entity, dict):
                    ingredient = entity.get('ingredient', '').strip().lower()
                    dose = entity.get('dose')
                    
                    if ingredient and len(ingredient) > 1 and ingredient not in {'', 'unknown', 'null', 'none'}:
                        safe_dose = str(dose).strip() if dose and str(dose).lower() not in {'null', 'none', ''} else None
                        validated.append({'ingredient': ingredient, 'dose': safe_dose})
            
            return validated if validated else self.default_result.copy()
            
        except:
            return self.default_result.copy()

    async def process_production_batch(self, texts: List[str], names: List[str], start_index: int = 0) -> List[List[Dict]]:
        self.start_time = time.time()
        
        checkpoint_results, checkpoint_index = self._load_checkpoint()
        if checkpoint_index > start_index:
            logger.info(f"Resuming from checkpoint at index {checkpoint_index}")
            start_index = checkpoint_index
            results = checkpoint_results
        else:
            results = []
        
        total_items = len(texts)
        logger.info(f"Starting production processing: {total_items} items")
        
        for i in range(start_index, total_items, self.config.batch_size):
            if self._check_time_limit():
                logger.warning("Approaching time limit, saving checkpoint and stopping")
                self._save_checkpoint(results, i)
                break
            
            batch_start = time.time()
            batch_texts = texts[i:i + self.config.batch_size]
            batch_names = names[i:i + self.config.batch_size]
            
            batch_num = (i // self.config.batch_size) + 1
            total_batches = (total_items + self.config.batch_size - 1) // self.config.batch_size
            
            logger.info(f"Batch {batch_num}/{total_batches} - Items {i+1}-{min(i+len(batch_texts), total_items)}")
            
            try:
                tasks = [self._fast_api_call(text, name) for text, name in zip(batch_texts, batch_names)]
                responses = await asyncio.wait_for(
                    asyncio.gather(*tasks, return_exceptions=True),
                    timeout=self.config.timeout_seconds * 2
                )
                
                batch_results = []
                for response in responses:
                    if isinstance(response, Exception):
                        batch_results.append(self.default_result.copy())
                        self.error_count += 1
                    else:
                        parsed = self._parse_response(response)
                        batch_results.append(parsed)
                
                results.extend(batch_results)
                self.processed_count += len(batch_results)
                
                batch_time = time.time() - batch_start
                items_per_second = len(batch_texts) / batch_time if batch_time > 0 else 0
                
                elapsed = time.time() - self.start_time
                remaining_items = total_items - (i + len(batch_texts))
                estimated_remaining = remaining_items / items_per_second if items_per_second > 0 else 0
                
                logger.info(f"Batch completed: {batch_time:.1f}s, {items_per_second:.1f} items/s")
                logger.info(f"Progress: {self.processed_count}/{total_items} ({(self.processed_count/total_items)*100:.1f}%)")
                logger.info(f"Elapsed: {elapsed:.1f}s, Est. remaining: {estimated_remaining:.1f}s")
                logger.info(f"Errors: {self.error_count}, Cache size: {len(self.cache)}")
                
                if i % self.config.checkpoint_interval == 0:
                    self._save_checkpoint(results, i + len(batch_texts))
                
                if i % self.config.memory_cleanup_interval == 0:
                    self._cleanup_memory()
                
                if i + self.config.batch_size < total_items:
                    await asyncio.sleep(self.config.inter_batch_delay)
                    
            except asyncio.TimeoutError:
                logger.error(f"Batch {batch_num} timed out, using fallbacks")
                fallback_results = [self.default_result.copy() for _ in batch_texts]
                results.extend(fallback_results)
                self.error_count += len(batch_texts)
            
            except Exception as e:
                logger.error(f"Batch {batch_num} failed: {e}")
                fallback_results = [self.default_result.copy() for _ in batch_texts]
                results.extend(fallback_results)
                self.error_count += len(batch_texts)
        
        self._save_cache()
        
        total_time = time.time() - self.start_time
        logger.info(f"Production processing completed in {total_time:.1f}s")
        logger.info(f"Total processed: {self.processed_count}, Errors: {self.error_count}")
        logger.info(f"Average speed: {self.processed_count/total_time:.1f} items/second")
        
        return results

    def extract_production(self, df: pd.DataFrame) -> pd.DataFrame:
        if df is None or df.empty:
            return pd.DataFrame()
        
        required_cols = ['drug_name', 'chemicals']
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Missing columns: {missing_cols}")
        
        texts = df['chemicals'].fillna('').astype(str).tolist()
        names = df['drug_name'].fillna('').astype(str).tolist()
        
        min_len = min(len(texts), len(names), len(df))
        texts, names = texts[:min_len], names[:min_len]
        
        logger.info(f"Starting production extraction: {len(texts)} pharmaceutical entries")
        logger.info(f"Target completion time: {self.config.max_processing_time/60:.1f} minutes")
        
        async def run():
            return await self.process_production_batch(texts, names)
        
        try:
            try:
                import nest_asyncio
                nest_asyncio.apply()
            except ImportError:
                pass
            
            results = asyncio.run(run())
            
        except Exception as e:
            logger.error(f"Production extraction failed: {e}")
            results = [self.default_result.copy() for _ in texts]
        
        result_df = df.copy()
        
        while len(results) < len(result_df):
            results.append(self.default_result.copy())
        results = results[:len(result_df)]
        
        result_df['extracted_entities'] = results
        
        return result_df

def production_extract(df: pd.DataFrame, api_key: str = None) -> pd.DataFrame:
    config = ProductionNERConfig(
        api_key=api_key or '',
        batch_size=20,
        max_concurrent_requests=12,
        inter_batch_delay=0.3,
        max_retries=2,
        timeout_seconds=20,
        max_processing_time=1800
    )
    
    extractor = ProductionPharmaceuticalNER(config)
    return extractor.extract_production(df)

if __name__ == "__main__":
    try:
        print("PRODUCTION PHARMACEUTICAL NER PIPELINE")
        print("=" * 60)
        print("Target: Complete processing within 30 minutes")
        print("Features: Checkpointing, Progress tracking, Error recovery")
        print("=" * 60)
        
        if 'df' not in locals():
            print("Loading sample data for demonstration...")
            df = pd.DataFrame({
                'drug_name': ['Paracetamol 500mg', 'Omeprazole 20mg', 'Metformin HCl'] * 100,
                'chemicals': [
                    'Paracetamol 500mg, lactose monohydrate, magnesium stearate',
                    'Omeprazole 20mg, cellulose microcrystalline, titanium dioxide',
                    'Metformin hydrochloride 850mg, povidone, magnesium stearate'
                ] * 100
            })
        
        print(f"Dataset size: {len(df)} entries")
        
        start_time = time.time()
        result_df = production_extract(df, api_key='')
        end_time = time.time()
        
        total_time = end_time - start_time
        print(f"\nPRODUCTION RESULTS:")
        print(f"Processing completed in {total_time:.1f} seconds ({total_time/60:.1f} minutes)")
        print(f"Average speed: {len(df)/total_time:.1f} entries/second")
        print(f"Success rate: {len([r for r in result_df['extracted_entities'] if r != [{'ingredient': '', 'dose': None}]])/len(df)*100:.1f}%")
        
        print(f"\nSAMPLE EXTRACTION RESULTS:")
        for i in range(min(5, len(result_df))):
            row = result_df.iloc[i]
            entities = row['extracted_entities']
            print(f"\n{i+1}. {row['drug_name']}:")
            if entities and entities != [{'ingredient': '', 'dose': None}]:
                for j, entity in enumerate(entities, 1):
                    if entity.get('ingredient'):
                        print(f"   • {entity['ingredient'].title()}: {entity.get('dose', 'No dose')}")
            else:
                print("   • No entities extracted")
        
    except Exception as e:
        print(f"Production pipeline failed: {e}")
        logger.error(f"Production execution error: {e}")


PRODUCTION PHARMACEUTICAL NER PIPELINE
Target: Complete processing within 30 minutes
Features: Checkpointing, Progress tracking, Error recovery
Dataset size: 38471 entries


2025-08-11 16:37:12,216 - INFO - Starting production extraction: 38471 pharmaceutical entries
2025-08-11 16:37:12,217 - INFO - Target completion time: 30.0 minutes
2025-08-11 16:37:12,222 - INFO - Starting production processing: 38471 items
2025-08-11 16:37:12,223 - INFO - Batch 1/1924 - Items 1-20
2025-08-11 16:37:16,248 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-11 16:37:16,505 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-11 16:37:16,507 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-11 16:37:16,513 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-11 16:37:16,515 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-11 16:37:16,519 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-08-11 16:37:16,52


PRODUCTION RESULTS:
Processing completed in 1503.8 seconds (25.1 minutes)
Average speed: 25.6 entries/second
Success rate: 9.4%

SAMPLE EXTRACTION RESULTS:

1. Parafizz 650:
   • Paracetamol: 650 mg

2. Damrin:
   • Diacerein: 50 mg

3. Gonesi:
   • Gonesi: mỗi viên

4. Bayer (Đức):
   • Estradiol Valerate: 2mg

5. Bổ thận dương Nhất Nhất:
   • No entities extracted


In [9]:
print(result_df.head(50))

                            drug_name  \
0                        Parafizz 650   
1                              Damrin   
2                              Gonesi   
3                         Bayer (Đức)   
4             Bổ thận dương Nhất Nhất   
5                   Propara 450mg/3ml   
6                             Acepron   
7                             Newtiam   
8                Piroxicam Stada 20mg   
9                             Fuspiro   
10                  Etoposide 'Ebewe'   
11                        Fexostad 60   
12                          Bivinadol   
13              Clazic MR 60mg United   
14                     Ceracept 0,75g   
15                        Trovem 20mg   
16                  Risabin injection   
17                             Bettam   
18                  Becaspira 3.0M UI   
19                       Idisten 20mg   
20                          Betamineo   
21                  Cephalexin 250 mg   
22                Zelfamox 875/125 DT   
23              

## Proposed NER model

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from sklearn.cluster import DBSCAN
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import re
from typing import List, Dict, Tuple, Optional
from dataclasses import dataclass
import json
from pathlib import Path

@dataclass
class HybridNERConfig:
    # Core Supervised Model
    embedding_dim: int = 128
    hidden_dim: int = 256
    num_layers: int = 2
    dropout: float = 0.3
    lr: float = 1e-3
    batch_size: int = 16
    epochs: int = 50
    
    # OWNER Discovery Component
    min_cluster_size: int = 3
    eps: float = 0.3
    confidence_threshold: float = 0.7
    discovery_batch_size: int = 100
    similarity_threshold: float = 0.8
    
    # Hybrid Integration
    supervised_weight: float = 0.7
    discovery_weight: float = 0.3
    validation_threshold: float = 0.6
    update_interval: int = 10

class CRFLayer(nn.Module):
    def __init__(self, num_tags):
        super().__init__()
        self.num_tags = num_tags
        self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
        self.start_transitions = nn.Parameter(torch.randn(num_tags))
        self.end_transitions = nn.Parameter(torch.randn(num_tags))
    
    def forward(self, emissions, tags=None, mask=None):
        if tags is not None:
            return self._compute_loss(emissions, tags, mask)
        return self._viterbi_decode(emissions, mask)
    
    def _compute_loss(self, emissions, tags, mask):
        batch_size, seq_len = tags.shape
        score = torch.zeros(batch_size, device=emissions.device)
        
        for i in range(seq_len):
            if i == 0:
                score += self.start_transitions[tags[:, i]]
            else:
                score += self.transitions[tags[:, i-1], tags[:, i]]
            score += emissions[:, i].gather(1, tags[:, i].unsqueeze(1)).squeeze()
        
        score += self.end_transitions[tags[:, -1]]
        partition = self._compute_partition(emissions, mask)
        return (partition - score).mean()
    
    def _compute_partition(self, emissions, mask):
        batch_size, seq_len, num_tags = emissions.shape
        alpha = self.start_transitions + emissions[:, 0]
        
        for i in range(1, seq_len):
            emit_score = emissions[:, i].unsqueeze(1)
            trans_score = self.transitions.unsqueeze(0)
            next_alpha = alpha.unsqueeze(2) + trans_score + emit_score
            alpha = torch.logsumexp(next_alpha, dim=1)
        
        return torch.logsumexp(alpha + self.end_transitions, dim=1)
    
    def _viterbi_decode(self, emissions, mask):
        batch_size, seq_len, num_tags = emissions.shape
        score = self.start_transitions + emissions[:, 0]
        history = []
        
        for i in range(1, seq_len):
            emit_score = emissions[:, i].unsqueeze(1)
            trans_score = self.transitions.unsqueeze(0)
            next_score = score.unsqueeze(2) + trans_score + emit_score
            best_tags = torch.argmax(next_score, dim=1)
            score = torch.max(next_score, dim=1)[0]
            history.append(best_tags)
        
        best_last_tags = torch.argmax(score + self.end_transitions, dim=1)
        best_paths = [best_last_tags]
        
        for hist in reversed(history):
            best_last_tags = hist.gather(1, best_last_tags.unsqueeze(1)).squeeze()
            best_paths.append(best_last_tags)
        
        return torch.stack(list(reversed(best_paths)), dim=1)

class SupervisedNER(nn.Module):
    def __init__(self, vocab_size, config):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, config.embedding_dim)
        self.lstm = nn.LSTM(config.embedding_dim, config.hidden_dim // 2, 
                           config.num_layers, bidirectional=True, 
                           dropout=config.dropout, batch_first=True)
        self.dropout = nn.Dropout(config.dropout)
        self.classifier = nn.Linear(config.hidden_dim, 5)  # O, B-CHEM, I-CHEM, B-DOSE, I-DOSE
        self.crf = CRFLayer(5)
    
    def forward(self, x, tags=None, mask=None):
        embeddings = self.embedding(x)
        lstm_out, _ = self.lstm(embeddings)
        lstm_out = self.dropout(lstm_out)
        emissions = self.classifier(lstm_out)
        return self.crf(emissions, tags, mask)

class OWNERDiscovery:
    def __init__(self, config):
        self.config = config
        self.vectorizer = TfidfVectorizer(max_features=1000, ngram_range=(1, 3))
        self.discovered_entities = set()
        self.entity_patterns = []
        self.confidence_scores = {}
    
    def discover_entities(self, texts: List[str]) -> Dict[str, List[Dict]]:
        # Prompt-based encoding
        chemical_prompts = [self._extract_chemical_candidates(text) for text in texts]
        dose_prompts = [self._extract_dose_candidates(text) for text in texts]
        
        # Contrastive learning via clustering
        chemical_clusters = self._cluster_candidates(chemical_prompts)
        dose_clusters = self._cluster_candidates(dose_prompts)
        
        # Entity discovery
        discovered = {
            'chemicals': self._validate_clusters(chemical_clusters, 'chemical'),
            'doses': self._validate_clusters(dose_clusters, 'dose')
        }
        
        return discovered
    
    def _extract_chemical_candidates(self, text: str) -> List[str]:
        patterns = [
            r'\b[A-Za-z]{3,}(?:ol|ate|ine|ide|ose)\b',
            r'\b[A-Za-z]+\s+[A-Za-z]+(?:ate|ine|ol)\b',
            r'\b(?:lactose|cellulose|stearate|oxide)\b',
        ]
        candidates = []
        for pattern in patterns:
            candidates.extend(re.findall(pattern, text, re.IGNORECASE))
        return list(set(candidates))
    
    def _extract_dose_candidates(self, text: str) -> List[str]:
        patterns = [
            r'\d+(?:\.\d+)?(?:mg|ml|g|mcg|%)',
            r'\d+(?:\.\d+)?\s*(?:mg|ml|g|mcg|%)',
        ]
        candidates = []
        for pattern in patterns:
            candidates.extend(re.findall(pattern, text, re.IGNORECASE))
        return list(set(candidates))
    
    def _cluster_candidates(self, candidate_lists: List[List[str]]) -> List[List[str]]:
        if not candidate_lists:
            return []
        
        all_candidates = [item for sublist in candidate_lists for item in sublist]
        if len(all_candidates) < self.config.min_cluster_size:
            return []
        
        try:
            vectors = self.vectorizer.fit_transform(all_candidates)
            clustering = DBSCAN(eps=self.config.eps, min_samples=self.config.min_cluster_size)
            labels = clustering.fit_predict(vectors.toarray())
            
            clusters = {}
            for candidate, label in zip(all_candidates, labels):
                if label != -1:
                    if label not in clusters:
                        clusters[label] = []
                    clusters[label].append(candidate)
            
            return list(clusters.values())
        except:
            return []
    
    def _validate_clusters(self, clusters: List[List[str]], entity_type: str) -> List[Dict]:
        validated = []
        for cluster in clusters:
            if len(cluster) >= self.config.min_cluster_size:
                confidence = len(cluster) / (len(cluster) + 1)
                if confidence >= self.config.confidence_threshold:
                    validated.append({
                        'entities': cluster,
                        'type': entity_type,
                        'confidence': confidence,
                        'size': len(cluster)
                    })
        return validated

class HybridPharmaceuticalNER:
    def __init__(self, config: HybridNERConfig = None):
        self.config = config or HybridNERConfig()
        self.supervised_model = None
        self.owner_discovery = OWNERDiscovery(self.config)
        self.vocab = {'word2idx': {}, 'idx2word': {}}
        self.training_history = []
        self.discovery_history = []
        
    def fit(self, df: pd.DataFrame, text_col: str = 'chemicals'):
        texts = df[text_col].dropna().tolist()
        
        # Phase 1: Train supervised model
        train_data = self._prepare_supervised_data(texts)
        self._build_vocab(train_data)
        
        vocab_size = len(self.vocab['word2idx'])
        self.supervised_model = SupervisedNER(vocab_size, self.config)
        self._train_supervised(train_data)
        
        # Phase 2: OWNER discovery
        discoveries = self.owner_discovery.discover_entities(texts)
        self.discovery_history.append(discoveries)
        
        # Phase 3: Hybrid integration
        self._integrate_discoveries(discoveries, texts)
        
        return {
            'supervised_entities': len(train_data),
            'discovered_chemicals': sum(len(d['entities']) for d in discoveries['chemicals']),
            'discovered_doses': sum(len(d['entities']) for d in discoveries['doses']),
            'vocab_size': vocab_size
        }
    
    def predict(self, text: str) -> List[Dict]:
        if not self.supervised_model:
            return []
        
        # Supervised prediction
        supervised_entities = self._predict_supervised(text)
        
        # Discovery-based prediction
        discovered_entities = self._predict_discovery(text)
        
        # Hybrid fusion
        fused_entities = self._fuse_predictions(supervised_entities, discovered_entities)
        
        return fused_entities
    
    def _prepare_supervised_data(self, texts: List[str]) -> List[Tuple]:
        training_data = []
        for text in texts:
            tokens = text.lower().split()
            labels = self._auto_label(tokens)
            if any(l != 'O' for l in labels):
                training_data.append((tokens, labels))
        return training_data
    
    def _auto_label(self, tokens: List[str]) -> List[str]:
        labels = ['O'] * len(tokens)
        
        chemical_patterns = [r'.*ol$', r'.*ate$', r'.*ine$', r'lactose', r'cellulose']
        dose_patterns = [r'\d+mg', r'\d+ml', r'\d+g', r'\d+%']
        
        for i, token in enumerate(tokens):
            for pattern in chemical_patterns:
                if re.match(pattern, token):
                    labels[i] = 'B-CHEM'
                    break
            for pattern in dose_patterns:
                if re.match(pattern, token):
                    labels[i] = 'B-DOSE'
                    break
        
        return labels
    
    def _build_vocab(self, train_data: List[Tuple]):
        word2idx = {'<PAD>': 0, '<UNK>': 1}
        for tokens, _ in train_data:
            for token in tokens:
                if token not in word2idx:
                    word2idx[token] = len(word2idx)
        
        self.vocab = {
            'word2idx': word2idx,
            'idx2word': {v: k for k, v in word2idx.items()}
        }
    
    def _train_supervised(self, train_data: List[Tuple]):
        optimizer = torch.optim.Adam(self.supervised_model.parameters(), lr=self.config.lr)
        
        for epoch in range(self.config.epochs):
            total_loss = 0
            for tokens, labels in train_data:
                token_ids = [self.vocab['word2idx'].get(t, 1) for t in tokens]
                label_ids = [{'O': 0, 'B-CHEM': 1, 'I-CHEM': 2, 'B-DOSE': 3, 'I-DOSE': 4}[l] for l in labels]
                
                x = torch.tensor([token_ids], dtype=torch.long)
                y = torch.tensor([label_ids], dtype=torch.long)
                
                optimizer.zero_grad()
                loss = self.supervised_model(x, y)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
            
            if epoch % 10 == 0:
                avg_loss = total_loss / len(train_data)
                self.training_history.append({'epoch': epoch, 'loss': avg_loss})
    
    def _predict_supervised(self, text: str) -> List[Dict]:
        tokens = text.lower().split()
        token_ids = [self.vocab['word2idx'].get(t, 1) for t in tokens]
        x = torch.tensor([token_ids], dtype=torch.long)
        
        self.supervised_model.eval()
        with torch.no_grad():
            predictions = self.supervised_model(x)
        
        pred_labels = predictions[0].tolist()
        entities = self._extract_entities(tokens, pred_labels)
        
        return [{'text': e, 'type': t, 'confidence': 0.8, 'source': 'supervised'} 
                for e, t in entities]
    
    def _predict_discovery(self, text: str) -> List[Dict]:
        chemical_candidates = self.owner_discovery._extract_chemical_candidates(text)
        dose_candidates = self.owner_discovery._extract_dose_candidates(text)
        
        entities = []
        for chem in chemical_candidates:
            entities.append({'text': chem, 'type': 'chemical', 'confidence': 0.6, 'source': 'discovery'})
        for dose in dose_candidates:
            entities.append({'text': dose, 'type': 'dose', 'confidence': 0.6, 'source': 'discovery'})
        
        return entities
    
    def _fuse_predictions(self, supervised: List[Dict], discovered: List[Dict]) -> List[Dict]:
        fused = []
        
        # Prioritize supervised predictions
        for pred in supervised:
            pred['weight'] = self.config.supervised_weight
            fused.append(pred)
        
        # Add non-overlapping discoveries
        supervised_texts = {pred['text'].lower() for pred in supervised}
        for pred in discovered:
            if pred['text'].lower() not in supervised_texts:
                pred['weight'] = self.config.discovery_weight
                pred['confidence'] *= self.config.discovery_weight
                if pred['confidence'] >= self.config.validation_threshold:
                    fused.append(pred)
        
        return sorted(fused, key=lambda x: x['confidence'], reverse=True)
    
    def _extract_entities(self, tokens: List[str], pred_labels: List[int]) -> List[Tuple]:
        label_map = {0: 'O', 1: 'B-CHEM', 2: 'I-CHEM', 3: 'B-DOSE', 4: 'I-DOSE'}
        entities = []
        current_entity = []
        current_type = None
        
        for token, label_id in zip(tokens, pred_labels):
            label = label_map[label_id]
            
            if label.startswith('B-'):
                if current_entity:
                    entities.append((' '.join(current_entity), 
                                   'chemical' if current_type == 'CHEM' else 'dose'))
                current_entity = [token]
                current_type = label[2:]
            elif label.startswith('I-') and current_type:
                current_entity.append(token)
            else:
                if current_entity:
                    entities.append((' '.join(current_entity), 
                                   'chemical' if current_type == 'CHEM' else 'dose'))
                current_entity = []
                current_type = None
        
        if current_entity:
            entities.append((' '.join(current_entity), 
                           'chemical' if current_type == 'CHEM' else 'dose'))
        
        return entities
    
    def _integrate_discoveries(self, discoveries: Dict, texts: List[str]):
        # Update training data with validated discoveries
        for discovery_type in discoveries:
            for cluster in discoveries[discovery_type]:
                if cluster['confidence'] >= self.config.validation_threshold:
                    self.owner_discovery.discovered_entities.update(cluster['entities'])

def hybrid_extract(df: pd.DataFrame, config: HybridNERConfig = None) -> pd.DataFrame:
    """Main extraction function"""
    model = HybridPharmaceuticalNER(config)
    
    stats = model.fit(df)

    df['hybrid_entities'] = df['chemicals'].apply(lambda x: model.predict(x) if pd.notna(x) else [])
    
    return df

# Usage example with breakthrough analysis
def analyze_breakthrough(df: pd.DataFrame) -> Dict:
    """Analyze breakthrough points for visualization"""
    config = HybridNERConfig()
    model = HybridPharmaceuticalNER(config)
    
    # Fit and collect metrics
    training_stats = model.fit(df)
    
    return {
        'model': model
    }
