In [1]:
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from unit_extractor import extract_value_unit

# Load the pre-trained Faster R-CNN model
def load_faster_rcnn():
    model = fasterrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    return model

# Function to extract region proposals from the image
def get_region_proposals(model, image_path):
    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    transform = T.ToTensor()
    image_tensor = transform(image).unsqueeze(0)

    # Get region proposals from the model
    with torch.no_grad():
        predictions = model(image_tensor)  # Outputs: boxes, labels, scores

    boxes = predictions[0]['boxes']  # Get bounding boxes
    return boxes, image

11 ounce

10 gram

10 gram

11 ounce

11 ounce

15 kilogram

4 fluid ounce

7 foot

3 cubic inch

30 watt

15 centimetre

20 fluid ounce

5000 volt



In [2]:
import clip
import torch
from PIL import Image

# Load CLIP model
def load_clip_model():
    model, preprocess = clip.load("ViT-B/32", device="cuda" if torch.cuda.is_available() else "cpu")
    return model, preprocess

# Match the region with the entity_name using CLIP
def get_best_region(boxes, image, entity_name, clip_model, preprocess):
    best_region = None
    best_similarity = -1  # Initialize with a low value

    # Encode entity_name to text features
    text_input = clip.tokenize([entity_name]).to(next(clip_model.parameters()).device)
    text_features = clip_model.encode_text(text_input)

    # Process each region and calculate similarity
    for box in boxes:
        x1, y1, x2, y2 = map(int, box)
        region = image.crop((x1, y1, x2, y2))
        region = preprocess(region).unsqueeze(0).to(next(clip_model.parameters()).device)

        # Compute region features
        region_features = clip_model.encode_image(region)

        # Compute similarity between region and entity_name
        similarity = torch.cosine_similarity(text_features, region_features).item()

        # Select the region with the highest similarity
        if similarity > best_similarity:
            best_region = region
            best_similarity = similarity

    return best_region, best_similarity

In [24]:
import pytesseract
from PIL import Image
import numpy as np

# Set Tesseract command path
pytesseract.pytesseract.tesseract_cmd = '/usr/bin/tesseract'

def tensor_to_pil_image(tensor):
    # 1. Remove batch dimension
    tensor = tensor.squeeze(0)  # From [1, 3, 224, 224] to [3, 224, 224]

    # 2. Convert torch tensor to NumPy array
    # If the tensor is normalized, you need to un-normalize it
    # Assuming the tensor values are in the range [0, 1], convert to [0, 255]
    tensor = tensor.mul(255).byte()  # Convert to byte (uint8)

    # 3. Permute the channels from [C, H, W] to [H, W, C]
    tensor = tensor.permute(1, 2, 0).cpu().numpy()

    # 4. Convert NumPy array to PIL image
    image = Image.fromarray(tensor)
    return image

def ocr_and_loss(best_region, ground_truth):
    # Convert the torch tensor to a PIL image
    if isinstance(best_region, torch.Tensor):
        best_region = tensor_to_pil_image(best_region)

    ocr_result = pytesseract.image_to_string(best_region).strip().lower()

    ocr_result = extract_value_unit(ocr_result)

    # Ground truth entity value
    target = ground_truth.strip().lower()

    # Custom cross-entropy loss: 1 if wrong, 0 if correct
    return 0 if ocr_result == target else 1

# Example of string-based loss function
def cross_entropy_loss(pred, target):
    return 0 if pred.strip().lower() == target.strip().lower() else 1

In [25]:
def train_model(train_image_paths, train_entity_names, train_labels, num_epochs=10):
    # Load Faster R-CNN and CLIP models
    faster_rcnn_model = load_faster_rcnn()
    clip_model, preprocess = load_clip_model()

    optimizer = torch.optim.Adam(clip_model.parameters(), lr=0.0001)

    for epoch in range(num_epochs):
        total_loss = 0
        for image_path, entity_name, ground_truth in zip(train_image_paths, train_entity_names, train_labels):
            # 1. Get region proposals
            boxes, image = get_region_proposals(faster_rcnn_model, image_path)

            # 2. Find the best matching region
            best_region, best_similarity = get_best_region(boxes, image, entity_name, clip_model, preprocess)

            # 3. Apply OCR and calculate loss
            if best_region is not None:
                loss = ocr_and_loss(best_region, ground_truth)
                total_loss += loss

        # Backpropagation
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss}")

In [26]:
import pandas as pd
# Load the CSV file
train_df = pd.read_csv("/DATA1/ai23mtech12001/Amazon/amazon-ml/dataset/train.csv")

# Extract image paths, entity names, and labels
train_image_paths = "train_images/" + train_df['image_link'].apply(lambda x: x.split('/')[-1])
train_entity_names = train_df['entity_name'].tolist()
train_labels = train_df['entity_value'].tolist()
train_image_paths = train_image_paths.tolist()

# Train the model
train_model(train_image_paths, train_entity_names, train_labels, num_epochs=10)



TesseractNotFoundError: /usr/bin/tesseract is not installed or it's not in your PATH. See README file for more information.

In [None]:
def inference(image_path, entity_name):
    # Load pre-trained models
    faster_rcnn_model = load_faster_rcnn()
    clip_model, preprocess = load_clip_model()

    # Get region proposals
    boxes, image = get_region_proposals(faster_rcnn_model, image_path)

    # Find the best region
    best_region, best_similarity = get_best_region(boxes, image, entity_name, clip_model, preprocess)

    # Extract text using OCR from the best region
    if best_region is not None:
        ocr_result = pytesseract.image_to_string(best_region)
        print(f"OCR Result: {ocr_result}")
    else:
        print("No matching region found.")

# Example inference call
inference("/DATA1/ai23mtech12001/Amazon/amazon-ml/test_images/21+i52HRW4L.jpg", "entity_name")