In [2]:
import pandas as pd
import numpy as np
import os
import re
from PIL import Image
import requests
import pytesseract
import sys
import cv2
sys.path.append('/Users/harsh/Desktop/AMAZON ML/student_resource 3/src/')


from utils import download_images, parse_string
from constants import entity_unit_map, allowed_units
from sklearn.metrics import f1_score, precision_score, recall_score
from tqdm import tqdm

# Step 1: Read the datasets
train_df = pd.read_csv('/Users/harsh/Desktop/AMAZON ML/student_resource 3/dataset/train.csv')
test_df = pd.read_csv('/Users/harsh/Desktop/AMAZON ML/student_resource 3/dataset/test.csv')

# Step 2: Download images using utils.py
os.makedirs('images/train', exist_ok=True)
os.makedirs('images/test', exist_ok=True)

print("Downloading training images...")
train_image_links = train_df['image_link'].tolist()
download_images(train_image_links, 'images/train', allow_multiprocessing=True)
train_image_paths = ['images/train/' + os.path.basename(url) for url in train_image_links]
train_df['image_path'] = train_image_paths

print("Downloading test images...")
test_image_links = test_df['image_link'].tolist()
download_images(test_image_links, 'images/test', allow_multiprocessing=True)
test_image_paths = ['images/test/' + os.path.basename(url) for url in test_image_links]
test_df['image_path'] = test_image_paths

# Step 3: Define OCR and extraction functions
def preprocess_image(image_path):
    img = cv2.imread(image_path)
    if img is None:
        return None
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)
    denoised = cv2.medianBlur(thresh, 3)
    return denoised

def extract_text_from_image(image_path):
    preprocessed_image = preprocess_image(image_path)
    if preprocessed_image is None:
        return ''
    text = pytesseract.image_to_string(preprocessed_image)
    return text

unit_variations = {
    'centimetre': ['centimetre', 'centimeter', 'cm', 'cms', 'centimeters'],
    'foot': ['foot', 'ft', 'feet'],
    'millimetre': ['millimetre', 'millimeter', 'mm', 'mms', 'millimeters'],
    'metre': ['metre', 'meter', 'm', 'meters'],
    'inch': ['inch', 'inches', 'in'],
    'yard': ['yard', 'yd', 'yards'],
    'milligram': ['milligram', 'mg', 'milligrams'],
    'kilogram': ['kilogram', 'kg', 'kgs', 'kilograms'],
    'microgram': ['microgram', 'µg', 'mcg', 'micrograms'],
    'gram': ['gram', 'g', 'grams'],
    'ounce': ['ounce', 'oz', 'ounces'],
    'ton': ['ton', 'tons'],
    'pound': ['pound', 'lb', 'lbs', 'pounds'],
    'millivolt': ['millivolt', 'mv', 'millivolts'],
    'kilovolt': ['kilovolt', 'kv', 'kilovolts'],
    'volt': ['volt', 'v', 'volts'],
    'kilowatt': ['kilowatt', 'kw', 'kilowatts'],
    'watt': ['watt', 'w', 'watts'],
    'cubic foot': ['cubic foot', 'ft³', 'cubic feet'],
    'microlitre': ['microlitre', 'μl', 'microliter', 'microliters'],
    'cup': ['cup', 'cups'],
    'fluid ounce': ['fluid ounce', 'fl oz', 'fluid ounces'],
    'centilitre': ['centilitre', 'cl', 'centiliter', 'centiliters'],
    'imperial gallon': ['imperial gallon', 'imp gal', 'imperial gallons'],
    'pint': ['pint', 'pt', 'pints'],
    'decilitre': ['decilitre', 'dl', 'deciliter', 'deciliters'],
    'litre': ['litre', 'l', 'liter', 'liters'],
    'millilitre': ['millilitre', 'ml', 'milliliter', 'milliliters'],
    'quart': ['quart', 'qt', 'quarts'],
    'cubic inch': ['cubic inch', 'in³', 'cubic inches'],
    'gallon': ['gallon', 'gal', 'gallons'],
}


unit_variations_reverse = {}
for standard_unit, variations in unit_variations.items():
    for variation in variations:
        unit_variations_reverse[variation.lower()] = standard_unit

