In [None]:
import csv
import os
import easyocr
import re
import random
import torch
from tqdm import tqdm 
class TextFeatureExtractor:
    def __init__(self, csv_file, dataset_folder):
        self.csv_file = csv_file
        self.dataset_folder = dataset_folder
        self.batch_size = 2000
        self.reader = easyocr.Reader(['en'], gpu=torch.cuda.is_available())
        self.pattern = r'(\d*\.?\d+)(?:\s*-\s*(\d*\.?\d+))?\s*(centimetre|centimeter|centimeters|cm|foot|feet|ft|inch|inches|in|"|metre|meter|meters|m|mtr|millimetre|millimeter|millimeters|mm|yard|yards|yd|gram|grams|g|grm|kilogram|kilograms|kg|kgs|microgram|micrograms|mcg|µg|milligram|milligrams|mg|ounce|ounces|oz|pound|pounds|lb|lbs|ton|tons|T|tn|kilovolt|kilovolts|kv|kV|millivolt|millivolts|mv|mV|volt|volts|v|V|kilowatt|kilowatts|kw|kW|watt|watts|w|W|centilitre|centilitres|cl|cL|cubic foot|cubic feet|ft³|ft\^3|cu ft|cubic inch|cubic inches|in³|in\^3|cu in|cup|cups|c\.|decilitre|decilitres|dl|dL|fluid ounce|fluid ounces|fl oz|fl\. oz\.|gallon|gallons|gal|imperial gallon|imperial gallons|imp gal|litre|litres|liter|liters|l|L|ltr|microlitre|microlitres|µl|mcL|millilitre|millilitres|ml|mL|pint|pints|pt|quart|quarts|qt)'

        self.entity_unit_map = {        'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
        'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
        'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
        'item_weight': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
        'maximum_weight_recommendation': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
        'voltage': {'kilovolt', 'millivolt', 'volt'},
        'wattage': {'kilowatt', 'watt'},
        'item_volume': {'centilitre',
        'cubic foot',
        'cubic inch',
        'cup',
        'decilitre',
        'fluid ounce',
        'gallon',
        'imperial gallon',
        'litre',
        'microlitre',
        'millilitre',
        'pint',
        'quart'}}

        self.normalization = {
            'length': {
                'centimetre': 'centimetre', 'centimeter': 'centimetre', 'centimeters': 'centimetre', 'cm': 'centimetre',
                'foot': 'foot', 'feet': 'foot', 'ft': 'foot', 'inch': 'inch', 'inches': 'inch', 'in': 'inch', '"': 'inch',
                'metre': 'metre', 'meter': 'metre', 'meters': 'metre', 'm': 'metre', 'millimetre': 'millimetre', 'mm': 'millimetre',
                'yard': 'yard', 'yd': 'yard'
            },
            'weight': {
                'gram': 'gram', 'grams': 'gram', 'g': 'gram',
                'kilogram': 'kilogram', 'kg': 'kilogram', 'microgram': 'microgram', 'mcg': 'microgram',
                'milligram': 'milligram', 'mg': 'milligram',
                'ounce': 'ounce', 'oz': 'ounce', 'pound': 'pound', 'lb': 'pound', 'ton': 'ton'
            },
            'voltage': {
                'kilovolt': 'kilovolt', 'kv': 'kilovolt', 'volt': 'volt', 'v': 'volt', 
            },
            'wattage': {
                'kilowatt': 'kilowatt', 'kw': 'kilowatt', 'watt': 'watt', 'w': 'watt'
            },
            'volume': {
                'litre': 'litre', 'liter': 'litre', 'ml': 'millilitre', 'fluid ounce': 'fluid ounce', 'fl oz': 'fluid ounce',
                'gallon': 'gallon', 'quart': 'quart', 'pint': 'pint', 'cup':'cup'}}

    def process_csv_in_batches(self):
            output_file = 'test_out1.csv'

            with open(self.csv_file, mode='r') as file, open(output_file, mode='w', newline='') as outfile:
                csv_reader = list(csv.reader(file)) 
                csv_writer = csv.writer(outfile)

                # Write header to output.csv
                csv_writer.writerow(['index', 'prediction'])

                csv_reader = csv_reader[1:]  # Skip header row
                total_rows = len(csv_reader)
                processed_images = set()

                # Process data in batches
                for batch_start in tqdm(range(0, total_rows, self.batch_size), desc="Processing Batches"):
                    batch_end = min(batch_start + self.batch_size, total_rows)
                    batch_rows = csv_reader[batch_start:batch_end]

                    # Process each image in the batch
                    for idx, row in enumerate(batch_rows, start=batch_start):
                        image_link = row[1]
                        image_name = image_link.split("/")[-1]
                        entity_name = row[3]
                       

                        processed_images.add(image_name)
                        image_path = os.path.join(self.dataset_folder, image_name)

                        if not os.path.exists(image_path):
                            csv_writer.writerow([idx,""])
                            continue
                        value1 = ""
                        try:
                            result = self.reader.readtext(image_path, detail=0)
                            extracted_text = ' '.join(result)
                            features = self.extract_features(extracted_text)
                            normalized_features = self.normalize_units(row[3], features)
                            if normalized_features:
                              for value, unit in normalized_features:
                               if unit in self.entity_unit_map.get(entity_name, set()):
                                value1 = value + " " + unit
                               else:
                                value1 = ""
                        except Exception as e:
                            print(f"Error processing image {image_name} at index {idx} : {e}")
                            value1 = ""    

                        csv_writer.writerow([idx, value1])
                    
                    outfile.flush()
                    torch.cuda.empty_cache()
                

    def extract_features(self, text):
        matches = re.findall(self.pattern, text)
        return matches if matches else []

    def normalize_units(self, entity_name, features):
        entity_type = self.get_entity_type(entity_name)  # Find the entity type for the current entity
        normalized = []
        
        if entity_type in self.normalization:
            normalization_dict = self.normalization[entity_type]  # Fetch normalization for the specific entity type
            for feature in features:
                if len(feature) == 3:
                    value, __, unit = feature
                elif len(feature) == 2:
                    value, unit = feature
                else:
                    continue
                
                normalized_unit = normalization_dict.get(unit, unit)  # Lookup unit only in the relevant entity type
                normalized.append((value, normalized_unit))
        
        return normalized

    def get_entity_type(self, entity_name):
        # This function maps entity names to entity types, improving performance by narrowing the normalization search
        if entity_name in {'width', 'depth', 'height'}:
            return 'length'
        elif entity_name in {'item_weight', 'maximum_weight_recommendation'}:
            return 'weight'
        elif entity_name == 'voltage':
            return 'voltage'
        elif entity_name == 'wattage':
            return 'wattage'
        elif entity_name == 'item_volume':
            return 'volume'
        else:
            return None

    def filter_features(self, entity_name, features):
        if entity_name in self.entity_unit_map:
            allowed_units = self.entity_unit_map[entity_name]
            return [(value, unit) for value, unit in features if unit in allowed_units]
        return []
csv_file_path = 'dataset/test.csv'  # Replace with the path to your CSV file
dataset_folder_path = 'test_images'  # Replace with the path to your image folder

# Create an instance of TextFeatureExtractor
extractor = TextFeatureExtractor(csv_file=csv_file_path, dataset_folder=dataset_folder_path)

# Process the CSV file to extract and validate features from the images
extractor.process_csv_in_batches()      
