In [19]:
import cv2
import pytesseract
import numpy as np
import pandas as pd
import re
from collections import defaultdict

pytesseract.pytesseract.tesseract_cmd = r'/usr/bin/tesseract'


In [17]:

class LayoutAwareOCR:
    """Truly layout-aware OCR that detects form structure automatically"""
    
    def __init__(self, image_path):
        self.original = cv2.imread(image_path)
        if self.original is None:
            raise FileNotFoundError(f"Cannot load image: {image_path}")
        
        self.preprocessed = None
        self.binary = None
        self.layout_elements = []
        self.form_fields = defaultdict(str)
        self.medications = []
        
    def preprocess(self):
        """Minimal preprocessing - let layout detection handle the rest"""
        # Resize to standard width for consistency
        h, w = self.original.shape[:2]
        target_w = 2000
        if w != target_w:
            scale = target_w / w
            new_h = int(h * scale)
            img = cv2.resize(self.original, (target_w, new_h), interpolation=cv2.INTER_CUBIC)
        else:
            img = self.original.copy()
        
        # Convert to grayscale
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        # Light denoising
        gray = cv2.fastNlMeansDenoising(gray, h=8)
        
        # Adaptive threshold (better for varying lighting)
        binary = cv2.adaptiveThreshold(
            gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
            cv2.THRESH_BINARY, 21, 10
        )
        
        self.preprocessed = img
        self.binary = binary
        return binary
    
    def detect_layout(self):
        """Use Tesseract's built-in layout analysis"""
        # Check if image needs inversion
        avg_val = np.mean(self.binary)
        if avg_val < 127:
            img_to_analyze = cv2.bitwise_not(self.binary)
        else:
            img_to_analyze = self.binary
        
        # PSM 3: Fully automatic page segmentation (best for layout detection)
        data = pytesseract.image_to_data(
            img_to_analyze,
            config='--oem 3 --psm 3',
            output_type=pytesseract.Output.DICT
        )
        
        # Extract layout elements with bounding boxes
        n_boxes = len(data['text'])
        elements = []
        
        for i in range(n_boxes):
            text = data['text'][i].strip()
            if not text:
                continue
            
            conf = int(data['conf'][i])
            if conf < 0:  # Skip invalid confidence
                continue
            
            x = data['left'][i]
            y = data['top'][i]
            w = data['width'][i]
            h = data['height'][i]
            block_num = data['block_num'][i]
            par_num = data['par_num'][i]
            line_num = data['line_num'][i]
            
            elements.append({
                'text': text,
                'conf': conf,
                'bbox': (x, y, w, h),
                'block': block_num,
                'par': par_num,
                'line': line_num,
                'area': w * h
            })
        
        self.layout_elements = elements
        return elements
    
    def detect_form_structure(self):
        """Identify form fields using layout analysis"""
        if not self.layout_elements:
            self.detect_layout()
        
        # Group elements by lines (same block, par, line)
        lines = defaultdict(list)
        for elem in self.layout_elements:
            key = (elem['block'], elem['par'], elem['line'])
            lines[key].append(elem)
        
        # Sort elements in each line by x-coordinate
        for key in lines:
            lines[key].sort(key=lambda e: e['bbox'][0])
        
        # Detect field patterns
        for line_key, line_elems in lines.items():
            line_text = ' '.join([e['text'] for e in line_elems])
            
            # Patient name pattern
            if re.search(r'FOR.*name.*address', line_text, re.I):
                # Next line likely contains patient info
                next_block = line_key[0] + 1
                patient_elems = [e for e in self.layout_elements 
                               if e['block'] == next_block]
                if patient_elems:
                    self.form_fields['patient_name'] = ' '.join(
                        [e['text'] for e in sorted(patient_elems, 
                         key=lambda x: x['bbox'][0])]
                    )
            
            # Date pattern
            if re.search(r'DATE', line_text, re.I):
                # Find text in same line or nearby
                date_candidates = [e['text'] for e in line_elems 
                                 if re.search(r'\d{1,2}.*\d{2,4}', e['text'])]
                if date_candidates:
                    self.form_fields['date'] = date_candidates[0]
        
        return self.form_fields
    
    def detect_table_structure(self):
        """Detect table rows and columns using contour analysis"""
        # Detect horizontal lines
        h, w = self.binary.shape
        horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (w//30, 1))
        horizontal_lines = cv2.morphologyEx(self.binary, cv2.MORPH_OPEN, 
                                           horizontal_kernel, iterations=2)
        
        # Detect vertical lines
        vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, h//30))
        vertical_lines = cv2.morphologyEx(self.binary, cv2.MORPH_OPEN, 
                                         vertical_kernel, iterations=2)
        
        # Find line intersections (cell corners)
        table_structure = cv2.bitwise_and(horizontal_lines, vertical_lines)
        
        # Find contours of cells
        contours, _ = cv2.findContours(table_structure, cv2.RETR_TREE, 
                                       cv2.CHAIN_APPROX_SIMPLE)
        
        cells = []
        for cnt in contours:
            x, y, w, h = cv2.boundingRect(cnt)
            if w > 50 and h > 20:  # Filter small noise
                cells.append({'bbox': (x, y, w, h), 'content': []})
        
        # Sort cells by position (top-to-bottom, left-to-right)
        cells.sort(key=lambda c: (c['bbox'][1], c['bbox'][0]))
        
        return cells, horizontal_lines, vertical_lines
    
    def extract_medications_from_table(self):
        """Extract medication information using detected table structure"""
        cells, h_lines, v_lines = self.detect_table_structure()
        
        # Assign text elements to cells
        for elem in self.layout_elements:
            ex, ey, ew, eh = elem['bbox']
            elem_center_x = ex + ew/2
            elem_center_y = ey + eh/2
            
            for cell in cells:
                cx, cy, cw, ch = cell['bbox']
                if (cx <= elem_center_x <= cx+cw and 
                    cy <= elem_center_y <= cy+ch):
                    cell['content'].append(elem)
                    break
        
        # Group cells into rows
        rows = defaultdict(list)
        for cell in cells:
            if cell['content']:
                row_y = cell['bbox'][1]
                rows[row_y].append(cell)
        
        # Sort cells within each row by x-coordinate
        medications = []
        for row_y in sorted(rows.keys()):
            row_cells = sorted(rows[row_y], key=lambda c: c['bbox'][0])
            
            row_texts = []
            for cell in row_cells:
                cell_text = ' '.join([e['text'] for e in 
                                     sorted(cell['content'], 
                                           key=lambda x: x['bbox'][0])])
                row_texts.append(cell_text)
            
            # Expect at least 2 columns (medicine, dosage)
            if len(row_texts) >= 2:
                med_name = row_texts[0]
                dosage = row_texts[1] if len(row_texts) > 1 else ''
                frequency = row_texts[2] if len(row_texts) > 2 else ''
                
                # Filter out likely headers/labels
                if (med_name and 
                    not re.search(r'^(inscription|subscription|signa|rx)', 
                                 med_name, re.I)):
                    medications.append({
                        'medicine': self.normalize_text(med_name),
                        'dosage': dosage,
                        'frequency': frequency
                    })
        
        self.medications = medications
        return medications
    
    def normalize_text(self, text):
        """Clean and normalize extracted text"""
        if not text:
            return ''
        # Remove multiple spaces
        text = re.sub(r'\s+', ' ', text)
        # Remove special characters but keep alphanumeric and common punctuation
        text = re.sub(r'[^\w\s\-\./]', '', text)
        return text.strip()
    
    def fallback_medication_extraction(self):
        """Fallback: Use region-based extraction if table detection fails"""
        h, w = self.binary.shape
        
        # Define likely prescription area (middle 40-70% of document)
        rx_region = self.binary[int(0.3*h):int(0.7*h), :]
        
        # Check if image needs inversion
        avg_val = np.mean(rx_region)
        if avg_val < 127:
            rx_region = cv2.bitwise_not(rx_region)
        
        # Try multiple PSM modes for handwriting
        best_result = []
        for psm in [4, 6, 11, 13]:  # PSM 13 is for raw line (good for handwriting)
            try:
                text = pytesseract.image_to_string(
                    rx_region,
                    config=f'--oem 3 --psm {psm}'
                )
                lines = [l.strip() for l in text.split('\n') if l.strip()]
                if len(lines) > len(best_result):
                    best_result = lines
            except:
                continue
        
        # Parse lines into medication entries
        medications = []
        for line in best_result:
            # Look for patterns like "Medicine XYZ 10mg"
            parts = line.split()
            if len(parts) >= 2:
                # Assume first part(s) are medicine name
                # Last part with digits might be dosage
                dosage_idx = -1
                for i, part in enumerate(parts):
                    if re.search(r'\d+\s*(mg|ml|g|mcg)', part, re.I):
                        dosage_idx = i
                        break
                
                if dosage_idx > 0:
                    med_name = ' '.join(parts[:dosage_idx])
                    dosage = ' '.join(parts[dosage_idx:])
                    medications.append({
                        'medicine': self.normalize_text(med_name),
                        'dosage': dosage,
                        'frequency': ''
                    })
                else:
                    medications.append({
                        'medicine': self.normalize_text(line),
                        'dosage': '',
                        'frequency': ''
                    })
        
        if not self.medications:  # Only use fallback if main method failed
            self.medications = medications
        
        return medications
    
    def process(self):
        """Main processing pipeline"""
        print("Step 1: Preprocessing...")
        self.preprocess()
        
        print("Step 2: Detecting layout...")
        self.detect_layout()
        
        print("Step 3: Identifying form fields...")
        self.detect_form_structure()
        
        print("Step 4: Extracting medications from table...")
        self.extract_medications_from_table()
        
        print("Step 5: Fallback extraction if needed...")
        if not self.medications:
            print("  (Using fallback method)")
            self.fallback_medication_extraction()
        
        return self.create_dataframe()
    
    def create_dataframe(self):
        """Create pandas DataFrame from extracted data"""
        if not self.medications:
            return pd.DataFrame(columns=['Medicine', 'Dosage', 'Frequency'])
        
        df = pd.DataFrame(self.medications)
        df.columns = ['Medicine', 'Dosage', 'Frequency']
        
        # Add metadata
        print(f"\nPatient: {self.form_fields.get('patient_name', 'Not detected')}")
        print(f"Date: {self.form_fields.get('date', 'Not detected')}")
        print(f"\nExtracted {len(df)} medication entries:")
        
        return df
    
    def save_debug_images(self, output_dir='debug_outputs'):
        """Save intermediate images for debugging"""
        import os
        os.makedirs(output_dir, exist_ok=True)
        
        if self.preprocessed is not None:
            cv2.imwrite(f'{output_dir}/1_preprocessed.png', self.preprocessed)
        if self.binary is not None:
            cv2.imwrite(f'{output_dir}/2_binary.png', self.binary)
        
        # Draw bounding boxes on layout elements
        if self.layout_elements:
            layout_viz = self.preprocessed.copy()
            for elem in self.layout_elements:
                x, y, w, h = elem['bbox']
                cv2.rectangle(layout_viz, (x, y), (x+w, y+h), (0, 255, 0), 2)
                cv2.putText(layout_viz, elem['text'][:10], (x, y-5),
                           cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
            cv2.imwrite(f'{output_dir}/3_layout_detection.png', layout_viz)
        
        print(f"Debug images saved to {output_dir}/")



In [18]:

if __name__ == "__main__":
    image_path = "/workspaces/new_ocr_attempt2/sample_prescription.png"
    
    ocr = LayoutAwareOCR(image_path)
    df = ocr.process()
    
    print("\n" + "="*60)
    print(df)
    print("="*60)
    
    # Save results
    output_csv = image_path.replace('.png', '_output.csv')
    df.to_csv(output_csv, index=False)
    print(f"\nSaved to: {output_csv}")
    
    # Save debug images
    ocr.save_debug_images()

Step 1: Preprocessing...
Step 2: Detecting layout...
Step 3: Identifying form fields...
Step 4: Extracting medications from table...
Step 5: Fallback extraction if needed...
  (Using fallback method)

Patient: John K Doe, HMB, VSN
Date: Not detected

Extracted 11 medication entries:

                  Medicine Dosage Frequency
0                    V. O-                 
1   SVYOVOCF FUE Guess wre                 
2                 CYY FIOF                 
3         R Superscription                 
4                 gm or ml                 
5          mn Bellideywrne                 
6                  15 Vane                 
7                 120 lane                 
8            Amphegee goed                 
9        WW te FL Pobetior                 
10           Se Sm tid ac.                 

Saved to: /workspaces/new_ocr_attempt2/sample_prescription_output.csv
Debug images saved to debug_outputs/