def extract_entity_value(text, entity_name):
    allowed_units = entity_unit_map[entity_name]
    numbers_units = []
    number_pattern = r'(\d+(?:\.\d+)?)'
    unit_pattern_list = []
    for unit in allowed_units:
        for variation in unit_variations.get(unit, []):
            unit_pattern_list.append(re.escape(variation))
    unit_pattern = '|'.join(unit_pattern_list)
    pattern = re.compile(number_pattern + r'\s*(' + unit_pattern + r')', re.IGNORECASE)
    matches = pattern.findall(text)
    for match in matches:
        number = match[0]
        unit_variation = match[1].lower()
        standard_unit = unit_variations_reverse.get(unit_variation)
        if standard_unit in allowed_units:
            numbers_units.append((number, standard_unit))
    if numbers_units:
        number, unit = numbers_units[0]
        number_formatted = format(float(number), '.10g')
        return f"{number_formatted} {unit}"
    else:
        return ""

# Step 4: Process training data and calculate F1 score
print("Processing training data...")
train_predictions = []
for idx, row in tqdm(train_df.iterrows(), total=train_df.shape[0]):
    image_path = row['image_path']
    entity_name = row['entity_name']
    text = extract_text_from_image(image_path)
    prediction = extract_entity_value(text, entity_name)
    train_predictions.append({
        'index': row['index'],
        'prediction': prediction,
        'ground_truth': row['entity_value']
    })

tp = 0
fp = 0
fn = 0
tn = 0

for pred in train_predictions:
    OUT = pred['prediction']
    GT = pred['ground_truth']
    if OUT != "" and GT != "":
        if OUT == GT:
            tp +=1
        else:
            fp +=1
    elif OUT != "" and GT == "":
        fp +=1
    elif OUT == "" and GT != "":
        fn +=1
    elif OUT == "" and GT == "":
        tn +=1

precision = tp / (tp + fp) if (tp + fp) > 0 else 0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

print(f"True Positives: {tp}")
print(f"False Positives: {fp}")
print(f"False Negatives: {fn}")
print(f"True Negatives: {tn}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")

# Step 5: Generate predictions for test data
print("Processing test data...")
test_predictions = []
for idx, row in tqdm(test_df.iterrows(), total=test_df.shape[0]):
    image_path = row['image_path']
    entity_name = row['entity_name']
    text = extract_text_from_image(image_path)
    prediction = extract_entity_value(text, entity_name)
    test_predictions.append({
        'index': row['index'],
        'prediction': prediction
    })

# Step 6: Create submission file and validate
test_out_df = pd.DataFrame(test_predictions)
test_out_df = test_out_df[['index', 'prediction']]
test_out_df.to_csv('test_out.csv', index=False)
print("Predictions saved to test_out.csv")

# Validate using sanity.py
from sanity import sanity_check

try:
    sanity_check('dataset/test.csv', 'test_out.csv')
except Exception as e:
    print(f"Sanity check failed: {e}")
else:
    print("Sanity check passed.")



Downloading training images...


 32%|███▏      | 84427/263859 [10:19<20:37, 144.97it/s] Process SpawnPoolWorker-63:
Process SpawnPoolWorker-62:
Process SpawnPoolWorker-60:
Process SpawnPoolWorker-42:
Process SpawnPoolWorker-40:
Process SpawnPoolWorker-39:
Process SpawnPoolWorker-49:
Process SpawnPoolWorker-50:
Process SpawnPoolWorker-55:
Process SpawnPoolWorker-56:
Process SpawnPoolWorker-59:
Process SpawnPoolWorker-52:
Process SpawnPoolWorker-47:
Process SpawnPoolWorker-44:
Process SpawnPoolWorker-27:
Process SpawnPoolWorker-57:
Process SpawnPoolWorker-64:
Process SpawnPoolWorker-58:
Process SpawnPoolWorker-46:
Process SpawnPoolWorker-37:
Process SpawnPoolWorker-8:
Process SpawnPoolWorker-15:
Process SpawnPoolWorker-29:
Traceback (most recent call last):
Traceback (most recent call last):
Process SpawnPoolWorker-24:
  File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/Library/Frameworks/Python.framework/Versions/3.11/li

KeyboardInterrupt: 