In [1]:
!pip install --no-index --find-links=/kaggle/input/ultralytics-package/ultralytics_package ultralytics

Looking in links: /kaggle/input/ultralytics-package/ultralytics_package
Processing /kaggle/input/ultralytics-package/ultralytics_package/ultralytics-8.3.75-py3-none-any.whl
Processing /kaggle/input/ultralytics-package/ultralytics_package/ultralytics_thop-2.0.14-py3-none-any.whl (from ultralytics)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.75 ultralytics-thop-2.0.14


In [2]:
import io
import re
import os
import cv2
import sys
import json
import torch
import logging
import requests
import numpy as np
import pandas as pd
from PIL import Image
from ultralytics import YOLO
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor



logging.getLogger("transformers").setLevel(logging.ERROR)

chart_types = {
    0: "horizontal_bar",
    1: "vertical_bar",
    2: "dot",
    3: "line",
    4: "scatter"
}

# Models
model_charttype_path = "/kaggle/input/chart-type-classification/pytorch/default/1/fine_tuned_chart_classification_model_mid_v3_5.h5"
model_Charttype = load_model(model_charttype_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_Deplot = Pix2StructForConditionalGeneration.from_pretrained('/kaggle/input/deplot/pytorch/deplot/1').to(device)
processor_Deplot = Pix2StructProcessor.from_pretrained('/kaggle/input/deplot/pytorch/deplot/1')

model_Deplot_horizontal = Pix2StructForConditionalGeneration.from_pretrained('/kaggle/input/deplot-finetuned-horizontal/deplot_finetuned_v2').to(device)
processor_horizontal = Pix2StructProcessor.from_pretrained('/kaggle/input/deplot-finetuned-horizontal/deplot_finetuned_v2')

model_Deplot_line = Pix2StructForConditionalGeneration.from_pretrained('/kaggle/input/deplot-finetuned-line-version-2/deplot_finetuned').to(device)
processor_line = Pix2StructProcessor.from_pretrained('/kaggle/input/deplot-finetuned-line-version-2/deplot_finetuned')

model_Deplot_dot = Pix2StructForConditionalGeneration.from_pretrained('/kaggle/input/deplot-finetuned-dot/deplot_finetuned').to(device)
processor_dot = Pix2StructProcessor.from_pretrained('/kaggle/input/deplot-finetuned-dot/deplot_finetuned')


# Load YOLO models
tick_model = YOLO("/kaggle/input/tick-model-yolo/other/default/2/tick_best.pt")  # Detects ticks & text boxes
scatter_model = YOLO("/kaggle/input/scatter-dot-yolo-03/other/default/8/best_scatter_29_date_10_03_2025.pt")  # Detects scatter points
plot_model = YOLO("/kaggle/input/plot-model-yolo/other/default/1/plot_best.pt")  # Detects plot area

Creating new Ultralytics Settings v0.0.6 file âœ… 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


In [3]:
# Display function for deplot output
def display_deplot_output(deplot_output):
    deplot_output = deplot_output.replace("<0x0A>", "\n").replace(" | ", "\t")
    second_a_index = [m.start() for m in re.finditer('\t', deplot_output)][1]
    last_newline_index = deplot_output.rfind('\n', 0, second_a_index)
    title = deplot_output[:last_newline_index]
    table = deplot_output[last_newline_index + 1:]
    data = io.StringIO(table)
    df = pd.read_csv(data, sep='\t')
    return df

# Function to handle deplot with model and processor
def deplot_with_model(path, model, processor, device, font_path, chart_type):
    try:
        image = Image.open(path)
        #display(image)
        if chart_type in ["dot", "horizontal_bar", "line"]:
            inputs = processor(
                images=image,
                text="Generate data series:",
                return_tensors="pt",
                font_path=font_path
            )   
        else:
            inputs = processor(
                images=image,
                text="Generate underlying data table of the figure below:",
                return_tensors="pt",
                font_path=font_path
            )
        inputs = {key: value.to(device) for key, value in inputs.items()}
        predictions = model.generate(**inputs, max_new_tokens=512)
        output = processor.decode(predictions[0], skip_special_tokens=True)

        if chart_type in ["dot", "horizontal_bar", "line"]:
            # Parse JSON output
            data = json.loads(output)
            print(data)
            x_data_series = ";".join([str(item['x']) for item in data])
            y_data_series = ";".join([str(item['y']) for item in data])
        else:
            # Handle default output format
            df = display_deplot_output(output)
            # Drop rows with NaN values
            df = df.dropna()
            x_data_series = ";".join(df[df.columns[0]].astype(str))
            y_data_series = ";".join(df[df.columns[1]].astype(str))
    except:
        x_data_series = "0"
        y_data_series = "0"

    return x_data_series, y_data_series

# Helper function to preprocess the image
def preprocess_image(image_path, target_size=(224, 224)):
    image = load_img(image_path, target_size=target_size)
    image = img_to_array(image) / 255.0
    return np.expand_dims(image, axis=0)

# Updated function to check if a series contains only floats
def is_float_series(series):
    try:
        series = [float(item) for item in series.split(";")]
        return True
    except ValueError:
        return False

In [4]:
import pandas as pd
import numpy as np

def is_invalid_number(value):
    try:
        n = float(value)
        return np.isnan(n) or np.isinf(n)
    except ValueError:
        return True

def clean_series(x_series, y_series, chart_type):
    try:
        x_values = x_series.split(";")
        y_values = y_series.split(";")

        # Check lengths match
        if len(x_values) != len(y_values):
            return "0", "0"

        # Handle invalid values in x and y
        for i in range(len(x_values)):
            if chart_type in ["vertical_bar", "line", "dot"]:
                # x is categorical, y is numeric
                if is_invalid_number(y_values[i]):
                    x_values[i] = "0"
                    y_values[i] = "0"
            elif chart_type == "horizontal_bar":
                # x is numeric, y is categorical
                if is_invalid_number(x_values[i]):
                    x_values[i] = "0"
                    y_values[i] = "0"
            elif chart_type == "scatter":
                # x and y are both numeric
                if is_invalid_number(x_values[i]) or is_invalid_number(y_values[i]):
                    x_values[i] = "0"
                    y_values[i] = "0"

        # Rejoin the series
        cleaned_x_series = ";".join(x_values)
        cleaned_y_series = ";".join(y_values)
        return cleaned_x_series, cleaned_y_series

    except Exception as e:
        # If any unexpected error occurs, set both x and y to "0"
        return "0", "0"


In [5]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Use a working tokenizer
processor = TrOCRProcessor.from_pretrained("/kaggle/input/trocr_model/other/default/1/trocr_model")

# Load the model from converted SafeTensors
model_text = VisionEncoderDecoderModel.from_pretrained("/kaggle/input/trocr_model/other/default/1/trocr_model").to(device)

# Path to the folder with images and to font
images_folder = "/kaggle/input/benetech-making-graphs-accessible/test/images"
output_csv_path = "submission.csv"
font_path = "/kaggle/input/huruff/Arial.ttf"

# Dataframe to store results
results = []

# Iterate through images in the folder
for img_name in os.listdir(images_folder):
    img_path = os.path.join(images_folder, img_name)
    if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
        base_name = os.path.splitext(img_name)[0]
        
        # Preprocess and predict chart type
        image = preprocess_image(img_path)
        predictions = model_Charttype.predict(image, verbose=0)
        predicted_chart_type_index = np.argmax(predictions, axis=1)[0]
        predicted_chart_type = chart_types.get(predicted_chart_type_index, "unknown")

        # Select model and processor based on chart type
        if predicted_chart_type == "horizontal_bar":
            selected_model = model_Deplot_horizontal
            selected_processor = processor_horizontal
        elif predicted_chart_type == "line":
            selected_model = model_Deplot_line
            selected_processor = processor_line
        elif predicted_chart_type == "dot":
            selected_model = model_Deplot_dot
            selected_processor = processor_dot
        elif predicted_chart_type=="scatter":
            selected_model = model_text
            selected_processor = processor
        else:
            selected_model = model_Deplot
            selected_processor = processor_Deplot
        
        if(predicted_chart_type!="scatter"):
            # Use deplot to extract data
            x_data_series, y_data_series = deplot_with_model(
                img_path, selected_model, selected_processor, device, font_path, predicted_chart_type
            )
        else:
            # **Run YOLO inference for plot area**
            plot_results = plot_model(img_path)
            plot_x1, plot_y1, plot_x2, plot_y2 = None, None, None, None
            for box in plot_results[0].boxes:
                plot_x1, plot_y1, plot_x2, plot_y2 = map(int, box.xyxy[0])  # Get bounding box of the detected plot
            
            # Load image
            image = cv2.imread(img_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            img_h, img_w,_ = image.shape
            
            # Run YOLO inference for ticks
            tick_results = tick_model(img_path)
            
            # Store detected elements
            tick_positions_x = []  # (image_x, plot_x_value)
            tick_positions_y = []  # (image_y, plot_y_value)
            
            # Text box positions & recognized values
            text_boxes_x = []  # X-axis text positions
            text_boxes_y = []  # Y-axis text positions
            
            for box in tick_results[0].boxes:
                x1, y1, x2, y2 = map(int, box.xyxy[0])
                class_id = int(box.cls[0])  # 0=x-tick, 1=y-tick, 2=x-text, 3=y-text, 4=xy-tick
            
                if class_id in [2, 3]:  # If it's a text box, recognize text
                    cropped_text = image[y1:y2, x1:x2]
                    pil_crop = Image.fromarray(cropped_text)
            
                    # Process for TrOCR
                    pixel_values = selected_processor(pil_crop, return_tensors="pt").pixel_values.to(device)
            
                    # Predict text
                    with torch.no_grad():
                        generated_ids = model_text.generate(pixel_values)
                        predicted_text = selected_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
            
                    if class_id == 2:  # X-axis text
                        text_boxes_x.append((x1, y1, x2, y2, predicted_text))
                    elif class_id == 3:  # Y-axis text
                        text_boxes_y.append((x1, y1, x2, y2, predicted_text))
            
                elif class_id in [0, 1, 4]:  # Tick marks
                    center_x = (x1 + x2) // 2
                    center_y = (y1 + y2) // 2
                    if class_id == 0:  # X-axis tick
                        tick_positions_x.append((center_x, center_y))
                    elif class_id == 1: # Y-axis tick
                        tick_positions_y.append((center_x, center_y))
                    else:  
                        tick_positions_x.append((center_x, center_y))
                        tick_positions_y.append((center_x, center_y))

            
            # **Sort tick positions before matching**
            tick_positions_x = sorted(tick_positions_x, key=lambda t: t[0])  # Sort X-ticks by center_x
            tick_positions_y = sorted(tick_positions_y, key=lambda t: t[1], reverse=True)  # Sort Y-ticks by center_y

            # **Sort text boxes before matching**
            text_boxes_x = sorted(text_boxes_x, key=lambda t: t[0])  # Sort by X-position (left to right)
            text_boxes_y = sorted(text_boxes_y, key=lambda t: t[1], reverse=True)  # Sort by Y-position (bottom to top)
            
            # **Match text to nearest tick positions**
            def match_ticks_to_text(tick_positions, text_boxes):
                matched_ticks = []
                used_ticks = set()  # To prevent assigning the same tick multiple times
            
                for text_box in text_boxes:
                    x1, y1, x2, y2, text_value = text_box
                    text_center_x = (x1 + x2) // 2
                    text_center_y = (y1 + y2) // 2

                    # Skip if text_value is missing or invalid
                    if not text_value.strip():  # If text is empty or only whitespace
                        continue
            
                    # Find the closest tick using Euclidean distance
                    closest_tick = None
                    min_distance = float("inf")
            
                    for tick_pos in tick_positions:
                        tick_x, tick_y = tick_pos
                        distance = np.sqrt((text_center_x - tick_x) ** 2 + (text_center_y - tick_y) ** 2)
            
                        if distance < min_distance and tick_pos not in used_ticks:
                            min_distance = distance
                            closest_tick = tick_pos
            
                    if closest_tick:
                        try:
                            plot_value = float(text_value)  # Convert to float
                            matched_ticks.append((closest_tick, plot_value))
                            used_ticks.add(closest_tick)  # Mark tick as used
                        except ValueError:
                            continue  # Skip invalid text values
                        
                return matched_ticks
            
            # Get matched tick values
            x_axis_values = match_ticks_to_text(tick_positions_x, text_boxes_x)
            y_axis_values = match_ticks_to_text(tick_positions_y, text_boxes_y)

            if(len(x_axis_values)==0 or len(y_axis_values)==0):
                x_data_series = "0"
                y_data_series = "0"
                x_data_series, y_data_series = clean_series(x_data_series, y_data_series, predicted_chart_type)
                results.append({"id": f"{base_name}_x", "data_series": x_data_series, "chart_type": predicted_chart_type})
                results.append({"id": f"{base_name}_y", "data_series": y_data_series, "chart_type": predicted_chart_type})
                continue

            
            # Keep only the first number from the inner tuple in x_axis_values
            x_axis_values = [(t[0][0], t[1]) for t in x_axis_values]
            
            # Keep only the second number from the inner tuple in y_axis_values
            y_axis_values = [(t[0][1], t[1]) for t in y_axis_values]
            
            # Define a margin (in pixels) for tolerance
            MARGIN = 5  # You can adjust this value based on your specific requirements
            
            # Extract only the plot values (with margin for error)
            x_img, x_plot = zip(*x_axis_values) if x_axis_values else ([], [])
            y_img, y_plot = zip(*y_axis_values) if y_axis_values else ([], [])

            # Apply margin for tolerance when extracting plot values
            if x_img and x_plot:
                x_min, x_max = min(x_img) - MARGIN, max(x_img) + MARGIN
                x_img = [x for x in x_img if x_min <= x <= x_max]  # Apply margin filter
                x_plot = [x_plot[i] for i in range(len(x_plot)) if x_min <= x_img[i] <= x_max]  # Filter matching values
            
            if y_img and y_plot:
                y_min, y_max = min(y_img) - MARGIN, max(y_img) + MARGIN
                y_img = [y for y in y_img if y_min <= y <= y_max]  # Apply margin filter
                y_plot = [y_plot[i] for i in range(len(y_plot)) if y_min <= y_img[i] <= y_max]  # Filter matching values


            def fix_incorrect_values(axis_values, tolerance=1.5, max_iterations=20):
                """
                Fix incorrect values in axis data by checking for regular spacing and applying a tolerance range.
                
                Args:
                axis_values (list): List of values for X or Y axis.
                tolerance (float): Tolerance range for how much a value can deviate from regular spacing.
                max_iterations (int): Maximum iterations to prevent infinite loops.
                
                Returns:
                list: Corrected axis values.
                """
                if len(axis_values) < 3:
                    return axis_values  # Can't fix with less than 3 values
            
                diffs = [axis_values[i+1] - axis_values[i] for i in range(len(axis_values)-1)]
                spacing = np.median(diffs)  # Use median to handle outliers in spacing
                corrected_values = axis_values.copy()
                
                for _ in range(max_iterations):
                    corrected = False
                    for i in range(1, len(corrected_values) - 1):
                        expected_value = corrected_values[i - 1] + spacing
                        if abs(corrected_values[i] - expected_value) > spacing * tolerance:
                            corrected_values[i] = expected_value  # Fix the value
                            corrected = True
                    
                    if not corrected:
                        break  # Stop if no changes were made in this iteration
                
                return corrected_values
            
            # Fix the x and y axis values if needed
            x_plot_corrected = fix_incorrect_values(x_plot, tolerance=1.5)
            y_plot_corrected = fix_incorrect_values(y_plot, tolerance=1.5)

            # **Fit linear transformation for x and y**
            #x_transform = np.poly1d(np.polyfit(x_img, x_plot, 1)) if len(x_img) > 1 else lambda x: x
            #y_transform = np.poly1d(np.polyfit(y_img, y_plot, 1)) if len(y_img) > 1 else lambda y: y

            import numpy as np
            from sklearn.linear_model import RANSACRegressor
            from sklearn.preprocessing import PolynomialFeatures
            from sklearn.pipeline import make_pipeline
            
            # --- Robust regression for the x-axis transformation ---
            try:
                if len(x_img) > 1:
                    model_x = make_pipeline(PolynomialFeatures(degree=1), RANSACRegressor())
                    model_x.fit(np.array(x_img).reshape(-1, 1), np.array(x_plot_corrected))
                    transform_x = lambda x: model_x.predict(np.array([[x]]))[0]
                else:
                    transform_x = lambda x: x
            except Exception as e:
                print("Error in RANSAC for x, falling back to polyfit: ", e)
                if len(x_img) > 1:
                    coeffs = np.polyfit(x_img, x_plot_corrected, 1)
                    transform_x = np.poly1d(coeffs)
                else:
                    transform_x = lambda x: x
            
            # --- Robust regression for the y-axis transformation ---
            try:
                if len(y_img) > 1:
                    model_y = make_pipeline(PolynomialFeatures(degree=1), RANSACRegressor())
                    model_y.fit(np.array(y_img).reshape(-1, 1), np.array(y_plot_corrected))
                    transform_y = lambda y: model_y.predict(np.array([[y]]))[0]
                else:
                    transform_y = lambda y: y
            except Exception as e:
                print("Error in RANSAC for y, falling back to polyfit: ", e)
                if len(y_img) > 1:
                    coeffs = np.polyfit(y_img, y_plot_corrected, 1)
                    transform_y = np.poly1d(coeffs)
                else:
                    transform_y = lambda y: y
            
            def transform_point(x, y):
                return transform_x(x), transform_y(y)
            
            # --- Run YOLO inference for scatter points ---
            scatter_results = scatter_model(img_path)  # Ensure this call is made!
            
            # --- Extract scatter positions & transform them to plot space ---
            scatter_plot_coords = []
            for box in scatter_results[0].boxes:
                try:
                    x1, y1, x2, y2 = map(int, box.xyxy[0])
                except Exception as e:
                    print("Error processing box coordinates: ", e)
                    continue
                center_x = (x1 + x2) // 2
                center_y = (y1 + y2) // 2
            
                # Remove scatter points outside the detected plot area
                if plot_x1 and plot_y1 and plot_x2 and plot_y2:
                    if not (plot_x1 <= center_x <= plot_x2 and plot_y1 <= center_y <= plot_y2):
                        continue
            
                try:
                    plot_x, plot_y = transform_point(center_x, center_y)
                    scatter_plot_coords.append((plot_x, plot_y))
                except Exception as e:
                    print("Error transforming scatter point: ", e)
                    continue
            
            # --- Sort scatter points: first by x-value, then by y-value if x is the same ---
            scatter_plot_coords = sorted(scatter_plot_coords, key=lambda p: (p[0], p[1]))
            
            # Extract x and y values
            #x_data_series = ";".join(str(x) for x, _ in scatter_plot_coords)
            #y_data_series = ";".join(str(y) for _, y in scatter_plot_coords)

            # Extract x and y values with 4 number after point
            x_data_series = ";".join(f"{x:.4f}" for x, _ in scatter_plot_coords)
            y_data_series = ";".join(f"{y:.4f}" for _, y in scatter_plot_coords)

        
        if(x_data_series==None or len(x_data_series)==0):
            x_data_series = "0"
        if(y_data_series==None or len(y_data_series)==0):
            y_data_series = "0"
        print(x_data_series)
        print(y_data_series)
        x_data_series,y_data_series = clean_series(x_data_series, y_data_series, predicted_chart_type)
        results.append({
            "id": f"{base_name}_x",
            "data_series": x_data_series,
            "chart_type": predicted_chart_type
        })
        results.append({
            "id": f"{base_name}_y",
            "data_series": y_data_series,
            "chart_type": predicted_chart_type
        })



        

#Save results to CSV
submission = pd.DataFrame(results)
submission.to_csv(output_csv_path, index=False)
print(f"Results saved to {output_csv_path}")



[{'x': '0', 'y': 0.0}, {'x': '6', 'y': -1.3353}, {'x': '12', 'y': -2.6353}, {'x': '18', 'y': -1.9653}, {'x': '24', 'y': -3.2753}]
0;6;12;18;24
0.0;-1.3353;-2.6353;-1.9653;-3.2753
 21-Feb; 22-Feb; 23-Feb; 24-Feb; 25-Feb; 26-Feb; 27-Feb; 28-Feb; 29-Feb; 01-Mar; 02-Mar; 03-Mar; 04-Mar; 05-Mar; 06-Mar; 07-Mar; 08-Mar; 09-Mar; 10-Mar
89000;151192;172700;177800;137500;99168;17242;41422;60168;66027;53941;44475;64653;79171;82392;102623;130650;101611;8283

image 1/1 /kaggle/input/benetech-making-graphs-accessible/test/images/00f5404753cf.jpg: 480x640 1 plot, 75.3ms
Speed: 10.2ms preprocess, 75.3ms inference, 90.3ms postprocess per image at shape (1, 3, 480, 640)

image 1/1 /kaggle/input/benetech-making-graphs-accessible/test/images/00f5404753cf.jpg: 480x640 5 x-ticks, 6 y-ticks, 6 x-tick-texts, 6 y-tick-texts, 1 xy-tick, 62.2ms
Speed: 1.8ms preprocess, 62.2ms inference, 1.4ms postprocess per image at shape (1, 3, 480, 640)





image 1/1 /kaggle/input/benetech-making-graphs-accessible/test/images/00f5404753cf.jpg: 608x800 39 scatter_points, 93.6ms
Speed: 2.5ms preprocess, 93.6ms inference, 1.6ms postprocess per image at shape (1, 3, 608, 800)
4.9586;4.9911;4.9911;5.9642;5.9642;5.9642;6.9697;6.9697;6.9697;7.9753;7.9753;7.9753;8.9809;8.9809;8.9809;9.9864;9.9864;9.9864;10.9595;10.9920;10.9920;10.9920;11.9651;11.9651;11.9975;11.9975;11.9975;12.9706;12.9706;12.9706;13.9762;13.9762;13.9762;14.9818;14.9818;14.9818;15.9873;15.9873;15.9873
14.1284;11.0092;12.0848;12.0848;13.1604;14.1284;14.1284;16.0645;17.0325;17.0325;18.1081;19.0761;20.0441;21.1197;22.0877;21.1197;22.0877;23.0557;24.1313;21.1197;22.0877;23.0557;25.2069;26.0673;23.0557;24.0237;25.0993;24.0237;26.0673;27.0354;25.0993;27.0354;28.1109;26.0673;27.0354;29.0790;29.0790;30.0470;31.0150
 Group 1; Group 2
3.6;8.4
[{'x': '0.0', 'y': 0.0132}, {'x': '0.4', 'y': 0.0132}, {'x': '0.8', 'y': 0.0132}, {'x': '1.2', 'y': 0.0132}, {'x': '1.6', 'y': 0.0132}, {'x': '2.0',

In [6]:
submission

Unnamed: 0,id,data_series,chart_type
0,000b92c3b098_x,0;6;12;18;24,line
1,000b92c3b098_y,0.0;-1.3353;-2.6353;-1.9653;-3.2753,line
2,01b45b831589_x,21-Feb; 22-Feb; 23-Feb; 24-Feb; 25-Feb; 26-Fe...,vertical_bar
3,01b45b831589_y,89000;151192;172700;177800;137500;99168;17242;...,vertical_bar
4,00f5404753cf_x,4.9586;4.9911;4.9911;5.9642;5.9642;5.9642;6.96...,scatter
5,00f5404753cf_y,14.1284;11.0092;12.0848;12.0848;13.1604;14.128...,scatter
6,00dcf883a459_x,Group 1; Group 2,vertical_bar
7,00dcf883a459_y,3.6;8.4,vertical_bar
8,007a18eb4e09_x,0.0;0.4;0.8;1.2;1.6;2.0;2.4;2.8,line
9,007a18eb4e09_y,0.0132;0.0132;0.0132;0.0132;0.0132;0.0132;0.01...,line
