In [3]:
import os
import subprocess
import json
from pdf2image import convert_from_path
from PIL import Image
from multiprocessing import Pool
import uuid

# 全局路径变量，将在主程序中初始化
table_dir = ""
output_pdf_dir = ""
output_png_dir = ""
extracted_tables_dir = ""
base_dir = ""

# Updated LaTeX preamble with UTF-8 support
latex_preamble = r"""
\documentclass[standalone]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage{amsmath, amsthm, amssymb, graphicx, geometry, array}
\usepackage{booktabs, multirow,natbib,tabularx, multicol, bm}
\pagenumbering{gobble}
\begin{document}
\begin{table}[htp]
\centering
\resizebox*{0.5\columnwidth}{!}{ 
"""
latex_end = r"}\end{table}\end{document}"

def render_tex_to_pdf(tex_path, output_pdf_path, timeout=20):
    """Render a LaTeX file to PDF with unique temporary files for parallel processing"""
    with open(tex_path, "r") as file:
        tex_content = file.read()
    
    full_tex_content = latex_preamble + tex_content + latex_end
    
    # Generate unique temporary filename using UUID
    temp_tex_filename = f"temp_{uuid.uuid4().hex}.tex"
    temp_tex_path = os.path.join(os.path.dirname(output_pdf_path), temp_tex_filename)
    
    with open(temp_tex_path, "w") as temp_file:
        temp_file.write(full_tex_content)
    
    try:
        result = subprocess.run(
            ["pdflatex", "-interaction=nonstopmode", "-output-directory", os.path.dirname(output_pdf_path), temp_tex_path],
            check=False,
            capture_output=True,
            text=True,
            encoding='latin-1',
            timeout=timeout
        )
        
        # Get the actual output PDF path from LaTeX compilation
        temp_pdf_path = temp_tex_path.replace(".tex", ".pdf")
        
        if os.path.exists(temp_pdf_path):
            os.rename(temp_pdf_path, output_pdf_path)
            print(f"Successfully rendered {tex_path} to PDF at {output_pdf_path}.")
        else:
            print(f"Error: PDF not generated for {tex_path}. LaTeX output:\n{result.stdout}")
    except subprocess.TimeoutExpired:
        print(f"Timeout expired while rendering {tex_path}. Skipping this file.")
    except Exception as e:
        print(f"Unexpected error rendering {tex_path}: {str(e)}")
    finally:
        # Cleanup temporary files
        for ext in [".aux", ".log", ".out", ".tex"]:
            file_path = temp_tex_path.replace(".tex", ext)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                except Exception as e:
                    print(f"Warning: Could not delete temporary file {file_path}: {str(e)}")

def convert_pdf_to_png(pdf_path, png_path, dpi=300):
    """Convert PDF to PNG with error handling"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        if images:
            images[0].save(png_path, "PNG")
            print(f"Converted {pdf_path} to {png_path}")
        else:
            print(f"Error: No pages found in {pdf_path}")
    except Exception as e:
        print(f"Error converting {pdf_path} to PNG: {str(e)}")

def process_entry(entry):
    """Process a single JSON entry with full error handling"""
    try:
        question_id = entry["id"]
        image_path = entry["image"]
        
        # 从image路径中提取原始文件名（带扩展名）
        original_filename = os.path.basename(image_path)
        
        print(f"Processing {question_id} with image {original_filename}...")
        
        # 只处理prediction，忽略reference
        # 使用原始图片名作为基础文件名（不带路径）
        base_name = original_filename.replace('.', '_')
        tex_path = os.path.join(table_dir, f"{base_name}.tex")
        
        # 创建LaTeX文件
        with open(tex_path, "w") as f:
            f.write(entry["prediction"])
        
        # 生成PDF
        pdf_path = os.path.join(output_pdf_dir, f"{base_name}.pdf")
        render_tex_to_pdf(tex_path, pdf_path)
        
        # 转换为PNG并使用原始文件名（带扩展名）
        if os.path.exists(pdf_path):
            png_path = os.path.join(output_png_dir, original_filename)
            convert_pdf_to_png(pdf_path, png_path)
        else:
            print(f"Skipping PNG conversion for {original_filename} - PDF not found")
        
        print(f"Completed processing {question_id}")
        return True
    except Exception as e:
        print(f"Error processing entry {entry.get('id', 'unknown')}: {str(e)}")
        return False

def process_jsonl(json_file):
    """Main processing function with parallel execution"""
    # 读取 JSONL 文件
    data = []
    with open(json_file, "r") as file:
        for line in file:
            if line.strip():  # 跳过空行
                data.append(json.loads(line))
    
    # 创建必要的目录
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(output_pdf_dir, exist_ok=True)
    os.makedirs(output_png_dir, exist_ok=True)
    
    # 使用半数可用 CPU 进行资源平衡
    num_workers = max(1, os.cpu_count() // 2)
    
    with Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(process_entry, data)
        
        # 监控进度
        success_count = 0
        for i, result in enumerate(results, 1):
            if result:
                success_count += 1
            print(f"Processed {i}/{len(data)} entries ({success_count} successes)")
    
    print(f"\nProcessing complete. Success rate: {success_count}/{len(data)}")

# 用户只需修改这个JSON文件路径
json_file = "/home/lingjun/code/InternVL/internvl_chat/dataset_fintabnet/output_fintabnet_SFT.jsonl"  # 修改此处为你的JSON文件路径

# 动态生成所有路径
base_dir = os.path.dirname(json_file)
extracted_tables_dir = os.path.join(base_dir, "extracted_tables_sft")
table_dir = os.path.join(extracted_tables_dir, "latex_tables")
output_pdf_dir = os.path.join(extracted_tables_dir, "pdf_tables_full")
output_png_dir = os.path.join(extracted_tables_dir, "png_tables_full")

# 执行处理流程
process_jsonl(json_file)
print("Rendering and conversion complete!")

Processing 7 with image image_18.png...Processing 0 with image image_1002.png...Processing 14 with image image_224.png...Processing 28 with image image_285.png...Processing 35 with image image_308.png...Processing 56 with image image_369.png...Processing 70 with image image_462.png...Processing 49 with image image_346.png...Processing 21 with image image_271.png...Processing 77 with image image_521.png...Processing 63 with image image_422.png...
Processing 84 with image image_589.png...Processing 112 with image image_905.png...
Processing 91 with image image_671.png...


Processing 105 with image image_826.png...
Processing 98 with image image_749.png...


Processing 1 with image image_1014.png...
Processing 8 with image image_19.png...Processing 42 with image image_320.png...
Processing 119 with image image_986.png...Processing 15 with image image_247.png...Processing 36 with image image_309.png...


Processing 43 with image image_334.png...Processing 50 with image image_347.png...Pro

In [2]:
import os
import subprocess
import json
from pdf2image import convert_from_path
from PIL import Image
from multiprocessing import Pool
import uuid


# 全局路径变量，将在主程序中初始化
table_dir = ""
output_pdf_dir = ""
output_png_dir = ""
extracted_tables_dir = ""
base_dir = ""

# Updated LaTeX preamble with UTF-8 support
latex_preamble = r"""
\documentclass[standalone]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage{amsmath, amsthm, amssymb, graphicx, geometry, array}
\usepackage{booktabs, multirow,natbib,tabularx, multicol, bm}
\pagenumbering{gobble}
\begin{document}
\begin{table}[htp]
\centering
\resizebox*{0.5\columnwidth}{!}{ 
"""
latex_end = r"}\end{table}\end{document}"

def render_tex_to_pdf(tex_path, output_pdf_path, timeout=20):
    """Render a LaTeX file to PDF with unique temporary files for parallel processing"""
    with open(tex_path, "r") as file:
        tex_content = file.read()
    
    full_tex_content = latex_preamble + tex_content + latex_end
    
    # Generate unique temporary filename using UUID
    temp_tex_filename = f"temp_{uuid.uuid4().hex}.tex"
    temp_tex_path = os.path.join(os.path.dirname(output_pdf_path), temp_tex_filename)
    
    with open(temp_tex_path, "w") as temp_file:
        temp_file.write(full_tex_content)
    
    try:
        result = subprocess.run(
            ["pdflatex", "-interaction=nonstopmode", "-output-directory", os.path.dirname(output_pdf_path), temp_tex_path],
            check=False,
            capture_output=True,
            text=True,
            encoding='latin-1',
            timeout=timeout
        )
        
        # Get the actual output PDF path from LaTeX compilation
        temp_pdf_path = temp_tex_path.replace(".tex", ".pdf")
        
        if os.path.exists(temp_pdf_path):
            os.rename(temp_pdf_path, output_pdf_path)
            print(f"Successfully rendered {tex_path} to PDF at {output_pdf_path}.")
        else:
            print(f"Error: PDF not generated for {tex_path}. LaTeX output:\n{result.stdout}")
    except subprocess.TimeoutExpired:
        print(f"Timeout expired while rendering {tex_path}. Skipping this file.")
    except Exception as e:
        print(f"Unexpected error rendering {tex_path}: {str(e)}")
    finally:
        # Cleanup temporary files
        for ext in [".aux", ".log", ".out", ".tex"]:
            file_path = temp_tex_path.replace(".tex", ext)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                except Exception as e:
                    print(f"Warning: Could not delete temporary file {file_path}: {str(e)}")

def convert_pdf_to_png(pdf_path, png_path, dpi=300):
    """Convert PDF to PNG with error handling"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        if images:
            images[0].save(png_path, "PNG")
            print(f"Converted {pdf_path} to {png_path}")
        else:
            print(f"Error: No pages found in {pdf_path}")
    except Exception as e:
        print(f"Error converting {pdf_path} to PNG: {str(e)}")

def process_entry(entry):
    """Process a single JSON entry with full error handling"""
    try:
        question_id = entry["id"]
        print(f"Processing {question_id}...")
        
        # Create output filenames
        base_files = {
            "prediction": (entry["prediction"], f"{question_id}_prediction"),
            "reference": (entry["reference"], f"{question_id}_reference")
        }
        
        for content_type, (content, base_name) in base_files.items():
            # Generate LaTeX file
            tex_path = os.path.join(table_dir, f"{base_name}.tex")
            with open(tex_path, "w") as f:
                f.write(content)
            
            # Generate PDF
            pdf_path = os.path.join(output_pdf_dir, f"{base_name}.pdf")
            render_tex_to_pdf(tex_path, pdf_path)
            
            # Convert to PNG
            if os.path.exists(pdf_path):
                png_path = os.path.join(output_png_dir, f"{base_name}.png")
                convert_pdf_to_png(pdf_path, png_path)
            else:
                print(f"Skipping PNG conversion for {base_name} - PDF not found")
        
        print(f"Completed processing {question_id}")
        return True
    except Exception as e:
        print(f"Error processing {entry.get('questionId', 'unknown')}: {str(e)}")
        return False

def process_json(json_file):
    """Main processing function with parallel execution"""
    with open(json_file, "r") as file:
        data = json.load(file)
    
    # Create necessary directories
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(output_pdf_dir, exist_ok=True)
    os.makedirs(output_png_dir, exist_ok=True)
    
    # Use half of available CPUs for resource balancing
    num_workers = max(1, os.cpu_count() // 2)
    
    with Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(process_entry, data)
        
        # Monitor progress
        success_count = 0
        for i, result in enumerate(results, 1):
            if result:
                success_count += 1
            print(f"Processed {i}/{len(data)} entries ({success_count} successes)")
    
    print(f"\nProcessing complete. Success rate: {success_count}/{len(data)}")

import concurrent
import numpy as np

def dwt2_simple(img_array):
    """Performs a simple 2D discrete wavelet transform with even-size check"""
    rows, cols = img_array.shape
    rows = rows - rows % 2
    cols = cols - cols % 2
    img_array = img_array[:rows, :cols]
    
    blocks = img_array.reshape(rows//2, 2, cols//2, 2).transpose(0, 2, 1, 3)
    a = blocks[..., 0, 0]
    b = blocks[..., 0, 1]
    c = blocks[..., 1, 0]
    d = blocks[..., 1, 1]
    
    cA = (a + b + c + d) * 0.25
    cH = (a - c) * 0.5
    cV = (a - b) * 0.5
    cD = (a - d) * 0.5
    return cA, cH, cV, cD

def calculate_ssim(img1, img2):
    img1_flat = img1.ravel()
    img2_flat = img2.ravel()
    n = img1_flat.size
    
    sum1 = img1_flat.sum()
    sum2 = img2_flat.sum()
    sum12 = (img1_flat * img2_flat).sum()
    sum1_sq = (img1_flat** 2).sum()
    sum2_sq = (img2_flat** 2).sum()
    
    mean1 = sum1 / n
    mean2 = sum2 / n
    var1 = (sum1_sq - sum1**2 / n) / n
    var2 = (sum2_sq - sum2**2 / n) / n
    
    covar = (sum12 - sum1 * sum2 / n) / (n - 1) if n > 1 else 0.0
    
    C1, C2 = 6.5025, 58.5225
    numerator = (2 * mean1 * mean2 + C1) * (2 * covar + C2)
    denominator = (mean1**2 + mean2**2 + C1) * (var1 + var2 + C2)
    return numerator / denominator

def calculate_cwssim(image1_path, image2_path):
    """Compares two images using Complex Wavelet Structural Similarity Index (CW-SSIM)"""
    try:
        image1 = Image.open(image1_path).convert('L')
        image2 = Image.open(image2_path).convert('L')
        
        image1 = image1.resize(image2.size)
        
        img1_array = np.array(image1)
        img2_array = np.array(image2)
        
        cA1, cH1, cV1, cD1 = dwt2_simple(img1_array)
        cA2, cH2, cV2, cD2 = dwt2_simple(img2_array)
        
        ssim_cA = calculate_ssim(cA1, cA2)
        ssim_cH = calculate_ssim(cH1, cH2)
        ssim_cV = calculate_ssim(cV1, cV2)
        ssim_cD = calculate_ssim(cD1, cD2)
        cwssim_score = (ssim_cA + ssim_cH + ssim_cV + ssim_cD) / 4
        
        return cwssim_score
    
    except Exception as e:
        print(f"Error comparing images {image1_path} and {image2_path}: {e}")
        return 0.0

def process_prefix(prefix, output_png_dir, annotation_images, answer_images):
    annotation_image_path = os.path.join(output_png_dir, annotation_images[prefix])
    if prefix in answer_images:
        answer_image_path = os.path.join(output_png_dir, answer_images[prefix])
        cwssim_score = calculate_cwssim(answer_image_path, annotation_image_path)
        return f"{prefix}: {cwssim_score:.4f}\n"
    else:
        return f"{prefix}: 0.0000\n"
    
def calculate_cwssim_wrapper(task):
    """包装函数处理单个对比任务"""
    prefix, answer_path, annotation_path = task
    try:
        if os.path.exists(answer_path) and os.path.exists(annotation_path):
            return prefix, calculate_cwssim(answer_path, annotation_path)
        return prefix, 0.0
    except Exception as e:
        print(f"Error processing {prefix}: {str(e)}")
        return prefix, 0.0
    
def compare_images_and_save_results(output_png_dir, result_file_path, max_workers=None):
    """优化后的图像对比函数，支持显式指定进程数并批量写入结果"""
    image_files = os.listdir(output_png_dir)
    
    # 构建文件映射关系
    answer_map = {f.split('_')[0]: f for f in image_files if 'prediction.png' in f}
    annotation_map = {f.split('_')[0]: f for f in image_files if 'reference.png' in f}

    # 生成有序任务列表
    tasks = []
    for prefix in sorted(annotation_map.keys(), key=lambda x: int(x)):
        answer_file = answer_map.get(prefix, "")
        task = (
            prefix,
            os.path.join(output_png_dir, answer_file) if answer_file else "",
            os.path.join(output_png_dir, annotation_map[prefix])
        )
        tasks.append(task)

    # 设置默认进程数（CPU核心数75%）
    if max_workers is None:
        max_workers = 16

    # 使用带进度显示的结果收集
    results = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(calculate_cwssim_wrapper, task): task[0] for task in tasks}
        
        # 添加进度显示
        completed = 0
        total = len(tasks)
        print(f"使用 {max_workers} 个进程进行图像对比...")
        
        for future in concurrent.futures.as_completed(futures):
            completed += 1
            prefix = futures[future]
            try:
                _, score = future.result(timeout=60)  # 60秒超时
                results.append((prefix, score))
                # 更新进度显示
                print(f"\r处理进度: {completed}/{total} ({completed/total:.1%})", end="", flush=True)
            except Exception as e:
                print(f"\n处理 {prefix} 时发生错误: {str(e)}")
                results.append((prefix, 0.0))

    # 按原始顺序排序后写入
    with open(result_file_path, "w") as result_file:
        for prefix, score in sorted(results,
            key=lambda x: int(''.join(filter(str.isdigit, x[0])))):
            result_line = f"{prefix}: {score:.4f}\n"
            result_file.write(result_line)
    
    # 空结果检查（新增）
    if len(results) == 0:
        print("严重警告：未生成任何对比结果！")
    else:
        print(f"结果已保存至 {result_file_path} (共 {len(results)} 条记录)")
    

def sort_txt_file(input_file, output_file=None):
    """通用排序函数，路径由调用方提供"""
    try:
        with open(input_file, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]

        def sort_key(line):
            left_num = line.split(':', 1)[0].strip()
            return int(left_num)

        sorted_lines = sorted(lines, key=sort_key)
        output = output_file or input_file.replace('.txt', '_sorted.txt')

        with open(output, 'w') as f:
            f.write('\n'.join(sorted_lines))
        print(f"文件已排序并保存为：{output}")
        return True
    except Exception as e:
        print(f"处理时发生错误：{str(e)}")
        return False

def calculate_average_score(file_path):
    """Calculates the average score from dynamic file path"""
    try:
        with open(file_path, 'r') as file:
            total_score = 0
            count = 0
            for line in file:
                parts = line.split(":")
                if len(parts) == 2:
                    score = float(parts[1].strip())
                    total_score += score
                    count += 1
            return total_score / count if count > 0 else 0.0
    except Exception as e:
        print(f"Error reading the file: {e}")
        return None

def calculate_average_and_ratio(file_path):
    """动态路径版本的非零分数统计"""
    try:
        total_score = 0.0
        valid_count = 0
        non_zero_count = 0
        
        with open(file_path, 'r') as file:
            for line in file:
                if not line.strip():
                    continue
                parts = line.rsplit(':', 1)
                if len(parts) != 2:
                    continue
                try:
                    score = float(parts[1].strip())
                except ValueError:
                    continue
                
                valid_count += 1
                if not np.isclose(score, 0.0):
                    non_zero_count += 1
                    total_score += score

        avg = total_score / non_zero_count if non_zero_count > 0 else 0.0
        ratio = non_zero_count / valid_count if valid_count > 0 else 0.0
        return avg, ratio
    except Exception as e:
        print(f"处理文件时发生错误: {e}")
        return None, None

def process_jsonl(json_file):
    """Main processing function with parallel execution"""
    # 读取 JSONL 文件
    data = []
    with open(json_file, "r") as file:
        for line in file:
            if line.strip():  # 跳过空行
                data.append(json.loads(line))
    
    # 创建必要的目录
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(output_pdf_dir, exist_ok=True)
    os.makedirs(output_png_dir, exist_ok=True)
    
    # 使用半数可用 CPU 进行资源平衡
    num_workers = max(1, os.cpu_count() // 2)
    
    with Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(process_entry, data)
        
        # 监控进度
        success_count = 0
        for i, result in enumerate(results, 1):
            if result:
                success_count += 1
            print(f"Processed {i}/{len(data)} entries ({success_count} successes)")
    
    print(f"\nProcessing complete. Success rate: {success_count}/{len(data)}")
# 用户只需修改这个JSON文件路径
json_file = "/home/lingjun/code/InternVL/internvl_chat/html_table_latex_results/output.jsonl"  # 修改此处为你的JSON文件路径

# 动态生成所有路径
base_dir = os.path.dirname(json_file)
extracted_tables_dir = os.path.join(base_dir, "extracted_tables")
table_dir = os.path.join(extracted_tables_dir, "latex_tables")
output_pdf_dir = os.path.join(extracted_tables_dir, "pdf_tables_full")
output_png_dir = os.path.join(extracted_tables_dir, "png_tables_full")
result_file_path = os.path.join(base_dir, "ssim_comparison_results.txt")
sorted_result_path = os.path.join(base_dir, "ssim_comparison_results_sorted.txt")

# 执行处理流程
process_jsonl(json_file)
print("Rendering and conversion complete!")

compare_images_and_save_results(output_png_dir, result_file_path)
sort_txt_file(result_file_path, sorted_result_path)
print("Image comparison complete!")

# 计算统计结果
if average_score := calculate_average_score(sorted_result_path):
    avg_message = f"平均值: {average_score:.4f}"
    print(avg_message)
    
avg, ratio = calculate_average_and_ratio(sorted_result_path)
if avg is not None and ratio is not None:
    stats_message = f"非零分数平均值: {avg:.4f}\n非零分数占比: {ratio:.2%}"
    print(stats_message)
    
    # 将统计结果写入文件末尾
    with open(sorted_result_path, 'a') as f:
        f.write('\n' + '-' * 50 + '\n')  # 添加分隔线
        f.write(avg_message + '\n')
        f.write(stats_message + '\n')

Processing 28...Processing 14...Processing 0...Processing 7...Processing 21...

Processing 56...
Processing 35...
Processing 91...Processing 63...Processing 49...Processing 70...Processing 42...Processing 77...Processing 84...Processing 98...Processing 126...Processing 112...Processing 105...

Processing 119...Processing 161...
Processing 189...
Processing 133...Processing 154...Processing 196...


Processing 175...Processing 168...

Processing 182...
Processing 203...Processing 210...
Processing 217...Processing 224...Processing 231...
Processing 245...
Processing 273...Processing 238...
Processing 259...
Processing 315...Processing 294...Processing 252...Processing 280...Processing 266...
Processing 308...
Processing 322...Processing 336...Processing 287...Processing 329...
Processing 8...Processing 15...
Processing 301...
Processing 343...Processing 1...Processing 22...
Processing 350...

Processing 357...


























Processing 140...Processing 147...

Error: PDF not 

In [3]:
###########Teds_structure_evaluation###########
import re
from table_recognition_metric import TEDS
import multiprocessing
import json

def remove_grid_lines(latex_table):
    cleaned_table = re.sub(r'\\cmidrule{\s*}|\\cdashline\{[0-9]+(-[0-9]+)?\}\s*|\\cmidrule$(?:lr|r|l)?$\{[0-9]+\-[0-9]+\}\s*|\\arrayrulecolor{.*?}\s*|\\caption{.*?}\s*|\\centering\s*|\\hline\s*|\\cline{.*?}\s*|\\toprule\s*|\\midrule\s*|\\bottomrule\s*', '', latex_table)
    cleaned_table = re.sub(r'\\tabularnewline', r'\\\\', cleaned_table)
    cleaned_table = re.sub(r'\n\s*\n', '\n', cleaned_table)
    return cleaned_table.strip(' \n')

def fix_multi(cell):
    multirow_pattern = r'\\multirow{(\d+)}{.*?}{(.*?)}'
    multicol_pattern = r'\\multicolumn{(\d+)}{.*?}{(.*?)}'
    
    match = re.search(multirow_pattern, cell['content'])
    if match:
        cell['rowspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()

    match = re.search(multicol_pattern, cell['content'])
    if match:
        cell['colspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()
    return cell    

def grid2html(grid):
    def to_td(grid, r, c):
        if grid[r][c] in ['<<', '^^', '..']:
            return ''
        td = {'text': grid[r][c], 'rowspan':1, 'colspan': 1}
        for i in range(r + 1, len(grid)):
            if grid[i][c] == '^^':
                td['rowspan'] += 1
            else:
                break   
        for j in range(c + 1, len(grid[r])):
            if grid[r][j] == '<<':
                td['colspan'] += 1
            else:
                break
        return f'<td rowspan={td["rowspan"]} colspan={td["colspan"]}> {td["text"]} </td>'.replace('rowspan=1', '').replace('colspan=1', '')
        
    html = []
    for r in range(len(grid)):
        row = []
        for c in range(len(grid[0])):
            row.append(to_td(grid, r, c))
        html.append(f'<tr> {"".join(row)} </tr>')
    return '<html><body><table>' + '\n'.join(html) + '</table></body></html>'

def qylatex_to_grid(latex):
    if not re.search(r'\\end{tabular[x]*\*?\}', latex):
        return
    pattern = r'\\begin\{tabular[x]*\*?\}.*?\\end\{tabular[x]*\*?\}'
    matches = re.findall(pattern, latex, re.DOTALL)
    if not matches:
        return
    content = remove_grid_lines(matches[0])
    rows = content.strip(' \n').split(r'\\')
    processed_rows = []
    for row in rows:
        if not row.strip():
            continue
        columns = re.split(r'(?<!\\)&', row)
        columns = [fix_multi({'content': c.strip(' \n'), 'rowspan': 1, 'colspan': 1}) for c in columns]
        processed_rows.append(columns)
    max_cols = max([sum([it['colspan'] for it in r]) for r in processed_rows]) if processed_rows else 0
    grid = [[None for _ in range(max_cols)] for _ in range(len(processed_rows))]
    r_idx_bias = 0
    for r_idx, row in enumerate(processed_rows):
        r_idx += r_idx_bias
        while r_idx >= len(grid):
            grid.append([None for _ in range(max_cols)])
        c_idx = 0
        current_row_bias = 10000
        for cell in row:
            while c_idx < len(grid[r_idx]) and grid[r_idx][c_idx] is not None:
                c_idx += 1
            if c_idx >= len(grid[r_idx]):
                break
            current_row_bias = min(current_row_bias, cell['rowspan'])
            grid[r_idx][c_idx] = cell['content']
            for r in range(cell['rowspan']):
                for c in range(cell['colspan']):
                    if r == 0 and c == 0:
                        continue
                    target_r = r_idx + r
                    target_c = c_idx + c
                    if target_r >= len(grid):
                        grid.append([None for _ in range(max_cols)])
                    if target_c < len(grid[target_r]):
                        grid[target_r][target_c] = '^^' if r > 0 else ('<<' if c > 0 else '..')
            c_idx += cell['colspan']
        r_idx_bias += current_row_bias - 1
    grid = [[c if c is not None else '' for c in r] for r in grid]
    return grid

def latex2html(latex_str):
    latex_str = re.sub(r'(?<!\\)%.*$', '', latex_str, flags=re.MULTILINE)
    latex_str = re.sub(r'(?<!\\)\\\\$$.*?$$', '', latex_str, flags=re.DOTALL)
    latex_str = latex_str.replace('\n', '').replace('\t', '')
    try:
        grid = qylatex_to_grid(latex_str)
    except IndexError as e:
        print(f"IndexError: {str(e)}")
        return 
    if not grid:
        return
    return grid2html(grid)

def teds_structure(gt, pred):
    gt_html = latex2html(gt)
    pred_html = latex2html(pred)
    if not pred_html:
        return 0, 0
    structure_teds = TEDS(structure_only=True)
    structure_score = structure_teds(gt_html, pred_html)
    all_teds = TEDS()
    teds_score = all_teds(gt_html, pred_html)
    return structure_score, teds_score

def process_item(args):
    idx, item = args
    try:
        if not isinstance(item, dict):
            print(f"Item {idx}: Invalid data type, skipping")
            return (None, None, None)
        # Check for different possible key combinations
        if 'questionId' in item and 'answer' in item and 'annotation' in item:
            qid = item['questionId']
            pred = item['answer']
            gt = item['annotation']
        elif 'id' in item and 'prediction' in item and 'reference' in item:
            qid = item['id']
            pred = item['prediction']
            gt = item['reference']
        else:
            print(f"Item {idx}: Missing required keys")
            return (None, None, None)
        structure_score, teds_score = teds_structure(gt, pred)
        return (qid, structure_score, teds_score)
    except Exception as e:
        print(f"Error processing item {idx}: {str(e)}")
        return (None, None, None)

def calculate_stats(scores):
    avg_all = sum(scores) / len(scores) if scores else 0
    non_zero_scores = [s for s in scores if s != 0]
    avg_non_zero = sum(non_zero_scores)/len(non_zero_scores) if non_zero_scores else 0
    non_zero_ratio = len(non_zero_scores)/len(scores) if scores else 0
    return avg_all, avg_non_zero, non_zero_ratio

def process_file(input_path):
    output_path = input_path.rsplit('.', 1)[0] + '_eval.txt'
    data = []
    # Determine file type and read data
    if input_path.endswith('.json'):
        try:
            with open(input_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if not isinstance(data, list):
                    raise ValueError("JSON root is not an array")
        except Exception as e:
            print(f"Error reading JSON: {str(e)}")
            return
    elif input_path.endswith('.jsonl'):
        with open(input_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    print(f"JSONL parse error: {str(e)}")
    else:
        print("Unsupported file format. Use .json or .jsonl")
        return
    
    # Process data
    structure_results = {}
    teds_results = {}
    cpu_num = multiprocessing.cpu_count()//2
    pool = multiprocessing.Pool(processes=cpu_num)
    task_args = [(i+1, item) for i, item in enumerate(data)]
    
    processed = 0
    for qid, structure_score, teds_score in pool.imap(process_item, task_args):
        if qid is not None:
            structure_results[qid] = structure_score
            teds_results[qid] = teds_score
        processed += 1
        if processed % 10 == 0:
            print(f"Processed {processed}/{len(data)}")
    
    pool.close()
    pool.join()
    
    # Write results
    structure_stats = calculate_stats(list(structure_results.values()))
    teds_stats = calculate_stats(list(teds_results.values()))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("ID:Structure_Score,TEDS_Score\n")
        for qid in sorted(structure_results):
            f.write(f"{qid}:{structure_results[qid]:.4f},{teds_results[qid]:.4f}\n")
        
        f.write("\n=== Structure Score ===\n")
        f.write(f"Total: {len(data)}\nValid: {len(structure_results)}\n")
        f.write(f"Non-zero: {structure_stats[2]:.2%}\n")
        f.write(f"Average (All): {structure_stats[0]:.4f}\nAverage (Non-zero): {structure_stats[1]:.4f}\n")
        
        f.write("\n=== TEDS Score ===\n")
        f.write(f"Total: {len(data)}\nValid: {len(teds_results)}\n")
        f.write(f"Non-zero: {teds_stats[2]:.2%}\n")
        f.write(f"Average (All): {teds_stats[0]:.4f}\nAverage (Non-zero): {teds_stats[1]:.4f}\n")

# Example usage
process_file("/home/lingjun/code/InternVL/internvl_chat/html_table_latex_results/output.jsonl")  # Replace with your file path

Processed 10/361
Processed 20/361
Processed 30/361
Processed 40/361
Processed 50/361
Processed 60/361
Processed 70/361
Processed 80/361
Processed 90/361
Processed 100/361
Processed 110/361
Processed 120/361
Processed 130/361
Processed 140/361
Processed 150/361
Processed 160/361
Processed 170/361
Processed 180/361
Processed 190/361
Processed 200/361
Processed 210/361


Process ForkPoolWorker-169:
Process ForkPoolWorker-172:
Process ForkPoolWorker-168:
Process ForkPoolWorker-183:
Process ForkPoolWorker-152:
Process ForkPoolWorker-173:
Process ForkPoolWorker-155:
Process ForkPoolWorker-171:
Process ForkPoolWorker-158:
Process ForkPoolWorker-184:
Process ForkPoolWorker-181:
Process ForkPoolWorker-148:
Process ForkPoolWorker-146:
Process ForkPoolWorker-153:
Process ForkPoolWorker-139:
Process ForkPoolWorker-147:
Process ForkPoolWorker-162:
Process ForkPoolWorker-175:
Process ForkPoolWorker-136:
Process ForkPoolWorker-160:
Process ForkPoolWorker-179:
Process ForkPoolWorker-182:
Process ForkPoolWorker-174:
Process ForkPoolWorker-176:
Process ForkPoolWorker-138:
Process ForkPoolWorker-135:
Process ForkPoolWorker-177:
Process ForkPoolWorker-129:
Process ForkPoolWorker-178:
Process ForkPoolWorker-131:
Process ForkPoolWorker-166:
Process ForkPoolWorker-130:
Process ForkPoolWorker-157:
Process ForkPoolWorker-151:
Process ForkPoolWorker-154:
Process ForkPoolWork

KeyboardInterrupt: 

In [None]:
import os
import subprocess
import json
from pdf2image import convert_from_path
from PIL import Image
from multiprocessing import Pool, cpu_count
import uuid
import concurrent.futures
import numpy as np

# 全局路径变量
table_dir = ""
output_pdf_dir = ""
output_png_dir = ""
extracted_tables_dir = ""
base_dir = ""

# 增强的LaTeX模板
latex_preamble = r"""
\documentclass[standalone]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage{textcomp}
\usepackage{amsmath, amsthm, amssymb, graphicx, geometry, array}
\usepackage{booktabs, multirow, tabularx, multicol, bm}
\pagenumbering{gobble}
\begin{document}
\begin{table}[htp]
\centering
\resizebox*{0.5\columnwidth}{!}{
"""
latex_end = r"}\end{table}\end{document}"

def render_tex_to_pdf(tex_path, output_pdf_path, timeout=30):
    """增强的LaTeX渲染函数"""
    with open(tex_path, "r") as file:
        tex_content = file.read()
    
    full_tex = latex_preamble + tex_content + latex_end
    
    temp_id = uuid.uuid4().hex
    temp_tex = os.path.join(os.path.dirname(output_pdf_path), f"temp_{temp_id}.tex")
    
    with open(temp_tex, "w") as f:
        f.write(full_tex)
    
    try:
        result = subprocess.run(
            ["pdflatex", "-interaction=nonstopmode", 
             "-output-directory", os.path.dirname(output_pdf_path), temp_tex],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            timeout=timeout,
            encoding='latin-1'
        )
        
        # 处理输出文件
        temp_pdf = temp_tex.replace(".tex", ".pdf")
        if os.path.exists(temp_pdf):
            os.rename(temp_pdf, output_pdf_path)
            print(f"✓ Rendered {os.path.basename(tex_path)}")
        else:
            print(f"× Failed {os.path.basename(tex_path)}\nLog: {result.stdout[:500]}")
            
    except subprocess.TimeoutExpired:
        print(f"⌛ Timeout {os.path.basename(tex_path)}")
    finally:
        # 清理临时文件
        for ext in [".aux", ".log", ".out", ".tex"]:
            temp_file = temp_tex.replace(".tex", ext)
            if os.path.exists(temp_file):
                try:
                    os.remove(temp_file)
                except:
                    pass

def convert_pdf_to_png(pdf_path, png_path, dpi=300):
    """增强的PDF转PNG函数"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        if images:
            images[0].save(png_path, "PNG", optimize=True, quality=95)
            print(f"✓ Converted {os.path.basename(pdf_path)}")
        else:
            print(f"× Empty PDF {os.path.basename(pdf_path)}")
    except Exception as e:
        print(f"× Conversion failed {os.path.basename(pdf_path)}: {str(e)}")

def process_entry(entry):
    """处理单个JSON条目"""
    try:
        qid = entry["id"]
        print(f"\n▶ Processing {qid}")
        
        # 生成LaTeX文件
        for field in ["gt", "latte_2"]:
            content = entry[field]
            base_name = f"{qid}_{field}"
            
            # 生成LaTeX文件
            tex_path = os.path.join(table_dir, f"{base_name}.tex")
            with open(tex_path, "w") as f:
                f.write(content)
            
            # 生成PDF
            pdf_path = os.path.join(output_pdf_dir, f"{base_name}.pdf")
            render_tex_to_pdf(tex_path, pdf_path)
            
            # 转换为PNG
            if os.path.exists(pdf_path):
                png_path = os.path.join(output_png_dir, f"{base_name}.png")
                convert_pdf_to_png(pdf_path, png_path)
                
        return True
    except Exception as e:
        print(f"⊗ Failed {entry.get('id', 'unknown')}: {str(e)}")
        return False

def process_json(json_path):
    """处理整个JSON文件"""
    with open(json_path, "r") as f:
        data = json.load(f)
    
    # 转换数据结构
    entries = [{"id": k, "gt": v["gt"], "latte_2": v["latte_2"]} 
               for k, v in data.items()]
    
    # 创建目录
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(output_pdf_dir, exist_ok=True)
    os.makedirs(output_png_dir, exist_ok=True)
    
    # 并行处理
    workers = max(1, cpu_count()//2)
    with Pool(processes=workers) as pool:
        results = pool.imap_unordered(process_entry, entries)
        
        success = 0
        for i, result in enumerate(results, 1):
            if result:
                success += 1
            print(f"▷ Progress: {i}/{len(entries)} ({success} successes)")
    
    print(f"\n★ Completed: {success}/{len(entries)} succeeded")

# CW-SSIM计算核心
def dwt2_simple(img):
    """简化版二维离散小波变换"""
    h, w = img.shape
    h -= h % 2
    w -= w % 2
    img = img[:h, :w]
    
    blocks = img.reshape(h//2, 2, w//2, 2).transpose(0, 2, 1, 3)
    a = blocks[..., 0, 0]
    b = blocks[..., 0, 1]
    c = blocks[..., 1, 0]
    d = blocks[..., 1, 1]
    
    return (a + b + c + d)/4, (a - c)/2, (a - b)/2, (a - d)/2

def calculate_ssim(patch1, patch2):
    """结构相似性计算"""
    C1, C2 = 6.5025, 58.5225
    
    mu1 = np.mean(patch1)
    mu2 = np.mean(patch2)
    
    sigma1_sq = np.var(patch1)
    sigma2_sq = np.var(patch2)
    sigma12 = np.cov(patch1.ravel(), patch2.ravel())[0,1]
    
    num = (2*mu1*mu2 + C1) * (2*sigma12 + C2)
    den = (mu1**2 + mu2**2 + C1) * (sigma1_sq + sigma2_sq + C2)
    
    return np.clip(num / den, 0.0, 1.0)



# def calculate_cwssim(img1_path, img2_path):
#     """完整的CW-SSIM计算"""
#     try:
#         # 图像预处理
#         img1 = Image.open(img1_path).convert('L')
#         img2 = Image.open(img2_path).convert('L').resize(img1.size)
        
#         arr1 = np.array(img1, dtype=np.float32)
#         arr2 = np.array(img2, dtype=np.float32)
        
#         # 小波分解
#         cA1, cH1, cV1, cD1 = dwt2_simple(arr1)
#         cA2, cH2, cV2, cD2 = dwt2_simple(arr2)
        
#         # 多尺度SSIM计算
#         ssims = [
#             calculate_ssim(cA1, cA2),
#             calculate_ssim(cH1, cH2),
#             calculate_ssim(cV1, cV2),
#             calculate_ssim(cD1, cD2)
#         ]
        
#         return np.mean(ssims)
#     except Exception as e:
#         print(f"⚠ Error computing CW-SSIM: {str(e)}")
#         return 0.0

import numpy as np
from pyrtools.pyramids import SCFpyr
from scipy.signal import convolve2d
import warnings

def cwssim_index(img1, img2):
    level=4
    ori=8
    guardb=0
    K=0.01
    # 确保图像是二维灰度图，转换为float类型
    if img1.ndim > 2:
        img1 = img1.mean(axis=2)
    if img2.ndim > 2:
        img2 = img2.mean(axis=2)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    
    # 构建复数可控金字塔
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")  # 忽略可能的警告
        pyr1 = SCFpyr(img1, height=level, nbands=ori)
        pyr2 = SCFpyr(img2, height=level, nbands=ori)
    
    winsize = 7
    window = np.ones((winsize, winsize)) / (winsize ** 2)
    
    # 计算边界裁剪量
    gb = guardb // (2 ** (level - 1))
    
    # 获取最高层级的子带尺寸以生成高斯权重
    band_key = (level-1, 0)  # 假设使用第一个方向子带
    band_coeff = pyr1.pyr_coeffs.get(band_key, None)
    if band_coeff is None:
        raise ValueError("无法获取指定层级的子带。")
    s = band_coeff.shape
    s_cropped = (s[0] - 2*gb, s[1] - 2*gb)
    size = (s_cropped[0] - winsize + 1, s_cropped[1] - winsize + 1)
    
    # 生成高斯权重矩阵
    sigma = size[0] / 4.0
    x = np.linspace(-(size[1]-1)/2, (size[1]-1)/2, size[1])
    y = np.linspace(-(size[0]-1)/2, (size[0]-1)/2, size[0])
    X, Y = np.meshgrid(x, y)
    w = np.exp(-(X**2 + Y**2) / (2 * sigma**2))
    w /= w.sum()
    
    band_cssim = np.zeros(ori)
    
    for i in range(ori):
        band_key = (level-1, i)
        band1 = pyr1.pyr_coeffs.get(band_key, None)
        band2 = pyr2.pyr_coeffs.get(band_key, None)
        if band1 is None or band2 is None:
            raise ValueError(f"方向子带 {i} 不存在于金字塔中。")
        
        # 裁剪边界
        band1 = band1[gb:-gb, gb:-gb]
        band2 = band2[gb:-gb, gb:-gb]
        
        # 计算相关系数和方差
        corr = band1 * np.conj(band2)
        varr = np.abs(band1)**2 + np.abs(band2)**2
        
        # 滑动窗口平均
        corr_band = convolve2d(corr, window, mode='valid')
        varr_band = convolve2d(varr, window, mode='valid')
        
        # 计算CSSIM映射
        cssim_map = (2 * np.abs(corr_band) + K) / (varr_band + K + 1e-8)  # 避免除零
        
        # 确保尺寸匹配
        if cssim_map.shape != w.shape:
            raise ValueError("CSSIM映射与权重矩阵尺寸不匹配。")
        
        band_cssim[i] = np.sum(cssim_map * w)
    
    cwssim = np.mean(band_cssim)
    return cwssim


def compare_images(output_dir, result_path):
    """批量图像对比"""
    # 构建文件映射
    file_pairs = {}
    for f in os.listdir(output_dir):
        if f.endswith('_gt.png'):
            qid = f[:-7]
            gt_path = os.path.join(output_dir, f)
            latte_path = os.path.join(output_dir, f"{qid}_latte_2.png")
            
            # 只有当两个文件都存在时才添加
            if os.path.exists(gt_path) and os.path.exists(latte_path):
                file_pairs[qid] = {
                    'gt': gt_path,
                    'latte_2': latte_path
                }
    
    # 准备任务
    tasks = [(qid, pd['latte_2'], pd['gt']) 
            for qid, pd in file_pairs.items() 
            if os.path.exists(pd['latte_2'])]
    
    print(f"Found {len(tasks)} valid pairs for comparison")
    
    # 并行计算
    results = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=8) as executor:
        futures = {executor.submit(calculate_cwssim, p1, p2): qid 
                  for qid, p1, p2 in tasks}
        
        for future in concurrent.futures.as_completed(futures):
            qid = futures[future]
            try:
                score = future.result()
                results.append((qid, score))
                print(f"✓ {qid}: {score:.4f}")
            except Exception as e:
                print(f"⊗ {qid}: {str(e)}")
                results.append((qid, 0.0))
    
    # 保存结果
    with open(result_path, "w") as f:
        for qid, score in sorted(results, key=lambda x: x[0]):
            f.write(f"{qid}: {score:.4f}\n")
    
    print(f"Results saved to {result_path}")

# 主流程

# 配置路径
json_file = "/home/lingjun/code/InternVL/internvl_chat/results_latte/Latte_2.json"  # 修改为实际路径
base_dir = os.path.dirname(json_file)

extracted_tables_dir = os.path.join(base_dir, "extracted_tables")
table_dir = os.path.join(extracted_tables_dir, "latex")
output_pdf_dir = os.path.join(extracted_tables_dir, "pdf")
output_png_dir = os.path.join(extracted_tables_dir, "png")
result_path = os.path.join(base_dir, "cwssim_results_test.txt")

# 执行处理流程
# process_json(json_file)
compare_images(output_png_dir, result_path)

# 计算统计结果
scores = []
with open(result_path, "r") as f:
    for line in f:
        if ":" in line:
            score = float(line.split(":")[1].strip())
            scores.append(score)

if scores:
    avg = sum(scores) / len(scores)
    non_zero = len([s for s in scores if s > 0.01])
    sum_non_zero = sum(s for s in scores if s > 0.01)
    avg_non_zero = sum_non_zero / non_zero if non_zero > 0 else 0
    print(f"\n★ Final Results ★\n"
            f"Average CW-SSIM: {avg:.4f}\n"
            f"Average CW-SSIM (non-zero): {avg_non_zero:.4f}\n"
            f"Non-zero rate: {non_zero}/{len(scores)} ({non_zero/len(scores):.1%})")

ImportError: cannot import name 'SCFpyr' from 'pyrtools.pyramids' (/home/lingjun/miniconda3/envs/vlm-r1/lib/python3.10/site-packages/pyrtools/pyramids/__init__.py)

In [None]:
###########Teds_structure_evaluation###########
import re
from table_recognition_metric import TEDS
import multiprocessing
import json

def remove_grid_lines(latex_table):
    cleaned_table = re.sub(r'\\cmidrule{\s*}|\\cdashline\{[0-9]+(-[0-9]+)?\}\s*|\\cmidrule$(?:lr|r|l)?$\{[0-9]+\-[0-9]+\}\s*|\\arrayrulecolor{.*?}\s*|\\caption{.*?}\s*|\\centering\s*|\\hline\s*|\\cline{.*?}\s*|\\toprule\s*|\\midrule\s*|\\bottomrule\s*', '', latex_table)
    cleaned_table = re.sub(r'\\tabularnewline', r'\\\\', cleaned_table)
    cleaned_table = re.sub(r'\n\s*\n', '\n', cleaned_table)
    return cleaned_table.strip(' \n')

def fix_multi(cell):
    multirow_pattern = r'\\multirow{(\d+)}{.*?}{(.*?)}'
    multicol_pattern = r'\\multicolumn{(\d+)}{.*?}{(.*?)}'
    
    match = re.search(multirow_pattern, cell['content'])
    if match:
        cell['rowspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()

    match = re.search(multicol_pattern, cell['content'])
    if match:
        cell['colspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()
    return cell    

def grid2html(grid):
    def to_td(grid, r, c):
        if grid[r][c] in ['<<', '^^', '..']:
            return ''
        td = {'text': grid[r][c], 'rowspan':1, 'colspan': 1}
        for i in range(r + 1, len(grid)):
            if grid[i][c] == '^^':
                td['rowspan'] += 1
            else:
                break   
        for j in range(c + 1, len(grid[r])):
            if grid[r][j] == '<<':
                td['colspan'] += 1
            else:
                break
        return f'<td rowspan={td["rowspan"]} colspan={td["colspan"]}> {td["text"]} </td>'.replace('rowspan=1', '').replace('colspan=1', '')
        
    html = []
    for r in range(len(grid)):
        row = []
        for c in range(len(grid[0])):
            row.append(to_td(grid, r, c))
        html.append(f'<tr> {"".join(row)} </tr>')
    return '<html><body><table>' + '\n'.join(html) + '</table></body></html>'

def qylatex_to_grid(latex):
    if not re.search(r'\\end{tabular[x]*\*?\}', latex):
        return
    pattern = r'\\begin\{tabular[x]*\*?\}.*?\\end\{tabular[x]*\*?\}'
    matches = re.findall(pattern, latex, re.DOTALL)
    if not matches:
        return
    content = remove_grid_lines(matches[0])
    rows = content.strip(' \n').split(r'\\')
    processed_rows = []
    for row in rows:
        if not row.strip():
            continue
        columns = re.split(r'(?<!\\)&', row)
        columns = [fix_multi({'content': c.strip(' \n'), 'rowspan': 1, 'colspan': 1}) for c in columns]
        processed_rows.append(columns)
    max_cols = max([sum([it['colspan'] for it in r]) for r in processed_rows]) if processed_rows else 0
    grid = [[None for _ in range(max_cols)] for _ in range(len(processed_rows))]
    r_idx_bias = 0
    for r_idx, row in enumerate(processed_rows):
        r_idx += r_idx_bias
        while r_idx >= len(grid):
            grid.append([None for _ in range(max_cols)])
        c_idx = 0
        current_row_bias = 10000
        for cell in row:
            while c_idx < len(grid[r_idx]) and grid[r_idx][c_idx] is not None:
                c_idx += 1
            if c_idx >= len(grid[r_idx]):
                break
            current_row_bias = min(current_row_bias, cell['rowspan'])
            grid[r_idx][c_idx] = cell['content']
            for r in range(cell['rowspan']):
                for c in range(cell['colspan']):
                    if r == 0 and c == 0:
                        continue
                    target_r = r_idx + r
                    target_c = c_idx + c
                    if target_r >= len(grid):
                        grid.append([None for _ in range(max_cols)])
                    if target_c < len(grid[target_r]):
                        grid[target_r][target_c] = '^^' if r > 0 else ('<<' if c > 0 else '..')
            c_idx += cell['colspan']
        r_idx_bias += current_row_bias - 1
    grid = [[c if c is not None else '' for c in r] for r in grid]
    return grid

def latex2html(latex_str):
    latex_str = re.sub(r'(?<!\\)%.*$', '', latex_str, flags=re.MULTILINE)
    latex_str = re.sub(r'(?<!\\)\\\\$$.*?$$', '', latex_str, flags=re.DOTALL)
    latex_str = latex_str.replace('\n', '').replace('\t', '')
    try:
        grid = qylatex_to_grid(latex_str)
    except IndexError as e:
        print(f"IndexError: {str(e)}")
        return 
    if not grid:
        return
    return grid2html(grid)

def teds_structure(gt, pred):
    gt_html = latex2html(gt)
    pred_html = latex2html(pred)
    if not pred_html:
        return 0, 0
    structure_teds = TEDS(structure_only=True)
    structure_score = structure_teds(gt_html, pred_html)
    all_teds = TEDS()
    teds_score = all_teds(gt_html, pred_html)
    return structure_score, teds_score

def process_item(args):
    idx, item = args
    try:
        if not isinstance(item, dict):
            print(f"Item {idx}: Invalid data type, skipping")
            return (None, None, None)
        # Modified to handle new format
        qid = item['id']
        pred = item['prediction']
        gt = item['reference']
        structure_score, teds_score = teds_structure(gt, pred)
        return (qid, structure_score, teds_score)
    except Exception as e:
        print(f"Error processing item {idx}: {str(e)}")
        return (None, None, None)

def calculate_stats(scores):
    avg_all = sum(scores) / len(scores) if scores else 0
    non_zero_scores = [s for s in scores if s != 0]
    avg_non_zero = sum(non_zero_scores)/len(non_zero_scores) if non_zero_scores else 0
    non_zero_ratio = len(non_zero_scores)/len(scores) if scores else 0
    return avg_all, avg_non_zero, non_zero_ratio

def process_file(input_path):
    output_path = input_path.rsplit('.', 1)[0] + '_eval.txt'
    data = []
    
    if input_path.endswith('.json'):
        try:
            with open(input_path, 'r', encoding='utf-8') as f:
                # Load as dictionary and convert to list format
                data_dict = json.load(f)
                data = [
                    {
                        'id': key,
                        'reference': value['gt'],
                        'prediction': value['latte_2']
                    }
                    for key, value in data_dict.items()
                ]
        except Exception as e:
            print(f"Error reading JSON: {str(e)}")
            return
    
    structure_results = {}
    teds_results = {}
    cpu_num = multiprocessing.cpu_count()//2
    pool = multiprocessing.Pool(processes=cpu_num)
    task_args = [(i+1, item) for i, item in enumerate(data)]
    
    processed = 0
    for qid, structure_score, teds_score in pool.imap(process_item, task_args):
        if qid is not None:
            structure_results[qid] = structure_score
            teds_results[qid] = teds_score
        processed += 1
        if processed % 10 == 0:
            print(f"Processed {processed}/{len(data)}")
    
    pool.close()
    pool.join()
    
    structure_stats = calculate_stats(list(structure_results.values()))
    teds_stats = calculate_stats(list(teds_results.values()))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("ID:Structure_Score,TEDS_Score\n")
        for qid in sorted(structure_results):
            f.write(f"{qid}:{structure_results[qid]:.4f},{teds_results[qid]:.4f}\n")
        
        f.write("\n=== Structure Score ===\n")
        f.write(f"Total: {len(data)}\nValid: {len(structure_results)}\n")
        f.write(f"Non-zero: {structure_stats[2]:.2%}\n")
        f.write(f"Average (All): {structure_stats[0]:.4f}\nAverage (Non-zero): {structure_stats[1]:.4f}\n")
        
        f.write("\n=== TEDS Score ===\n")
        f.write(f"Total: {len(data)}\nValid: {len(teds_results)}\n")
        f.write(f"Non-zero: {teds_stats[2]:.2%}\n")
        f.write(f"Average (All): {teds_stats[0]:.4f}\nAverage (Non-zero): {teds_stats[1]:.4f}\n")

# Example usage
process_file("/home/lingjun/code/InternVL/internvl_chat/results_latte/Latte_2.json")  # Replace with your file path

Processed 10/2740
Processed 20/2740
Processed 30/2740
Processed 40/2740
Processed 50/2740
Processed 60/2740
Processed 70/2740
Processed 80/2740
Processed 90/2740
Processed 100/2740
Processed 110/2740
Processed 120/2740
Processed 130/2740
Processed 140/2740
Processed 150/2740
Processed 160/2740
Processed 170/2740
Processed 180/2740
Processed 190/2740
Processed 200/2740
Processed 210/2740
Processed 220/2740
Processed 230/2740
Processed 240/2740
Processed 250/2740
Processed 260/2740
Processed 270/2740
Processed 280/2740
Processed 290/2740
Processed 300/2740
Processed 310/2740
Processed 320/2740
Processed 330/2740
Processed 340/2740
Processed 350/2740
Processed 360/2740
Processed 370/2740
Processed 380/2740
Processed 390/2740
Processed 400/2740
Processed 410/2740
Processed 420/2740
Processed 430/2740
Processed 440/2740
Processed 450/2740
Processed 460/2740
Processed 470/2740
Processed 480/2740
Processed 490/2740
Processed 500/2740
Processed 510/2740
Processed 520/2740
Processed 530/2740
Pr

In [None]:
import os
import subprocess
import json
from pdf2image import convert_from_path
from PIL import Image
from multiprocessing import Pool
import uuid
from skimage.metrics import structural_similarity as cw_ssim
import re

def get_sort_key(prefix):
    """从前缀中提取数字部分用于排序"""
    digits = re.sub(r'\D', '', prefix)
    return int(digits) if digits else 0
# 全局路径变量，将在主程序中初始化
table_dir = ""
output_pdf_dir = ""
output_png_dir = ""
extracted_tables_dir = ""
base_dir = ""

# Updated LaTeX preamble with UTF-8 support
latex_preamble = r"""
\documentclass[standalone]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage{amsmath, amsthm, amssymb, graphicx, geometry, array}
\usepackage{booktabs, multirow,natbib,tabularx, multicol, bm}
\pagenumbering{gobble}
\begin{document}
\begin{table}[htp]
\centering
\resizebox*{0.5\columnwidth}{!}{ 
"""
latex_end = r"}\end{table}\end{document}"

def render_tex_to_pdf(tex_path, output_pdf_path, timeout=20):
    """Render a LaTeX file to PDF with unique temporary files for parallel processing"""
    with open(tex_path, "r") as file:
        tex_content = file.read()
    
    full_tex_content = latex_preamble + tex_content + latex_end
    
    # Generate unique temporary filename using UUID
    temp_tex_filename = f"temp_{uuid.uuid4().hex}.tex"
    temp_tex_path = os.path.join(os.path.dirname(output_pdf_path), temp_tex_filename)
    
    with open(temp_tex_path, "w") as temp_file:
        temp_file.write(full_tex_content)
    
    try:
        result = subprocess.run(
            ["pdflatex", "-interaction=nonstopmode", "-output-directory", os.path.dirname(output_pdf_path), temp_tex_path],
            check=False,
            capture_output=True,
            text=True,
            encoding='latin-1',
            timeout=timeout
        )
        
        # Get the actual output PDF path from LaTeX compilation
        temp_pdf_path = temp_tex_path.replace(".tex", ".pdf")
        
        if os.path.exists(temp_pdf_path):
            os.rename(temp_pdf_path, output_pdf_path)
            print(f"Successfully rendered {tex_path} to PDF at {output_pdf_path}.")
        else:
            print(f"Error: PDF not generated for {tex_path}. LaTeX output:\n{result.stdout}")
    except subprocess.TimeoutExpired:
        print(f"Timeout expired while rendering {tex_path}. Skipping this file.")
    except Exception as e:
        print(f"Unexpected error rendering {tex_path}: {str(e)}")
    finally:
        # Cleanup temporary files
        for ext in [".aux", ".log", ".out", ".tex"]:
            file_path = temp_tex_path.replace(".tex", ext)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                except Exception as e:
                    print(f"Warning: Could not delete temporary file {file_path}: {str(e)}")

def convert_pdf_to_png(pdf_path, png_path, dpi=300):
    """Convert PDF to PNG with error handling"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        if images:
            images[0].save(png_path, "PNG")
            print(f"Converted {pdf_path} to {png_path}")
        else:
            print(f"Error: No pages found in {pdf_path}")
    except Exception as e:
        print(f"Error converting {pdf_path} to PNG: {str(e)}")

def process_entry(entry):
    """Process a single JSON entry with full error handling"""
    try:
        question_id = entry["id"]
        print(f"Processing {question_id}...")
        
        # Create output filenames
        base_files = {
            "prediction": (entry["prediction"], f"{question_id}_prediction"),
            "reference": (entry["reference"], f"{question_id}_reference")
        }
        
        for content_type, (content, base_name) in base_files.items():
            # Generate LaTeX file
            tex_path = os.path.join(table_dir, f"{base_name}.tex")
            with open(tex_path, "w") as f:
                f.write(content)
            
            # Generate PDF
            pdf_path = os.path.join(output_pdf_dir, f"{base_name}.pdf")
            render_tex_to_pdf(tex_path, pdf_path)
            
            # Convert to PNG
            if os.path.exists(pdf_path):
                png_path = os.path.join(output_png_dir, f"{base_name}.png")
                convert_pdf_to_png(pdf_path, png_path)
            else:
                print(f"Skipping PNG conversion for {base_name} - PDF not found")
        
        print(f"Completed processing {question_id}")
        return True
    except Exception as e:
        print(f"Error processing {entry.get('questionId', 'unknown')}: {str(e)}")
        return False

def process_json(json_file):
    """Main processing function with parallel execution"""
    with open(json_file, "r") as file:
        data = json.load(file)
    
    # Create necessary directories
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(output_pdf_dir, exist_ok=True)
    os.makedirs(output_png_dir, exist_ok=True)
    
    # Use half of available CPUs for resource balancing
    num_workers = max(1, os.cpu_count() // 2)
    
    with Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(process_entry, data)
        
        # Monitor progress
        success_count = 0
        for i, result in enumerate(results, 1):
            if result:
                success_count += 1
            print(f"Processed {i}/{len(data)} entries ({success_count} successes)")
    
    print(f"\nProcessing complete. Success rate: {success_count}/{len(data)}")

import concurrent
import numpy as np

def dwt2_simple(img_array):
    """Performs a simple 2D discrete wavelet transform with even-size check"""
    rows, cols = img_array.shape
    rows = rows - rows % 2
    cols = cols - cols % 2
    img_array = img_array[:rows, :cols]
    
    blocks = img_array.reshape(rows//2, 2, cols//2, 2).transpose(0, 2, 1, 3)
    a = blocks[..., 0, 0]
    b = blocks[..., 0, 1]
    c = blocks[..., 1, 0]
    d = blocks[..., 1, 1]
    
    cA = (a + b + c + d) * 0.25
    cH = (a - c) * 0.5
    cV = (a - b) * 0.5
    cD = (a - d) * 0.5
    return cA, cH, cV, cD

def calculate_ssim(img1, img2):
    img1_flat = img1.ravel()
    img2_flat = img2.ravel()
    n = img1_flat.size
    
    sum1 = img1_flat.sum()
    sum2 = img2_flat.sum()
    sum12 = (img1_flat * img2_flat).sum()
    sum1_sq = (img1_flat** 2).sum()
    sum2_sq = (img2_flat** 2).sum()
    
    mean1 = sum1 / n
    mean2 = sum2 / n
    var1 = (sum1_sq - sum1**2 / n) / n
    var2 = (sum2_sq - sum2**2 / n) / n
    
    covar = (sum12 - sum1 * sum2 / n) / (n - 1) if n > 1 else 0.0
    
    C1, C2 = 6.5025, 58.5225
    numerator = (2 * mean1 * mean2 + C1) * (2 * covar + C2)
    denominator = (mean1**2 + mean2**2 + C1) * (var1 + var2 + C2)
    return numerator / denominator

def calculate_cwssim(image1_path, image2_path):
    """Compares two images using Complex Wavelet Structural Similarity Index (CW-SSIM)"""
    try:
        image1 = Image.open(image1_path).convert('L')
        image2 = Image.open(image2_path).convert('L')
        
        image1 = image1.resize(image2.size)
        
        img1_array = np.array(image1)
        img2_array = np.array(image2)
        
        cA1, cH1, cV1, cD1 = dwt2_simple(img1_array)
        cA2, cH2, cV2, cD2 = dwt2_simple(img2_array)
        
        ssim_cA = calculate_ssim(cA1, cA2)
        ssim_cH = calculate_ssim(cH1, cH2)
        ssim_cV = calculate_ssim(cV1, cV2)
        ssim_cD = calculate_ssim(cD1, cD2)
        cwssim_score = (ssim_cA + ssim_cH + ssim_cV + ssim_cD) / 4
        
        return cwssim_score
    
    except Exception as e:
        print(f"Error comparing images {image1_path} and {image2_path}: {e}")
        return 0.0

def process_prefix(prefix, output_png_dir, annotation_images, answer_images):
    annotation_image_path = os.path.join(output_png_dir, annotation_images[prefix])
    if prefix in answer_images:
        answer_image_path = os.path.join(output_png_dir, answer_images[prefix])
        cwssim_score = calculate_cwssim(answer_image_path, annotation_image_path)
        return f"{prefix}: {cwssim_score:.4f}\n"
    else:
        return f"{prefix}: 0.0000\n"
    
def calculate_cwssim_wrapper(task):
    """包装函数处理单个对比任务"""
    prefix, answer_path, annotation_path = task
    try:
        if os.path.exists(answer_path) and os.path.exists(annotation_path):
            return prefix, calculate_cwssim(answer_path, annotation_path)
        return prefix, 0.0
    except Exception as e:
        print(f"Error processing {prefix}: {str(e)}")
        return prefix, 0.0
    
def compare_images_and_save_results(output_png_dir, result_file_path, max_workers=None):
    """优化后的图像对比函数，支持显式指定进程数并批量写入结果"""
    image_files = os.listdir(output_png_dir)
    
    # 构建文件映射关系
    answer_map = {f.split('_')[0]: f for f in image_files if 'prediction.png' in f}
    annotation_map = {f.split('_')[0]: f for f in image_files if 'reference.png' in f}

    # 生成有序任务列表
    tasks = []
    # 修改排序方式：按提取的数字排序
    for prefix in sorted(annotation_map.keys(), key=get_sort_key):
        answer_file = answer_map.get(prefix, "")
        task = (
            prefix,
            os.path.join(output_png_dir, answer_file) if answer_file else "",
            os.path.join(output_png_dir, annotation_map[prefix])
        )
        tasks.append(task)

    # 设置默认进程数（CPU核心数75%）
    if max_workers is None:
        max_workers = 16

    # 使用带进度显示的结果收集
    results = []
    with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(calculate_cwssim_wrapper, task): task[0] for task in tasks}
        
        # 添加进度显示
        completed = 0
        total = len(tasks)
        print(f"使用 {max_workers} 个进程进行图像对比...")
        
        for future in concurrent.futures.as_completed(futures):
            completed += 1
            prefix = futures[future]
            try:
                _, score = future.result(timeout=60)  # 60秒超时
                results.append((prefix, score))
                # 更新进度显示
                print(f"\r处理进度: {completed}/{total} ({completed/total:.1%})", end="", flush=True)
            except Exception as e:
                print(f"\n处理 {prefix} 时发生错误: {str(e)}")
                results.append((prefix, 0.0))

    # 按原始顺序排序后写入
    with open(result_file_path, "w") as result_file:
        for prefix, score in sorted(results, key=lambda x: get_sort_key(x[0])):
            result_line = f"{prefix}: {score:.4f}\n"
            result_file.write(result_line)
    
    # 空结果检查（新增）
    if len(results) == 0:
        print("严重警告：未生成任何对比结果！")
    else:
        print(f"结果已保存至 {result_file_path} (共 {len(results)} 条记录)")
    

def sort_txt_file(input_file, output_file=None):
    """通用排序函数，路径由调用方提供"""
    try:
        with open(input_file, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]

        def sort_key(line):
            left_num = line.split(':', 1)[0].strip()
            return int(left_num)

        sorted_lines = sorted(lines, key=sort_key)
        output = output_file or input_file.replace('.txt', '_sorted.txt')

        with open(output, 'w') as f:
            f.write('\n'.join(sorted_lines))
        print(f"文件已排序并保存为：{output}")
        return True
    except Exception as e:
        print(f"处理时发生错误：{str(e)}")
        return False

def calculate_average_score(file_path):
    """精确计算所有有效行的平均值"""
    try:
        total = 0.0
        count = 0
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                # 跳过统计信息分隔行
                if not line or line.startswith('-') or '平均值' in line or '占比' in line:
                    continue
                if ':' in line:
                    # 使用rsplit处理可能包含多个冒号的情况
                    parts = line.rsplit(':', 1)
                    if len(parts) != 2:
                        continue
                    _, score_str = parts
                    try:
                        score = float(score_str.strip())
                        total += score
                        count += 1
                    except ValueError:
                        continue
        return total / count if count > 0 else 0.0
    except Exception as e:
        print(f"计算平均值时发生错误: {str(e)}")
        return 0.0

def calculate_average_and_ratio(file_path):
    """精确计算非零分数平均值和占比"""
    try:
        total_score = 0.0
        valid_count = 0
        non_zero_count = 0
        
        with open(file_path, 'r') as f:
            for line in f:
                line = line.strip()
                # 跳过分隔行和统计信息
                if not line or line.startswith('-') or '平均值' in line or '占比' in line:
                    continue
                if ':' in line:
                    parts = line.rsplit(':', 1)
                    if len(parts) != 2:
                        continue
                    _, score_str = parts
                    try:
                        score = float(score_str.strip())
                        valid_count += 1
                        if score > 1e-6:  # 更精确的浮点比较阈值
                            non_zero_count += 1
                            total_score += score
                    except ValueError:
                        continue

        avg = total_score / non_zero_count if non_zero_count > 0 else 0.0
        ratio = non_zero_count / valid_count if valid_count > 0 else 0.0
        return avg, ratio
    except Exception as e:
        print(f"计算统计值时发生错误: {str(e)}")
        return 0.0, 0.0

# 修改后的结果处理流程
def process_statistics(sorted_result_path):
    """综合统计结果处理"""
    # 计算全局平均值
    global_avg = calculate_average_score(sorted_result_path)
    
    # 计算非零平均值和占比
    non_zero_avg, non_zero_ratio = calculate_average_and_ratio(sorted_result_path)
    
    # 生成统计信息
    stats = [
        f"{'-'*50}",
        f"全局平均值（含零）: {global_avg:.4f}",
        f"非零分数平均值: {non_zero_avg:.4f}",
        f"有效对比占比: {non_zero_ratio:.2%}",
        f"统计样本总数: {_get_total_count(sorted_result_path)}"
    ]
    
    # 写入文件
    with open(sorted_result_path, 'a') as f:
        f.write('\n' + '\n'.join(stats))
    
    # 控制台输出
    print('\n'.join(stats))

def _get_total_count(file_path):
    """辅助函数：获取有效对比总数"""
    count = 0
    with open(file_path, 'r') as f:
        for line in f:
            line = line.strip()
            if line and ':' in line and not line.startswith('-') and '平均值' not in line:
                count += 1
    return count
def process_jsonl(json_file):
    """Main processing function with parallel execution"""
    # 读取 JSONL 文件
    data = []
    with open(json_file, "r") as file:
        for line in file:
            if line.strip():  # 跳过空行
                data.append(json.loads(line))
    
    # 创建必要的目录
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(output_pdf_dir, exist_ok=True)
    os.makedirs(output_png_dir, exist_ok=True)
    
    # 使用半数可用 CPU 进行资源平衡
    num_workers = max(1, os.cpu_count() // 2)
    
    with Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(process_entry, data)
        
        # 监控进度
        success_count = 0
        for i, result in enumerate(results, 1):
            if result:
                success_count += 1
            print(f"Processed {i}/{len(data)} entries ({success_count} successes)")
    
    print(f"\nProcessing complete. Success rate: {success_count}/{len(data)}")
# 用户只需修改这个JSON文件路径
json_file = "/home/lingjun/code/InternVL/internvl_chat/test_2/output.jsonl"  # 修改此处为你的JSON文件路径

# 动态生成所有路径
base_dir = os.path.dirname(json_file)
extracted_tables_dir = os.path.join(base_dir, "extracted_tables")
table_dir = os.path.join(extracted_tables_dir, "latex_tables")
output_pdf_dir = os.path.join(extracted_tables_dir, "pdf_tables_full")
output_png_dir = os.path.join(extracted_tables_dir, "png_tables_full")
result_file_path = os.path.join(base_dir, "ssim_comparison_results.txt")
sorted_result_path = os.path.join(base_dir, "ssim_comparison_results_sorted.txt")

# 执行处理流程
process_jsonl(json_file)
print("Rendering and conversion complete!")

compare_images_and_save_results(output_png_dir, result_file_path)
process_statistics(result_file_path)
print("Image comparison complete!")

# 计算统计结果
if average_score := calculate_average_score(sorted_result_path):
    avg_message = f"平均值: {average_score:.4f}"
    print(avg_message)
    
avg, ratio = calculate_average_and_ratio(sorted_result_path)
if avg is not None and ratio is not None:
    stats_message = f"非零分数平均值: {avg:.4f}\n非零分数占比: {ratio:.2%}"
    print(stats_message)
    
    # 将统计结果写入文件末尾
    with open(sorted_result_path, 'a') as f:
        f.write('\n' + '-' * 50 + '\n')  # 添加分隔线
        f.write(avg_message + '\n')
        f.write(stats_message + '\n')

Processing AI-11184...Processing AI-10162...Processing AI-10792...Processing AI-11194...Processing AI-10650...Processing AI-11277...Processing AI-10070...Processing AI-11152...Processing AI-11128...Processing AI-10956...Processing AI-11027...Processing AI-10944...Processing AI-10717...Processing AI-11493...

Processing AI-10884...

Processing AI-1092...Processing AI-10328...Processing AI-11419...




Processing AI-10282...Processing AI-1046...Processing AI-10469...Processing AI-10177...

Processing AI-10865...Processing AI-10490...Processing AI-10799...Processing AI-10157...

Processing AI-10858...
Error processing unknown: 'prediction'Processing AI-1159...Processing AI-11689...Error processing unknown: 'prediction'Processing AI-10357...
Processing AI-12210...Processing AI-11148...Processing AI-10040...Processing AI-12293...Error processing unknown: 'prediction'
Error processing unknown: 'prediction'Processing AI-12288...Processing AI-11744...Processing AI-11869...Processing AI-11754..

In [10]:
###########Teds_structure_evaluation###########
import re
from table_recognition_metric import TEDS
import multiprocessing
import json

def remove_grid_lines(latex_table):
    cleaned_table = re.sub(r'\\cmidrule{\s*}|\\cdashline\{[0-9]+(-[0-9]+)?\}\s*|\\cmidrule$(?:lr|r|l)?$\{[0-9]+\-[0-9]+\}\s*|\\arrayrulecolor{.*?}\s*|\\caption{.*?}\s*|\\centering\s*|\\hline\s*|\\cline{.*?}\s*|\\toprule\s*|\\midrule\s*|\\bottomrule\s*', '', latex_table)
    cleaned_table = re.sub(r'\\tabularnewline', r'\\\\', cleaned_table)
    cleaned_table = re.sub(r'\n\s*\n', '\n', cleaned_table)
    return cleaned_table.strip(' \n')

def fix_multi(cell):
    multirow_pattern = r'\\multirow{(\d+)}{.*?}{(.*?)}'
    multicol_pattern = r'\\multicolumn{(\d+)}{.*?}{(.*?)}'
    
    match = re.search(multirow_pattern, cell['content'])
    if match:
        cell['rowspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()

    match = re.search(multicol_pattern, cell['content'])
    if match:
        cell['colspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()
    return cell    

def grid2html(grid):
    def to_td(grid, r, c):
        if grid[r][c] in ['<<', '^^', '..']:
            return ''
        td = {'text': grid[r][c], 'rowspan':1, 'colspan': 1}
        for i in range(r + 1, len(grid)):
            if grid[i][c] == '^^':
                td['rowspan'] += 1
            else:
                break   
        for j in range(c + 1, len(grid[r])):
            if grid[r][j] == '<<':
                td['colspan'] += 1
            else:
                break
        return f'<td rowspan={td["rowspan"]} colspan={td["colspan"]}> {td["text"]} </td>'.replace('rowspan=1', '').replace('colspan=1', '')
        
    html = []
    for r in range(len(grid)):
        row = []
        for c in range(len(grid[0])):
            row.append(to_td(grid, r, c))
        html.append(f'<tr> {"".join(row)} </tr>')
    return '<html><body><table>' + '\n'.join(html) + '</table></body></html>'

def qylatex_to_grid(latex):
    if not re.search(r'\\end{tabular[x]*\*?\}', latex):
        return
    pattern = r'\\begin\{tabular[x]*\*?\}.*?\\end\{tabular[x]*\*?\}'
    matches = re.findall(pattern, latex, re.DOTALL)
    if not matches:
        return
    content = remove_grid_lines(matches[0])
    rows = content.strip(' \n').split(r'\\')
    processed_rows = []
    for row in rows:
        if not row.strip():
            continue
        columns = re.split(r'(?<!\\)&', row)
        columns = [fix_multi({'content': c.strip(' \n'), 'rowspan': 1, 'colspan': 1}) for c in columns]
        processed_rows.append(columns)
    max_cols = max([sum([it['colspan'] for it in r]) for r in processed_rows]) if processed_rows else 0
    grid = [[None for _ in range(max_cols)] for _ in range(len(processed_rows))]
    r_idx_bias = 0
    for r_idx, row in enumerate(processed_rows):
        r_idx += r_idx_bias
        while r_idx >= len(grid):
            grid.append([None for _ in range(max_cols)])
        c_idx = 0
        current_row_bias = 10000
        for cell in row:
            while c_idx < len(grid[r_idx]) and grid[r_idx][c_idx] is not None:
                c_idx += 1
            if c_idx >= len(grid[r_idx]):
                break
            current_row_bias = min(current_row_bias, cell['rowspan'])
            grid[r_idx][c_idx] = cell['content']
            for r in range(cell['rowspan']):
                for c in range(cell['colspan']):
                    if r == 0 and c == 0:
                        continue
                    target_r = r_idx + r
                    target_c = c_idx + c
                    if target_r >= len(grid):
                        grid.append([None for _ in range(max_cols)])
                    if target_c < len(grid[target_r]):
                        grid[target_r][target_c] = '^^' if r > 0 else ('<<' if c > 0 else '..')
            c_idx += cell['colspan']
        r_idx_bias += current_row_bias - 1
    grid = [[c if c is not None else '' for c in r] for r in grid]
    return grid

def latex2html(latex_str):
    latex_str = re.sub(r'(?<!\\)%.*$', '', latex_str, flags=re.MULTILINE)
    latex_str = re.sub(r'(?<!\\)\\\\$$.*?$$', '', latex_str, flags=re.DOTALL)
    latex_str = latex_str.replace('\n', '').replace('\t', '')
    try:
        grid = qylatex_to_grid(latex_str)
    except IndexError as e:
        print(f"IndexError: {str(e)}")
        return 
    if not grid:
        return
    return grid2html(grid)

def teds_structure(gt, pred):
    gt_html = latex2html(gt)
    pred_html = latex2html(pred)
    if not pred_html:
        return 0, 0
    structure_teds = TEDS(structure_only=True)
    structure_score = structure_teds(gt_html, pred_html)
    all_teds = TEDS()
    teds_score = all_teds(gt_html, pred_html)
    return structure_score, teds_score

def process_item(args):
    idx, item = args
    try:
        if not isinstance(item, dict):
            print(f"Item {idx}: Invalid data type, skipping")
            return (None, None, None)
        # Check for different possible key combinations
        if 'questionId' in item and 'answer' in item and 'annotation' in item:
            qid = item['questionId']
            pred = item['answer']
            gt = item['annotation']
        elif 'id' in item and 'prediction' in item and 'reference' in item:
            qid = item['id']
            pred = item['prediction']
            gt = item['reference']
        else:
            print(f"Item {idx}: Missing required keys")
            return (None, None, None)
        structure_score, teds_score = teds_structure(gt, pred)
        return (qid, structure_score, teds_score)
    except Exception as e:
        print(f"Error processing item {idx}: {str(e)}")
        return (None, None, None)

def calculate_stats(scores):
    avg_all = sum(scores) / len(scores) if scores else 0
    non_zero_scores = [s for s in scores if s != 0]
    avg_non_zero = sum(non_zero_scores)/len(non_zero_scores) if non_zero_scores else 0
    non_zero_ratio = len(non_zero_scores)/len(scores) if scores else 0
    return avg_all, avg_non_zero, non_zero_ratio

def process_file(input_path):
    output_path = input_path.rsplit('.', 1)[0] + '_eval.txt'
    data = []
    # Determine file type and read data
    if input_path.endswith('.json'):
        try:
            with open(input_path, 'r', encoding='utf-8') as f:
                data = json.load(f)
                if not isinstance(data, list):
                    raise ValueError("JSON root is not an array")
        except Exception as e:
            print(f"Error reading JSON: {str(e)}")
            return
    elif input_path.endswith('.jsonl'):
        with open(input_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data.append(json.loads(line.strip()))
                except json.JSONDecodeError as e:
                    print(f"JSONL parse error: {str(e)}")
    else:
        print("Unsupported file format. Use .json or .jsonl")
        return
    
    # Process data
    structure_results = {}
    teds_results = {}
    cpu_num = multiprocessing.cpu_count()//2
    pool = multiprocessing.Pool(processes=cpu_num)
    task_args = [(i+1, item) for i, item in enumerate(data)]
    
    processed = 0
    for qid, structure_score, teds_score in pool.imap(process_item, task_args):
        if qid is not None:
            structure_results[qid] = structure_score
            teds_results[qid] = teds_score
        processed += 1
        if processed % 10 == 0:
            print(f"Processed {processed}/{len(data)}")
    
    pool.close()
    pool.join()
    
    # Write results
    structure_stats = calculate_stats(list(structure_results.values()))
    teds_stats = calculate_stats(list(teds_results.values()))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write("ID:Structure_Score,TEDS_Score\n")
        for qid in sorted(structure_results):
            f.write(f"{qid}:{structure_results[qid]:.4f},{teds_results[qid]:.4f}\n")
        
        f.write("\n=== Structure Score ===\n")
        f.write(f"Total: {len(data)}\nValid: {len(structure_results)}\n")
        f.write(f"Non-zero: {structure_stats[2]:.2%}\n")
        f.write(f"Average (All): {structure_stats[0]:.4f}\nAverage (Non-zero): {structure_stats[1]:.4f}\n")
        
        f.write("\n=== TEDS Score ===\n")
        f.write(f"Total: {len(data)}\nValid: {len(teds_results)}\n")
        f.write(f"Non-zero: {teds_stats[2]:.2%}\n")
        f.write(f"Average (All): {teds_stats[0]:.4f}\nAverage (Non-zero): {teds_stats[1]:.4f}\n")

# Example usage
process_file("/home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b_latte/output.jsonl")  # Replace with your file path

Processed 10/2725
Processed 20/2725
Processed 30/2725
Processed 40/2725
Processed 50/2725
Processed 60/2725
Processed 70/2725
Processed 80/2725
Processed 90/2725
Processed 100/2725
Processed 110/2725
Processed 120/2725
Processed 130/2725
Processed 140/2725
Processed 150/2725
Processed 160/2725
Processed 170/2725
Processed 180/2725
Processed 190/2725
Processed 200/2725
Processed 210/2725
Processed 220/2725
Processed 230/2725
Processed 240/2725
Processed 250/2725
Processed 260/2725
Processed 270/2725
Processed 280/2725
Processed 290/2725
Processed 300/2725
Processed 310/2725
Processed 320/2725
Processed 330/2725
Processed 340/2725
Processed 350/2725
Processed 360/2725
Processed 370/2725
Processed 380/2725
Processed 390/2725
Processed 400/2725
Processed 410/2725
Processed 420/2725
Processed 430/2725
Processed 440/2725
Processed 450/2725
Processed 460/2725
Processed 470/2725
Processed 480/2725
Processed 490/2725
Processed 500/2725
Processed 510/2725
Processed 520/2725
Processed 530/2725
Pr

In [29]:
###########Teds_structure_evaluation###########
import re
from table_recognition_metric import TEDS
import multiprocessing
import json
def remove_grid_lines(latex_table):
    # 去除 \hline, \cline, \toprule, \midrule, \bottomrule
    cleaned_table = re.sub(r'\\cmidrule{\s*}|\\cdashline\{[0-9]+(-[0-9]+)?\}\s*|\\cmidrule\((?:lr|r|l)?\)\{[0-9]+\-[0-9]+\}\s*|\\arrayrulecolor{.*?}\s*|\\caption{.*?}\s*|\\centering\s*|\\hline\s*|\\cline{.*?}\s*|\\toprule\s*|\\midrule\s*|\\bottomrule\s*', '', latex_table)
    
    cleaned_table = re.sub(r'\\tabularnewline', r'\\\\', cleaned_table)
    # # 去除注释
    # cleaned_table = re.sub(r'(?<!\\)%.*$', '', cleaned_table, flags=re.MULTILINE)
    # cleaned_table = re.sub(r'(?<!\\)\[[^\[\]]*\]', '', cleaned_table)
    # 合并连续的空行
    cleaned_table = re.sub(r'\n\s*\n', '\n', cleaned_table)
    
    return cleaned_table.strip(' \n')  # 去除首尾空格

def fix_multi(cell):
    multirow_pattern = r'\\multirow{(\d+)}{.*?}{(.*?)}'
    multicol_pattern = r'\\multicolumn{(\d+)}{.*?}{(.*?)}'
    
    match = re.search(multirow_pattern, cell['content'])
    if match:
        cell['rowspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()

    match = re.search(multicol_pattern, cell['content'])
    if match:
        cell['colspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()

    return cell    

def grid2html(grid):
    def to_td(grid, r, c):
        if grid[r][c] == '<<' or grid[r][c] == '^^' or grid[r][c] == '..':
            return ''
        td = {'text': grid[r][c], 'rowspan':1, 'colspan': 1}
        #计算行跨度
        for i in range(r + 1, len(grid)):
            if grid[i][c] == '^^':
                td['rowspan'] += 1
            else:
                break
        #就算列跨度    
        for j in range(c + 1, len(grid[r])):
            if grid[r][j] == '<<':
                td['colspan'] += 1
            else:
                break
        return f'<td rowspan={td["rowspan"]} colspan={td["colspan"]}> {td["text"]} </td>'.replace('rowspan=1', '').replace('colspan=1', '')
        
    
    html = []
    for r in range(len(grid)):
        row = []
        for c in range(len(grid[0])):
            row.append(to_td(grid, r, c))
        html.append(f'<tr> {"".join(row)} </tr>')
    # for row in grid:
    #     html.append('<tr>' + ''.join([to_td(c) for c in row]) + '</tr>')
    
    return '<html><body><table>' + '\n'.join(html) + '</table></body></html>'


def qylatex_to_grid(latex):
    # 提取表格内容
    if not re.search(r'\\end{tabular[x]*\*?\}', latex):
        return  # 如果不匹配，返回
    pattern = r'\\begin\{tabular[x]*\*?\}.*?\\end\{tabular[x]*\*?\}'

    matches = re.findall(pattern, latex, re.DOTALL)
    if matches:
        table_content = matches[0]
    else:
        return
    # 提取表格内容

    content = remove_grid_lines(table_content)
    # 获取表格内部的内容
    
    # 将行和列分割
    rows = content.strip(' \n').split(r'\\')

    processed_rows = []

    for row in rows:
        # 去除空行
        if not row.strip():
            continue

        # 用 & 分割列
        columns = re.split(r'(?<!\\)&', row)
         # 去除多余的空格并确定每一列的跨度以构建行
        columns = [fix_multi({'content': c.strip(' \n'), 'rowspan': 1, 'colspan': 1}) for c in columns]
       
        processed_rows.append(columns)
    # # 如果最后一行为空, 删除
    # while len(processed_rows) > 0 and len(processed_rows[-1]) == 0:
    #     processed_rows.pop()
    rows = processed_rows
    max_cols = max([sum([it['colspan'] for it in r]) for r in rows])
    # 创建一个空白网格
    grid = [[None for _ in range(max_cols)] for _ in range(len(rows))]
    col_char_num = [[1] for _ in range(max_cols)]
    # 填充网格，处理 rowspan 和 colspan
    r_idx_bias = 0
    for r_idx, row in enumerate(rows):
        r_idx += r_idx_bias
        while r_idx >= len(grid):
            grid.append([None for _ in range(max_cols)])
        c_idx = 0
        current_row_bias = 10000
        for cell in row:
            # 找到第一个未填充的单元格
            if grid[r_idx][c_idx] is not None:
                if cell['content']:
                    while grid[r_idx][c_idx] == '..':
                        c_idx += 1
                else:
                    c_idx += 1
                    continue

            current_row_bias = min(current_row_bias, cell['rowspan'])
            # 填充内容
            grid[r_idx][c_idx] = cell['content']
            col_char_num[c_idx].append(len(cell['content']))
            
            # 处理 rowspan 和 colspan
            for r in range(cell['rowspan']):
                for c in range(cell['colspan']):
                    if r == 0 and c == 0:
                        continue
                    if r == 0:
                        grid[r_idx][c_idx + c] = '<<'
                    elif c == 0:
                    
                        while r_idx+r >= len(grid):
                            grid.append([None for _ in range(max_cols)])
                        grid[r_idx + r][c_idx] = '^^'
                    else:
                        grid[r_idx + r][c_idx + c] = '..'                  
            c_idx += cell['colspan']
        r_idx_bias += current_row_bias - 1
    grid = [[c if c is not None else '' for c in r] for r in grid]
    return grid


def latex2html(latex_str):
    # 去除注释
    latex_str = re.sub(r'(?<!\\)%.*$', '', latex_str, flags=re.MULTILINE)
    # 去除"\\\\[...]"
    latex_str = re.sub(r'(?<!\\)\\\\\[.*?\]', '', latex_str, flags=re.DOTALL)

    latex_str = latex_str.replace('\n', '').replace('\t', '')
    try:
        grid = qylatex_to_grid(latex_str)
    except IndexError as e:
        print(f"IndexError: {str(e)}")
        return 
    if not grid:
        return
    html = grid2html(grid)
    return html

def teds_structure(gt, pred):
    """计算TEDS"""
    gt_html = latex2html(gt)
    pred_html = latex2html(pred)
    if not pred_html:
        # print("Prediction LaTeX to HTML conversion failed.")
        return 0, 0
    structure_teds = TEDS(structure_only=True)
    structure_score = structure_teds(gt_html, pred_html)
    all_teds = TEDS()
    teds_score = all_teds(gt_html, pred_html)
    return structure_score, teds_score

def process_item(args):
    """包装处理函数用于多进程"""
    idx, item = args
    try:
        if not isinstance(item, dict):
            print(f"Item {idx}: Invalid data type ({type(item)}), skipping")
            return (None, None, None)
            
        required_keys = ['id', 'prediction', 'reference']
        for key in required_keys:
            if key not in item:
                print(f"Item {idx}: Missing key '{key}', skipping")
                return (None, None, None)

        qid = item['id']
        gt = item['reference']
        pred = item['prediction']
        structure_score, teds_score = teds_structure(gt, pred)  # 正确解构两个返回值
        return (qid, structure_score, teds_score)
        
    except Exception as e:
        print(f"Error processing item {idx}: {str(e)}")
        return (None, None, None)

def process_item_json(args):
    """包装处理函数用于多进程"""
    idx, item = args
    try:
        if not isinstance(item, dict):
            print(f"Item {idx}: Invalid data type ({type(item)}), skipping")
            return (None, None, None)
            
        required_keys = ['questionId', 'answer', 'annotation']
        if not all(key in item for key in required_keys):
            print(f"Item {idx}: Missing required keys")
            return None


        qid = item['questionId']
        gt = item['annotation']
        pred = item['answer']
        structure_score, teds_score = teds_structure(gt, pred)  # 正确解构两个返回值
        return (qid, structure_score, teds_score)
        
    except Exception as e:
        print(f"Error processing item {idx}: {str(e)}")
        return (None, None, None)
    
def process_json(input_path):
    output_path = input_path.rsplit('.', 1)[0] + '_eval.txt'
    structure_results = {}
    teds_results = {}
    # 读取整个JSON文件
    with open(input_path, 'r', encoding='utf-8') as f:
        try:
            data = json.load(f)
            if not isinstance(data, list):
                raise ValueError("Root element is not an array")
        except Exception as e:
            print(f"Failed to parse JSON: {str(e)}")
            return

    # 创建进程池
    cpu_num = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(processes=cpu_num)
    
    # 准备参数 (添加索引)
    task_args = [(i+1, item) for i, item in enumerate(data)]
    
    # 使用imap获取结果
    processed = 0
    for qid, structure_score, teds_score in pool.imap(process_item_json, task_args):
        if qid is not None:
            structure_results[qid] = structure_score
            teds_results[qid] = teds_score
        processed += 1
        if processed % 10 == 0:
            print(f"Processed {processed}/{len(data)} items")
    
    pool.close()
    pool.join()

    # 统计计算和结果写入
    structure_stats = calculate_stats(list(structure_results.values()))
    teds_stats = calculate_stats(list(teds_results.values()))

    with open(output_path, 'w', encoding='utf-8') as f:
        # 写入明细结果
        f.write("ID:Structure_Score, TEDS_Score\n")
        for qid in sorted(structure_results.keys()):
            f.write(f"{qid}:{structure_results[qid]:.4f}, {teds_results[qid]:.4f}\n")
        
        # 写入统计信息
        f.write("\n=== Structure Score Statistics ===\n")
        f.write(f"Total samples: {len(data)}\n")
        f.write(f"Valid samples: {len(structure_results)}\n")
        f.write(f"Non-zero ratio: {structure_stats[2]:.4f}\n")
        f.write(f"Average Score (All): {structure_stats[0]:.4f}\n")
        f.write(f"Average Score (Non-zero): {structure_stats[1]:.4f}\n")
        
        f.write("\n=== TEDS Score Statistics ===\n")
        f.write(f"Total samples: {len(data)}\n")
        f.write(f"Valid samples: {len(teds_results)}\n")
        f.write(f"Non-zero ratio: {teds_stats[2]:.4f}\n")
        f.write(f"Average Score (All): {teds_stats[0]:.4f}\n")
        f.write(f"Average Score (Non-zero): {teds_stats[1]:.4f}\n")

def calculate_stats(scores):
    avg_all = sum(scores) / len(scores) if scores else 0
    non_zero_scores = [s for s in scores if s != 0]
    avg_non_zero = sum(non_zero_scores)/len(non_zero_scores) if non_zero_scores else 0
    non_zero_ratio = len(non_zero_scores) / len(scores) if scores else 0
    return avg_all, avg_non_zero, non_zero_ratio

def process_jsonl(input_path):
    """处理jsonl文件并自动生成输出路径"""
    # 自动生成输出路径
    output_path = input_path.rsplit('.', 1)[0] + '_eval.txt'
    
    structure_results = {}
    teds_results = {}
    
    # 读取JSONL文件
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                item = json.loads(line.strip())
                data.append(item)
            except json.JSONDecodeError as e:
                print(f"Failed to parse line: {str(e)}")
                continue

    print(f"Processing {len(data)} items from {input_path}")
    print(f"Results will be saved to {output_path}")

    # 创建进程池
    cpu_num = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(processes=cpu_num)
    
    # 准备参数 (添加索引)
    task_args = [(i+1, item) for i, item in enumerate(data)]
    
    # 使用imap获取结果 
    processed = 0
    for qid, structure_score, teds_score in pool.imap(process_item, task_args):
        if qid is not None:
            structure_results[qid] = structure_score
            teds_results[qid] = teds_score
        processed += 1
        if processed % 10 == 0:
            print(f"Processed {processed}/{len(data)} items")
    
    pool.close()
    pool.join()

    # 统计计算和结果写入
    structure_stats = calculate_stats(list(structure_results.values()))
    teds_stats = calculate_stats(list(teds_results.values()))

    with open(output_path, 'w', encoding='utf-8') as f:
        # 写入明细结果
        f.write("ID:Structure_Score, TEDS_Score\n")
        for qid in sorted(structure_results.keys()):
            f.write(f"{qid}:{structure_results[qid]:.4f}, {teds_results[qid]:.4f}\n")
        
        # 写入统计信息
        f.write("\n=== Structure Score Statistics ===\n")
        f.write(f"Total samples: {len(data)}\n")
        f.write(f"Valid samples: {len(structure_results)}\n")
        f.write(f"Non-zero ratio: {structure_stats[2]:.4f}\n")
        f.write(f"Average Score (All): {structure_stats[0]:.4f}\n")
        f.write(f"Average Score (Non-zero): {structure_stats[1]:.4f}\n")
        
        f.write("\n=== TEDS Score Statistics ===\n")
        f.write(f"Total samples: {len(data)}\n")
        f.write(f"Valid samples: {len(teds_results)}\n")
        f.write(f"Non-zero ratio: {teds_stats[2]:.4f}\n")
        f.write(f"Average Score (All): {teds_stats[0]:.4f}\n")
        f.write(f"Average Score (Non-zero): {teds_stats[1]:.4f}\n")
        
process_jsonl("/home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b_simple_grpo_1epoch/output.jsonl")  # 修改此处为你的JSONL文件路径

Processing 500 items from /home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b_simple_grpo_1epoch/output.jsonl
Results will be saved to /home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b_simple_grpo_1epoch/output_eval.txt
Processed 10/500 items
Processed 20/500 items
Processed 30/500 items
Processed 40/500 items
Processed 50/500 items
Processed 60/500 items
Processed 70/500 items
Processed 80/500 items
Processed 90/500 items
Processed 100/500 items
Processed 110/500 items
Processed 120/500 items
Processed 130/500 items
Processed 140/500 items
Processed 150/500 items
Processed 160/500 items
Processed 170/500 items
Processed 180/500 items
Processed 190/500 items
Processed 200/500 items
Processed 210/500 items
Processed 220/500 items
Processed 230/500 items
Processed 240/500 items
Processed 250/500 items
Processed 260/500 items
Processed 270/500 items
Processed 280/500 items
Processed 290/500 items
Processed 300/500 items
Processed 310/500 items
Processed 320/500 items
Proce

In [None]:
import os
import subprocess
from pdf2image import convert_from_path
from PIL import Image
from multiprocessing import Pool
import uuid
import glob
import concurrent.futures
import numpy as np

# LaTeX preamble with UTF-8 support and table formatting
latex_preamble = r"""
\documentclass[standalone]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage{amsmath, amsthm, amssymb, graphicx, geometry, array}
\usepackage{booktabs, multirow,natbib,tabularx, multicol, bm}
\pagenumbering{gobble}
\begin{document}
\begin{table}[htp]
\centering
\resizebox*{0.5\columnwidth}{!}{ 
"""
latex_end = r"}\end{table}\end{document}"

def render_content_to_pdf(content, output_pdf_path, timeout=20):
    """Render LaTeX content to PDF with error handling and cleanup"""
    full_tex_content = latex_preamble + content + latex_end
    
    # Generate unique temporary filename
    temp_tex_filename = f"temp_{uuid.uuid4().hex}.tex"
    temp_tex_path = os.path.join(os.path.dirname(output_pdf_path), temp_tex_filename)
    
    try:
        with open(temp_tex_path, "w") as temp_file:
            temp_file.write(full_tex_content)

        result = subprocess.run(
            ["pdflatex", "-interaction=nonstopmode", "-output-directory", os.path.dirname(output_pdf_path), temp_tex_path],
            check=False,
            capture_output=True,
            text=True,
            encoding='latin-1',
            timeout=timeout
        )

        temp_pdf_path = temp_tex_path.replace(".tex", ".pdf")
        if os.path.exists(temp_pdf_path):
            os.rename(temp_pdf_path, output_pdf_path)
            print(f"Rendered PDF saved to {output_pdf_path}")
        else:
            print(f"PDF generation failed for {output_pdf_path}. Errors:\n{result.stdout}")
    except subprocess.TimeoutExpired:
        print(f"Timeout expired for {output_pdf_path}")
    except Exception as e:
        print(f"Error processing {output_pdf_path}: {str(e)}")
    finally:
        # Cleanup temporary files
        for ext in [".aux", ".log", ".out", ".tex"]:
            file_path = temp_tex_path.replace(".tex", ext)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                except:
                    pass

def convert_pdf_to_png(pdf_path, png_path, dpi=300):
    """Convert PDF to PNG with error handling"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        if images:
            images[0].save(png_path, "PNG")
            print(f"Converted {pdf_path} to {png_path}")
        else:
            print(f"Error: Empty PDF {pdf_path}")
    except Exception as e:
        print(f"Conversion failed for {pdf_path}: {str(e)}")

def find_matching_pairs(gt_dir, gen_dir):
    """Find matching file pairs between groundtruth and generated directories"""
    gt_files = glob.glob(os.path.join(gt_dir, "*.tex"))
    pairs = []
    
    for gt_path in gt_files:
        base_name = os.path.splitext(os.path.basename(gt_path))[0]
        gen_path = os.path.join(gen_dir, f"{base_name}.txt")
        
        if os.path.exists(gen_path):
            pairs.append((gt_path, gen_path))
        else:
            print(f"Warning: Missing generated file for {base_name}")
    
    print(f"Found {len(pairs)} valid file pairs")
    return pairs

def process_pair(pair):
    """Process a matched file pair with error handling"""
    gt_path, gen_path = pair
    base_name = os.path.splitext(os.path.basename(gt_path))[0]
    
    try:
        # Process groundtruth
        with open(gt_path, "r") as f:
            gt_content = f.read()
        
        gt_pdf = os.path.join(pdf_output_dir, f"{base_name}_gt.pdf")
        gt_png = os.path.join(png_output_dir, f"{base_name}_gt.png")
        render_content_to_pdf(gt_content, gt_pdf)
        if os.path.exists(gt_pdf):
            convert_pdf_to_png(gt_pdf, gt_png)

        # Process generated
        with open(gen_path, "r") as f:
            gen_content = f.read()
        
        gen_pdf = os.path.join(pdf_output_dir, f"{base_name}_gen.pdf")
        gen_png = os.path.join(png_output_dir, f"{base_name}_gen.png")
        render_content_to_pdf(gen_content, gen_pdf)
        if os.path.exists(gen_pdf):
            convert_pdf_to_png(gen_pdf, gen_png)

        return True
    except Exception as e:
        print(f"Failed to process {base_name}: {str(e)}")
        return False

def dwt2_simple(img_array):
    """Performs a simple 2D discrete wavelet transform with even-size check"""
    rows, cols = img_array.shape
    rows = rows - rows % 2
    cols = cols - cols % 2
    img_array = img_array[:rows, :cols]
    
    blocks = img_array.reshape(rows//2, 2, cols//2, 2).transpose(0, 2, 1, 3)
    a = blocks[..., 0, 0]
    b = blocks[..., 0, 1]
    c = blocks[..., 1, 0]
    d = blocks[..., 1, 1]
    
    cA = (a + b + c + d) * 0.25
    cH = (a - c) * 0.5
    cV = (a - b) * 0.5
    cD = (a - d) * 0.5
    return cA, cH, cV, cD

def calculate_ssim(img1, img2):
    img1_flat = img1.ravel()
    img2_flat = img2.ravel()
    n = img1_flat.size
    
    sum1 = img1_flat.sum()
    sum2 = img2_flat.sum()
    sum12 = (img1_flat * img2_flat).sum()
    sum1_sq = (img1_flat** 2).sum()
    sum2_sq = (img2_flat** 2).sum()
    
    mean1 = sum1 / n
    mean2 = sum2 / n
    var1 = (sum1_sq - sum1**2 / n) / n
    var2 = (sum2_sq - sum2**2 / n) / n
    
    covar = (sum12 - sum1 * sum2 / n) / (n - 1) if n > 1 else 0.0
    
    C1, C2 = 6.5025, 58.5225
    numerator = (2 * mean1 * mean2 + C1) * (2 * covar + C2)
    denominator = (mean1**2 + mean2**2 + C1) * (var1 + var2 + C2)
    return numerator / denominator

def calculate_cwssim(image1_path, image2_path):
    """Compares two images using Complex Wavelet Structural Similarity Index (CW-SSIM)"""
    try:
        image1 = Image.open(image1_path).convert('L')
        image2 = Image.open(image2_path).convert('L')
        image1 = image1.resize(image2.size)
        
        img1_array = np.array(image1)
        img2_array = np.array(image2)
        
        cA1, cH1, cV1, cD1 = dwt2_simple(img1_array)
        cA2, cH2, cV2, cD2 = dwt2_simple(img2_array)
        
        ssim_cA = calculate_ssim(cA1, cA2)
        ssim_cH = calculate_ssim(cH1, cH2)
        ssim_cV = calculate_ssim(cV1, cV2)
        ssim_cD = calculate_ssim(cD1, cD2)
        cwssim_score = (ssim_cA + ssim_cH + ssim_cV + ssim_cD) / 4
        
        return cwssim_score
    except Exception as e:
        print(f"Error comparing images {image1_path} and {image2_path}: {e}")
        return 0.0

def process_prefix(prefix, output_png_dir, gt_images, gen_images):
    gt_image_path = os.path.join(output_png_dir, gt_images[prefix])
    if prefix in gen_images:
        gen_image_path = os.path.join(output_png_dir, gen_images[prefix])
        cwssim_score = calculate_cwssim(gen_image_path, gt_image_path)
        return f"{prefix}: {cwssim_score:.4f}\n"
    else:
        return f"{prefix}: 0.0000\n"

def compare_images_and_save_results(output_dir):
    output_png_dir = os.path.join(output_dir, "png_results")
    result_file_path = os.path.join(output_dir, "cwssim_results.txt")
    
    image_files = os.listdir(output_png_dir)
    gt_images = {f.split('_')[0]: f for f in image_files if f.endswith('_gt.png')}
    gen_images = {f.split('_')[0]: f for f in image_files if f.endswith('_gen.png')}
    
    with open(result_file_path, "w") as result_file:
        with concurrent.futures.ProcessPoolExecutor() as executor:
            futures = [
                executor.submit(process_prefix, prefix, output_png_dir, gt_images, gen_images)
                for prefix in gt_images
            ]
            for future in concurrent.futures.as_completed(futures):
                result_line = future.result()
                result_file.write(result_line)
                print(result_line.strip())
    
    print(f"Results saved to {result_file_path}")
    return result_file_path

def sort_txt_file(input_file):
    try:
        with open(input_file, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]
        
        sorted_lines = sorted(lines, key=lambda x: int(x.split(':')[0]))
        output_file = input_file.replace('.txt', '_sorted.txt')
        
        with open(output_file, 'w') as f:
            f.write('\n'.join(sorted_lines))
        
        print(f"Sorted results saved to {output_file}")
        return output_file
    except Exception as e:
        print(f"Sorting failed: {str(e)}")
        return input_file

def calculate_average_and_ratio(file_path):
    try:
        total_score = 0.0
        valid_count = 0
        non_zero_count = 0
        
        with open(file_path, 'r') as file:
            for line in file:
                if not line.strip():
                    continue
                
                parts = line.rsplit(':', 1)
                if len(parts) != 2:
                    continue
                
                try:
                    score = float(parts[1].strip())
                except ValueError:
                    continue
                
                valid_count += 1
                if not np.isclose(score, 0.0):
                    non_zero_count += 1
                    total_score += score

        avg = total_score / non_zero_count if non_zero_count > 0 else 0.0
        ratio = non_zero_count / valid_count if valid_count > 0 else 0.0
        
        return avg, ratio
    except Exception as e:
        print(f"处理文件时发生错误: {e}")
        return None, None

def calculate_average_score(file_path):
    """Calculates the average score from a file with id:score format."""
    try:
        with open(file_path, 'r') as file:
            total_score = 0
            count = 0
            for line in file:
                # Split by ':' and strip whitespace, convert the score to float
                parts = line.split(":")
                if len(parts) == 2:
                    score = float(parts[1].strip())
                    total_score += score
                    count += 1

            # Calculate and return the average score
            if count > 0:
                average_score = total_score / count
                return average_score
            else:
                return 0.0
    except Exception as e:
        print(f"Error reading the file: {e}")
        return None
    
    # 用户只需设置以下两个路径
gt_dir = "/home/lingjun/code/InternVL/internvl_chat/val_data/complex/label"  # 修改为你的GT路径
gen_dir = "/home/lingjun/code/InternVL/internvl_chat/val_data/easy/pred_1B-v5-2570"   # 修改为你的生成文件路径

# 自动生成输出路径
output_dir = os.path.join(os.path.dirname(gen_dir), "output_" + os.path.basename(gen_dir))
os.makedirs(output_dir, exist_ok=True)

# 创建子目录
global pdf_output_dir, png_output_dir
pdf_output_dir = os.path.join(output_dir, "pdf_results")
png_output_dir = os.path.join(output_dir, "png_results")
os.makedirs(pdf_output_dir, exist_ok=True)
os.makedirs(png_output_dir, exist_ok=True)

# 处理文件对
file_pairs = find_matching_pairs(gt_dir, gen_dir)
num_workers = max(1, os.cpu_count() // 2)

with Pool(processes=num_workers) as pool:
    results = pool.imap_unordered(process_pair, file_pairs)
    success_count = 0
    for i, result in enumerate(results, 1):
        if result:
            success_count += 1
        print(f"Processed {i}/{len(file_pairs)} pairs ({success_count} successful)")

# 图像比较和分数计算
result_file = compare_images_and_save_results(output_dir)
sorted_file = sort_txt_file(result_file)

# 计算统计指标
result_file_path = os.path.join(output_dir, "cwssim_results.txt")
average_score, non_zero_ratio = calculate_average_and_ratio(result_file_path)

overall_avg = calculate_average_score(result_file)
if overall_avg is not None:
    print(f"\n整体平均分: {overall_avg:.4f}")

# 非零分数统计
avg_non_zero, non_zero_ratio = calculate_average_and_ratio(result_file)
if avg_non_zero is not None and non_zero_ratio is not None:
    print(f"非零分数平均分: {avg_non_zero:.4f}")
    print(f"有效生成占比: {non_zero_ratio:.2%}")


In [None]:
import os
import subprocess
import json
from pdf2image import convert_from_path
from PIL import Image
from multiprocessing import Pool
import uuid
from skimage.metrics import structural_similarity as cw_ssim
# Updated LaTeX preamble with UTF-8 support
latex_preamble = r"""
\documentclass[standalone]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage{amsmath, amsthm, amssymb, graphicx, geometry, array}
\usepackage{booktabs, multirow,natbib,tabularx, multicol, bm}
\pagenumbering{gobble}
\begin{document}
\begin{table}[htp]
\centering
\resizebox*{0.5\columnwidth}{!}{ 
"""
latex_end = r"}\end{table}\end{document}"

def render_tex_to_pdf(tex_path, output_pdf_path, timeout=20):
    """Render a LaTeX file to PDF with unique temporary files for parallel processing"""
    with open(tex_path, "r") as file:
        tex_content = file.read()
    
    full_tex_content = latex_preamble + tex_content + latex_end
    
    # Generate unique temporary filename using UUID
    temp_tex_filename = f"temp_{uuid.uuid4().hex}.tex"
    temp_tex_path = os.path.join(os.path.dirname(output_pdf_path), temp_tex_filename)
    
    with open(temp_tex_path, "w") as temp_file:
        temp_file.write(full_tex_content)
    
    try:
        result = subprocess.run(
            ["pdflatex", "-interaction=nonstopmode", "-output-directory", os.path.dirname(output_pdf_path), temp_tex_path],
            check=False,
            capture_output=True,
            text=True,
            encoding='latin-1',
            timeout=timeout
        )
        
        # Get the actual output PDF path from LaTeX compilation
        temp_pdf_path = temp_tex_path.replace(".tex", ".pdf")
        
        if os.path.exists(temp_pdf_path):
            os.rename(temp_pdf_path, output_pdf_path)
            print(f"Successfully rendered {tex_path} to PDF at {output_pdf_path}.")
        else:
            print(f"Error: PDF not generated for {tex_path}. LaTeX output:\n{result.stdout}")
    except subprocess.TimeoutExpired:
        print(f"Timeout expired while rendering {tex_path}. Skipping this file.")
    except Exception as e:
        print(f"Unexpected error rendering {tex_path}: {str(e)}")
    finally:
        # Cleanup temporary files
        for ext in [".aux", ".log", ".out", ".tex"]:
            file_path = temp_tex_path.replace(".tex", ext)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                except Exception as e:
                    print(f"Warning: Could not delete temporary file {file_path}: {str(e)}")

def convert_pdf_to_png(pdf_path, png_path, dpi=300):
    """Convert PDF to PNG with error handling"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        if images:
            images[0].save(png_path, "PNG")
            print(f"Converted {pdf_path} to {png_path}")
        else:
            print(f"Error: No pages found in {pdf_path}")
    except Exception as e:
        print(f"Error converting {pdf_path} to PNG: {str(e)}")

def process_entry(entry):
    """Process a single JSON entry with full error handling"""
    try:
        question_id = entry["questionId"]
        print(f"Processing {question_id}...")
        
        # Create output filenames
        base_files = {
            "answer": (entry["answer"], f"{question_id}_answer"),
            "annotation": (entry["annotation"], f"{question_id}_annotation")
        }
        
        for content_type, (content, base_name) in base_files.items():
            # Generate LaTeX file
            tex_path = os.path.join(table_dir, f"{base_name}.tex")
            with open(tex_path, "w") as f:
                f.write(content)
            
            # Generate PDF
            pdf_path = os.path.join(output_pdf_dir, f"{base_name}.pdf")
            render_tex_to_pdf(tex_path, pdf_path)
            
            # Convert to PNG
            if os.path.exists(pdf_path):
                png_path = os.path.join(output_png_dir, f"{base_name}.png")
                convert_pdf_to_png(pdf_path, png_path)
            else:
                print(f"Skipping PNG conversion for {base_name} - PDF not found")
        
        print(f"Completed processing {question_id}")
        return True
    except Exception as e:
        print(f"Error processing {entry.get('questionId', 'unknown')}: {str(e)}")
        return False

def process_json(json_file):
    """Main processing function with parallel execution"""
    with open(json_file, "r") as file:
        data = json.load(file)
    
    # Create necessary directories
    os.makedirs(table_dir, exist_ok=True)
    os.makedirs(output_pdf_dir, exist_ok=True)
    os.makedirs(output_png_dir, exist_ok=True)
    
    # Use half of available CPUs for resource balancing
    num_workers = max(1, os.cpu_count() // 2)
    
    with Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(process_entry, data)
        
        # Monitor progress
        success_count = 0
        for i, result in enumerate(results, 1):
            if result:
                success_count += 1
            print(f"Processed {i}/{len(data)} entries ({success_count} successes)")
    
    print(f"\nProcessing complete. Success rate: {success_count}/{len(data)}")

if __name__ == "__main__":
    # Configuration paths
    json_file = "/home/lingjun/code/InternVL/internvl_chat/results_v1_5epoch/Table2Latex_250401040709.json"
    base_dir = os.path.dirname(json_file)
    extracted_tables_dir = os.path.join(base_dir, "extracted_tables")
    
    # Initialize directory paths
    table_dir = os.path.join(extracted_tables_dir, "latex_tables")
    output_pdf_dir = os.path.join(extracted_tables_dir, "pdf_tables_full")
    output_png_dir = os.path.join(extracted_tables_dir, "png_tables_full")
    
    # Start processing
    process_json(json_file)
    print("Rendering and conversion complete!")

In [None]:
#####################文件形式
import os
import subprocess
from pdf2image import convert_from_path
from PIL import Image
from multiprocessing import Pool
import uuid
import glob

# LaTeX preamble with UTF-8 support and table formatting
latex_preamble = r"""
\documentclass[standalone]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage{amsmath, amsthm, amssymb, graphicx, geometry, array}
\usepackage{booktabs, multirow,natbib,tabularx, multicol, bm}
\pagenumbering{gobble}
\begin{document}
\begin{table}[htp]
\centering
\resizebox*{0.5\columnwidth}{!}{ 
"""
latex_end = r"}\end{table}\end{document}"

def render_content_to_pdf(content, output_pdf_path, timeout=20):
    """Render LaTeX content to PDF with error handling and cleanup"""
    full_tex_content = latex_preamble + content + latex_end
    
    # Generate unique temporary filename
    temp_tex_filename = f"temp_{uuid.uuid4().hex}.tex"
    temp_tex_path = os.path.join(os.path.dirname(output_pdf_path), temp_tex_filename)
    
    try:
        with open(temp_tex_path, "w") as temp_file:
            temp_file.write(full_tex_content)

        result = subprocess.run(
            ["pdflatex", "-interaction=nonstopmode", "-output-directory", os.path.dirname(output_pdf_path), temp_tex_path],
            check=False,
            capture_output=True,
            text=True,
            encoding='latin-1',
            timeout=timeout
        )

        temp_pdf_path = temp_tex_path.replace(".tex", ".pdf")
        if os.path.exists(temp_pdf_path):
            os.rename(temp_pdf_path, output_pdf_path)
            print(f"Rendered PDF saved to {output_pdf_path}")
        else:
            print(f"PDF generation failed for {output_pdf_path}. Errors:\n{result.stdout}")
    except subprocess.TimeoutExpired:
        print(f"Timeout expired for {output_pdf_path}")
    except Exception as e:
        print(f"Error processing {output_pdf_path}: {str(e)}")
    finally:
        # Cleanup temporary files
        for ext in [".aux", ".log", ".out", ".tex"]:
            file_path = temp_tex_path.replace(".tex", ext)
            if os.path.exists(file_path):
                try:
                    os.remove(file_path)
                except:
                    pass

def convert_pdf_to_png(pdf_path, png_path, dpi=160):
    """Convert PDF to PNG with error handling"""
    try:
        images = convert_from_path(pdf_path, dpi=dpi)
        if images:
            images[0].save(png_path, "PNG")
            print(f"Converted {pdf_path} to {png_path}")
        else:
            print(f"Error: Empty PDF {pdf_path}")
    except Exception as e:
        print(f"Conversion failed for {pdf_path}: {str(e)}")

def find_matching_pairs(gt_dir, gen_dir):
    """Find matching file pairs between groundtruth and generated directories"""
    gt_files = glob.glob(os.path.join(gt_dir, "*.tex"))
    pairs = []
    
    for gt_path in gt_files:
        base_name = os.path.splitext(os.path.basename(gt_path))[0]
        gen_path = os.path.join(gen_dir, f"{base_name}.txt")
        
        if os.path.exists(gen_path):
            pairs.append((gt_path, gen_path))
        else:
            print(f"Warning: Missing generated file for {base_name}")
    
    print(f"Found {len(pairs)} valid file pairs")
    return pairs

def process_pair(pair):
    """Process a matched file pair with error handling"""
    gt_path, gen_path = pair
    base_name = os.path.splitext(os.path.basename(gt_path))[0]
    
    try:
        # Process groundtruth
        with open(gt_path, "r") as f:
            gt_content = f.read()
        
        gt_pdf = os.path.join(pdf_output_dir, f"{base_name}_gt.pdf")
        gt_png = os.path.join(png_output_dir, f"{base_name}_gt.png")
        render_content_to_pdf(gt_content, gt_pdf)
        if os.path.exists(gt_pdf):
            convert_pdf_to_png(gt_pdf, gt_png)

        # Process generated
        with open(gen_path, "r") as f:
            gen_content = f.read()
        
        gen_pdf = os.path.join(pdf_output_dir, f"{base_name}_gen.pdf")
        gen_png = os.path.join(png_output_dir, f"{base_name}_gen.png")
        render_content_to_pdf(gen_content, gen_pdf)
        if os.path.exists(gen_pdf):
            convert_pdf_to_png(gen_pdf, gen_png)

        return True
    except Exception as e:
        print(f"Failed to process {base_name}: {str(e)}")
        return False

if __name__ == "__main__":
    # Configuration paths
    gt_dir = "/home/lingjun/code/InternVL/internvl_chat/val_data/complex/label"          # Groundtruth .tex files
    gen_dir = "/home/lingjun/code/InternVL/internvl_chat/val_data/complex/pred_1B-v5-38555"         # Generated .txt files
    output_dir = "/home/lingjun/code/InternVL/internvl_chat/val_data/complex/output_pred_1B-v5-38555" # Main output directory
    
    # Create output directories
    pdf_output_dir = os.path.join(output_dir, "pdf_results")
    png_output_dir = os.path.join(output_dir, "png_results")
    os.makedirs(pdf_output_dir, exist_ok=True)
    os.makedirs(png_output_dir, exist_ok=True)

    # Find and process matching pairs
    file_pairs = find_matching_pairs(gt_dir, gen_dir)
    
    # Use half of available CPUs
    num_workers = max(1, os.cpu_count() // 2)
    
    # Process pairs in parallel
    with Pool(processes=num_workers) as pool:
        results = pool.imap_unordered(process_pair, file_pairs)
        
        success_count = 0
        for i, result in enumerate(results, 1):
            if result:
                success_count += 1
            print(f"Processed {i}/{len(file_pairs)} pairs ({success_count} successful)")

    print(f"\nProcessing completed. Success rate: {success_count}/{len(file_pairs)}")
    print(f"PDF outputs: {pdf_output_dir}")
    print(f"PNG outputs: {png_output_dir}")

In [2]:
from numpy.lib.stride_tricks import sliding_window_view
import os
import concurrent
import numpy as np
from PIL import Image
import cv2

def dwt2_simple(img_array):
    """Performs a simple 2D discrete wavelet transform with even-size check"""
    # Ensure rows and cols are even by cropping
    rows, cols = img_array.shape
    rows = rows - rows % 2  # 裁剪到最近的偶数
    cols = cols - cols % 2
    img_array = img_array[:rows, :cols]  # 保留能被2整除的部分
    
    # 重塑为块结构 (rows//2, cols//2, 2, 2)
    blocks = img_array.reshape(rows//2, 2, cols//2, 2).transpose(0, 2, 1, 3)
    a = blocks[..., 0, 0]
    b = blocks[..., 0, 1]
    c = blocks[..., 1, 0]
    d = blocks[..., 1, 1]
    
    # 计算系数
    cA = (a + b + c + d) * 0.25
    cH = (a - c) * 0.5
    cV = (a - b) * 0.5
    cD = (a - d) * 0.5
    return cA, cH, cV, cD

# Function to calculate Structural Similarity Index (SSIM)
def calculate_ssim(img1, img2):
    img1_flat = img1.ravel()
    img2_flat = img2.ravel()
    n = img1_flat.size
    
    # 一次性计算所有总和
    sum1 = img1_flat.sum()
    sum2 = img2_flat.sum()
    sum12 = (img1_flat * img2_flat).sum()
    sum1_sq = (img1_flat** 2).sum()
    sum2_sq = (img2_flat** 2).sum()
    
    # 均值和方差
    mean1 = sum1 / n
    mean2 = sum2 / n
    var1 = (sum1_sq - sum1**2 / n) / n  # ddof=0
    var2 = (sum2_sq - sum2**2 / n) / n
    
    # 协方差（样本协方差，ddof=1）
    covar = (sum12 - sum1 * sum2 / n) / (n - 1) if n > 1 else 0.0
    
    # SSIM公式
    C1, C2 = 6.5025, 58.5225
    numerator = (2 * mean1 * mean2 + C1) * (2 * covar + C2)
    denominator = (mean1**2 + mean2**2 + C1) * (var1 + var2 + C2)
    return numerator / denominator

# Function to calculate Complex Wavelet Structural Similarity Index (CW-SSIM)
def calculate_cwssim(image1_path, image2_path):
    """Compares two images using Complex Wavelet Structural Similarity Index (CW-SSIM)"""
    try:
        # Open both images
        image1 = Image.open(image1_path).convert('L')  # Convert to grayscale
        image2 = Image.open(image2_path).convert('L')  # Convert to grayscale
        
        # Resize the images to the same size if necessary
        image1 = image1.resize(image2.size)
        
        # Convert images to numpy arrays
        img1_array = np.array(image1)
        img2_array = np.array(image2)
        
        # Perform simple 2D DWT (discrete wavelet transform)
        cA1, cH1, cV1, cD1 = dwt2_simple(img1_array)
        cA2, cH2, cV2, cD2 = dwt2_simple(img2_array)
        
        # Compute SSIM for the approximation and detail coefficients
        ssim_cA = calculate_ssim(cA1, cA2)
        ssim_cH = calculate_ssim(cH1, cH2)
        ssim_cV = calculate_ssim(cV1, cV2)
        ssim_cD = calculate_ssim(cD1, cD2)
        cwssim_score = (ssim_cA + ssim_cH + ssim_cV + ssim_cD) / 4
        
        return cwssim_score
    
    except Exception as e:
        print(f"Error comparing images {image1_path} and {image2_path}: {e}")
        return 0.0

def process_prefix(prefix, output_png_dir, gt_images, gen_images):
    gt_image_path = os.path.join(output_png_dir, gt_images[prefix])
    if prefix in gen_images:
        gen_image_path = os.path.join(output_png_dir, gen_images[prefix])
        cwssim_score = calculate_cwssim(gen_image_path, gt_image_path)
        return f"{prefix}: {cwssim_score:.4f}\n"
    else:
        return f"{prefix}: 0.0000\n"

def compare_images_and_save_results():
    # 更新路径到新的图片存储位置
    output_png_dir = "/home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b-complex-checkpoint-9450/extracted_tables/pdf_tables_full"
    result_file_path = "/home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b-complex-checkpoint-9450/cwssim_results.txt"
    
    # 获取图片文件列表
    image_files = os.listdir(output_png_dir)
    
    # 按新格式过滤文件：前缀_gt.png 和 前缀_gen.png
    gt_images = {f.split('_')[0]: f for f in image_files if f.endswith('_reference.png')}
    gen_images = {f.split('_')[0]: f for f in image_files if f.endswith('_prediction.png')}
    
    with open(result_file_path, "w") as result_file:
        with concurrent.futures.ProcessPoolExecutor() as executor:
            futures = [
                executor.submit(process_prefix, prefix, output_png_dir, gt_images, gen_images)
                for prefix in gt_images
            ]
            for future in concurrent.futures.as_completed(futures):
                result_line = future.result()
                result_file.write(result_line)
                print(result_line.strip())
    
    print(f"Results saved to {result_file_path}")
    return result_file_path



def sort_txt_file(input_file):
    """排序结果文件"""
    try:
        with open(input_file, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]
        
        # 按前缀数字排序
        sorted_lines = sorted(lines, key=lambda x: int(x.split(':')[0]))
        
        # 生成排序后的文件名
        output_file = input_file.replace('.txt', '_sorted.txt')
        
        with open(output_file, 'w') as f:
            f.write('\n'.join(sorted_lines))
        
        print(f"Sorted results saved to {output_file}")
        return output_file
    except Exception as e:
        print(f"Sorting failed: {str(e)}")
        return input_file

# 执行流程
if __name__ == "__main__":
    result_file = compare_images_and_save_results()
    sorted_file = sort_txt_file(result_file)
    print(f"Final sorted results: {sorted_file}")
    print("Image comparison complete!")

Results saved to /home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b-complex-checkpoint-9450/cwssim_results.txt
Sorted results saved to /home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b-complex-checkpoint-9450/cwssim_results_sorted.txt
Final sorted results: /home/lingjun/code/InternVL/internvl_chat/results_qwenvl3b-complex-checkpoint-9450/cwssim_results_sorted.txt
Image comparison complete!


In [None]:

import os
import concurrent
import numpy as np
from PIL import Image


def dwt2_simple(img_array):
    """Performs a simple 2D discrete wavelet transform with even-size check"""
    # Ensure rows and cols are even by cropping
    rows, cols = img_array.shape
    rows = rows - rows % 2  # 裁剪到最近的偶数
    cols = cols - cols % 2
    img_array = img_array[:rows, :cols]  # 保留能被2整除的部分
    
    # 重塑为块结构 (rows//2, cols//2, 2, 2)
    blocks = img_array.reshape(rows//2, 2, cols//2, 2).transpose(0, 2, 1, 3)
    a = blocks[..., 0, 0]
    b = blocks[..., 0, 1]
    c = blocks[..., 1, 0]
    d = blocks[..., 1, 1]
    
    # 计算系数
    cA = (a + b + c + d) * 0.25
    cH = (a - c) * 0.5
    cV = (a - b) * 0.5
    cD = (a - d) * 0.5
    return cA, cH, cV, cD

# Function to calculate Structural Similarity Index (SSIM)
def calculate_ssim(img1, img2):
    img1_flat = img1.ravel()
    img2_flat = img2.ravel()
    n = img1_flat.size
    
    # 一次性计算所有总和
    sum1 = img1_flat.sum()
    sum2 = img2_flat.sum()
    sum12 = (img1_flat * img2_flat).sum()
    sum1_sq = (img1_flat** 2).sum()
    sum2_sq = (img2_flat** 2).sum()
    
    # 均值和方差
    mean1 = sum1 / n
    mean2 = sum2 / n
    var1 = (sum1_sq - sum1**2 / n) / n  # ddof=0
    var2 = (sum2_sq - sum2**2 / n) / n
    
    # 协方差（样本协方差，ddof=1）
    covar = (sum12 - sum1 * sum2 / n) / (n - 1) if n > 1 else 0.0
    
    # SSIM公式
    C1, C2 = 6.5025, 58.5225
    numerator = (2 * mean1 * mean2 + C1) * (2 * covar + C2)
    denominator = (mean1**2 + mean2**2 + C1) * (var1 + var2 + C2)
    return numerator / denominator

# Function to calculate Complex Wavelet Structural Similarity Index (CW-SSIM)
def calculate_cwssim(image1_path, image2_path):
    """Compares two images using Complex Wavelet Structural Similarity Index (CW-SSIM)"""
    try:
        # Open both images
        image1 = Image.open(image1_path).convert('L')  # Convert to grayscale
        image2 = Image.open(image2_path).convert('L')  # Convert to grayscale
        
        # Resize the images to the same size if necessary
        image1 = image1.resize(image2.size)
        
        # Convert images to numpy arrays
        img1_array = np.array(image1)
        img2_array = np.array(image2)
        
        # Perform simple 2D DWT (discrete wavelet transform)
        cA1, cH1, cV1, cD1 = dwt2_simple(img1_array)
        cA2, cH2, cV2, cD2 = dwt2_simple(img2_array)
        
        # Compute SSIM for the approximation and detail coefficients
        ssim_cA = calculate_ssim(cA1, cA2)
        ssim_cH = calculate_ssim(cH1, cH2)
        ssim_cV = calculate_ssim(cV1, cV2)
        ssim_cD = calculate_ssim(cD1, cD2)
        cwssim_score = (ssim_cA + ssim_cH + ssim_cV + ssim_cD) / 4
        
        return cwssim_score
    
    except Exception as e:
        print(f"Error comparing images {image1_path} and {image2_path}: {e}")
        return 0.0

def process_prefix(prefix, output_png_dir, annotation_images, answer_images):
    annotation_image_path = os.path.join(output_png_dir, annotation_images[prefix])
    if prefix in answer_images:
        answer_image_path = os.path.join(output_png_dir, answer_images[prefix])
        cwssim_score = calculate_cwssim(answer_image_path, annotation_image_path)
        return f"{prefix}: {cwssim_score:.4f}\n"
    else:
        return f"{prefix}: 0.0000\n"

def compare_images_and_save_results():
    output_png_dir = "/home/lingjun/code/InternVL/internvl_chat/results_qwenvl_1epoch_complex/extracted_tables/png_tables_full"
    result_file_path = "/home/lingjun/code/InternVL/internvl_chat/results_qwenvl_1epoch_complex/extracted_tables/png_tables_full/cwssim_results.txt"
    image_files = os.listdir(output_png_dir)
    answer_images = {f.split('_')[0]: f for f in image_files if '_answer.png' in f}
    annotation_images = {f.split('_')[0]: f for f in image_files if '_annotation.png' in f}
    
    with open(result_file_path, "w") as result_file:
        with concurrent.futures.ProcessPoolExecutor() as executor:
            futures = [executor.submit(process_prefix, prefix, output_png_dir, annotation_images, answer_images) 
                       for prefix in annotation_images]
            for future in concurrent.futures.as_completed(futures):
                result_line = future.result()
                result_file.write(result_line)
                print(result_line.strip())
    print(f"Results saved to {result_file_path}")



def sort_txt_file(input_file, output_file=None):
    """
    读取并排序文本文件中的数字对
    :param input_file: 输入文件名（如"data.txt"）
    :param output_file: 输出文件名，默认添加"_sorted"后缀
    """
    try:
        # 读取原始文件
        with open(input_file, 'r') as f:
            lines = [line.strip() for line in f if line.strip()]  # 移除空行和空白字符

        # 定义排序规则：按冒号左侧的数字升序排列
        def sort_key(line):
            left_num = line.split(':', 1)[0].strip()  # 分割并取左边数字
            return int(left_num)

        # 执行排序
        sorted_lines = sorted(lines, key=sort_key)

        # 生成输出文件名
        output = output_file or input_file.replace('.txt', '_sorted.txt')

        # 写入新文件
        with open(output, 'w') as f:
            f.write('\n'.join(sorted_lines))

        print(f"文件已排序并保存为：{output}")
        return True

    except FileNotFoundError:
        print(f"错误：文件 {input_file} 不存在")
    except Exception as e:
        print(f"处理时发生错误：{str(e)}")
    return False

compare_images_and_save_results()
sort_txt_file("/home/lingjun/code/InternVL/internvl_chat/results_qwenvl_1epoch_complex/extracted_tables/png_tables_full/cwssim_results_sorted.txt")
print("Image comparison complete!")

In [21]:
def calculate_average_score(file_path):
    """Calculates the average score from a file with id:score format."""
    try:
        with open(file_path, 'r') as file:
            total_score = 0
            count = 0
            for line in file:
                # Split by ':' and strip whitespace, convert the score to float
                parts = line.split(":")
                if len(parts) == 2:
                    score = float(parts[1].strip())
                    total_score += score
                    count += 1

            # Calculate and return the average score
            if count > 0:
                average_score = total_score / count
                return average_score
            else:
                return 0.0
    except Exception as e:
        print(f"Error reading the file: {e}")
        return None

# Example usage
file_path = "/home/lingjun/code/InternVL/internvl_chat/results_v1_5epoch/ssim_comparison_results_complex_sorted.txt"
average_score = calculate_average_score(file_path)

if average_score is not None:
    print(f"The average score is: {average_score:.4f}")

The average score is: 0.5582


In [22]:
def calculate_average_and_ratio(file_path):
    """计算非零分数的平均值和非零分数占比
    Args:
        file_path: 包含评分结果的文本文件路径，每行格式为 "前缀: 分数"
    
    Returns:
        tuple: (平均分, 非零分数占比) 或 (None, None) 当出现错误时
    """
    try:
        total_score = 0.0
        valid_count = 0
        non_zero_count = 0
        
        with open(file_path, 'r') as file:
            for line in file:
                # 跳过空行
                if not line.strip():
                    continue
                
                # 安全分割数据
                parts = line.rsplit(':', 1)  # 从右边最多分割1次
                if len(parts) != 2:
                    continue  # 跳过格式错误行
                
                try:
                    score = float(parts[1].strip())
                except ValueError:
                    continue  # 跳过无法转换为浮点数的行
                
                valid_count += 1
                if not np.isclose(score, 0.0):  # 使用近似零判断
                    non_zero_count += 1
                    total_score += score

        # 计算统计值
        avg = total_score / non_zero_count if non_zero_count > 0 else 0.0
        ratio = non_zero_count / valid_count if valid_count > 0 else 0.0
        
        return avg, ratio
    
    except Exception as e:
        print(f"处理文件时发生错误: {e}")
        return None, None

file_path = "/home/lingjun/code/InternVL/internvl_chat/results_v1_5epoch/ssim_comparison_results_complex_sorted.txt"
average_score, non_zero_ratio = calculate_average_and_ratio(file_path)

if average_score is not None and non_zero_ratio is not None:
    print(f"非零分数平均值: {average_score:.4f}")
    print(f"非零分数占比: {non_zero_ratio:.2%}")
else:
    print("无法计算结果，请检查文件路径和格式")

非零分数平均值: 0.6051
非零分数占比: 92.24%


In [None]:
###########Teds_structure_evaluation###########
import re
from table_recognition_metric import TEDS
import multiprocessing
import json
def remove_grid_lines(latex_table):
    # 去除 \hline, \cline, \toprule, \midrule, \bottomrule
    cleaned_table = re.sub(r'\\cmidrule{\s*}|\\cdashline\{[0-9]+(-[0-9]+)?\}\s*|\\cmidrule\((?:lr|r|l)?\)\{[0-9]+\-[0-9]+\}\s*|\\arrayrulecolor{.*?}\s*|\\caption{.*?}\s*|\\centering\s*|\\hline\s*|\\cline{.*?}\s*|\\toprule\s*|\\midrule\s*|\\bottomrule\s*', '', latex_table)
    
    cleaned_table = re.sub(r'\\tabularnewline', r'\\\\', cleaned_table)
    # # 去除注释
    # cleaned_table = re.sub(r'(?<!\\)%.*$', '', cleaned_table, flags=re.MULTILINE)
    # cleaned_table = re.sub(r'(?<!\\)\[[^\[\]]*\]', '', cleaned_table)
    # 合并连续的空行
    cleaned_table = re.sub(r'\n\s*\n', '\n', cleaned_table)
    
    return cleaned_table.strip(' \n')  # 去除首尾空格

def fix_multi(cell):
    multirow_pattern = r'\\multirow{(\d+)}{.*?}{(.*?)}'
    multicol_pattern = r'\\multicolumn{(\d+)}{.*?}{(.*?)}'
    
    match = re.search(multirow_pattern, cell['content'])
    if match:
        cell['rowspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()

    match = re.search(multicol_pattern, cell['content'])
    if match:
        cell['colspan'] = int(match.group(1))
        cell['content'] = cell['content'].replace(match.group(0), match.group(2).strip(), 1).strip()

    return cell    

def grid2html(grid):
    def to_td(grid, r, c):
        if grid[r][c] == '<<' or grid[r][c] == '^^' or grid[r][c] == '..':
            return ''
        td = {'text': grid[r][c], 'rowspan':1, 'colspan': 1}

        for i in range(r + 1, len(grid)):
            if grid[i][c] == '^^':
                td['rowspan'] += 1
            else:
                break
        
        for j in range(c + 1, len(grid[r])):
            if grid[r][j] == '<<':
                td['colspan'] += 1
            else:
                break
        return f'<td rowspan={td["rowspan"]} colspan={td["colspan"]}> {td["text"]} </td>'.replace('rowspan=1', '').replace('colspan=1', '')
        
    
    html = []
    for r in range(len(grid)):
        row = []
        for c in range(len(grid[0])):
            row.append(to_td(grid, r, c))
        html.append(f'<tr> {"".join(row)} </tr>')
    # for row in grid:
    #     html.append('<tr>' + ''.join([to_td(c) for c in row]) + '</tr>')
    
    return '<html><body><table>' + '\n'.join(html) + '</table></body></html>'


def qylatex_to_grid(latex):
    # 提取表格内容
    if not latex.endswith('\\end{tabular}'):
        return 
    pattern = r'\\begin\{tabular\}\s*\{.*?\}(.*?)\\end\{tabular\}'
    matches = re.findall(pattern, latex, re.DOTALL)
    if matches:
        table_content = matches[0]
    else:
        return
    # 提取表格内容

    content = remove_grid_lines(table_content)
    # 获取表格内部的内容
    
    # 将行和列分割
    rows = content.strip(' \n').split(r'\\')

    processed_rows = []

    for row in rows:
        # 去除空行
        if not row.strip():
            continue
        # 处理 multirow 和 multicolumn
        # row = re.sub(r'\\multirow{(\d+)}{.*?}{(.*?)}', lambda m: f"{m.group(2)}" + ("<< " * (int(m.group(1)) - 1)), row)
        # row = re.sub(r'\\multicolumn{(\d+)}{.*?}{(.*?)}', lambda m: f"{m.group(2)}" + ("| " * (int(m.group(1)) - 1)), row)

        # 用 & 分割列
        columns = re.split(r'(?<!\\)&', row)
        columns = [fix_multi({'content': c.strip(' \n'), 'rowspan': 1, 'colspan': 1}) for c in columns]
        # 去除多余的空格并构建行
        processed_rows.append(columns)
    # # 如果最后一行为空, 删除
    # while len(processed_rows) > 0 and len(processed_rows[-1]) == 0:
    #     processed_rows.pop()
    rows = processed_rows
    max_cols = max([sum([it['colspan'] for it in r]) for r in rows])
    # 创建一个空白网格
    grid = [[None for _ in range(max_cols)] for _ in range(len(rows))]
    col_char_num = [[1] for _ in range(max_cols)]
    # 填充网格，处理 rowspan 和 colspan
    r_idx_bias = 0
    for r_idx, row in enumerate(rows):
        r_idx += r_idx_bias
        if r_idx >= len(grid):
            grid.append([None for _ in range(max_cols)])
        c_idx = 0
        current_row_bias = 10000
        for cell in row:
            # 找到第一个未填充的单元格
            if grid[r_idx][c_idx] is not None:
                if cell['content']:
                    while grid[r_idx][c_idx] == '..':
                        c_idx += 1
                else:
                    c_idx += 1
                    continue

            current_row_bias = min(current_row_bias, cell['rowspan'])
            # 填充内容
            grid[r_idx][c_idx] = cell['content']
            col_char_num[c_idx].append(len(cell['content']))
            
            # 处理 rowspan 和 colspan
            for r in range(cell['rowspan']):
                for c in range(cell['colspan']):
                    if r == 0 and c == 0:
                        continue
                    if r == 0:
                        grid[r_idx][c_idx + c] = '<<'
                    elif c == 0:
                        grid[r_idx + r][c_idx] = '^^'
                    else:
                        grid[r_idx + r][c_idx + c] = '..'
            c_idx += cell['colspan']
        r_idx_bias += current_row_bias - 1
    grid = [[c if c is not None else '' for c in r] for r in grid]
    
    return grid


def latex2html(latex_str):
    # 去除注释
    latex_str = re.sub(r'(?<!\\)%.*$', '', latex_str, flags=re.MULTILINE)
    # 去除"\\\\[...]"
    latex_str = re.sub(r'(?<!\\)\\\\\[.*?\]', '', latex_str, flags=re.DOTALL)

    latex_str = latex_str.replace('\n', '').replace('\t', '')
    try:
        grid = qylatex_to_grid(latex_str)
    except IndexError:
        return 
    if not grid:
        return
    html = grid2html(grid)
    return html

def teds_structure(gt, pred):
    """计算结构相似度"""
    gt_html = latex2html(gt)
    pred_html = latex2html(pred)
    if not pred_html:
        # print("Prediction LaTeX to HTML conversion failed.")
        return 0
    structure_teds = TEDS(structure_only=True)
    structure_score = structure_teds(gt_html, pred_html)

    return structure_score

def process_item(args):
    """包装处理函数用于多进程"""
    idx, item = args
    try:
        # 有效性检查
        if not isinstance(item, dict):
            print(f"Item {idx}: Invalid data type ({type(item)}), skipping")
            return (None, None)
            
        required_keys = ['id', 'prediction', 'reference']
        for key in required_keys:
            if key not in item:
                print(f"Item {idx}: Missing key '{key}', skipping")
                return (None, None)

        # 处理数据
        qid = item['id']
        gt = item['reference']
        pred = item['prediction']
        score = teds_structure(gt, pred)
        return (qid, score)
        
    except Exception as e:
        print(f"Error processing item {idx}: {str(e)}")
        return (None, None)

def process_json(input_path, output_path):
    results = {}
    
    # 读取整个JSON文件
    with open(input_path, 'r', encoding='utf-8') as f:
        try:
            data = json.load(f)
            if not isinstance(data, list):
                raise ValueError("Root element is not an array")
        except Exception as e:
            print(f"Failed to parse JSON: {str(e)}")
            return

    # 创建进程池
    cpu_num = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(processes=cpu_num)
    
    # 准备参数 (添加索引)
    task_args = [(i+1, item) for i, item in enumerate(data)]
    
    # 使用imap获取结果
    processed = 0
    for qid, score in pool.imap(process_item, task_args):
        if qid is not None and score is not None:
            results[qid] = score
        processed += 1
        if processed % 10 == 0:  # 每处理10个显示进度
            print(f"Processed {processed}/{len(data)} items")
    
    pool.close()
    pool.join()

    # 新增统计计算和结果写入部分
    scores = list(results.values())
    avg_all = sum(scores) / len(scores) if scores else 0
    non_zero_scores = [s for s in scores if s != 0]
    avg_non_zero = sum(non_zero_scores)/len(non_zero_scores) if non_zero_scores else 0

    with open(output_path, 'w', encoding='utf-8') as f:
        # 写入明细结果
        for qid in sorted(results.keys()):
            f.write(f"{qid}:{results[qid]:.4f}\n")
        
        # 写入统计信息
        f.write("\n=== Statistical Results ===\n")
        f.write(f"Total samples: {len(data)}\n")
        f.write(f"Valid samples: {len(scores)}\n")
        f.write(f"Average TEDS (All): {avg_all:.4f}\n")
        f.write(f"Average TEDS (Non-zero): {avg_non_zero:.4f}\n")

def process_jsonl(input_path, output_path):
    results = {}
    
    # 读取JSONL文件
    data = []
    with open(input_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                item = json.loads(line.strip())
                data.append(item)
            except json.JSONDecodeError as e:
                print(f"Failed to parse line: {str(e)}")
                continue

    # 创建进程池
    cpu_num = multiprocessing.cpu_count()
    pool = multiprocessing.Pool(processes=cpu_num)
    
    # 准备参数 (添加索引)
    task_args = [(i+1, item) for i, item in enumerate(data)]
    
    # 使用imap获取结果 
    processed = 0
    for qid, score in pool.imap(process_item, task_args):
        if qid is not None and score is not None:
            results[qid] = score
        processed += 1
        if processed % 10 == 0:  # 每处理10个显示进度
            print(f"Processed {processed}/{len(data)} items")
    
    pool.close()
    pool.join()

    # 统计计算和结果写入
    scores = list(results.values())
    avg_all = sum(scores) / len(scores) if scores else 0
    non_zero_scores = [s for s in scores if s != 0]
    avg_non_zero = sum(non_zero_scores)/len(non_zero_scores) if non_zero_scores else 0
    non_zero_ratio = len(non_zero_scores) / len(scores) if scores else 0  # 计算非0比例

    with open(output_path, 'w', encoding='utf-8') as f:
        # 写入明细结果
        for qid in sorted(results.keys()):
            f.write(f"{qid}:{results[qid]:.4f}\n")
        
        # 写入统计信息
        f.write("\n=== Statistical Results ===\n")
        f.write(f"Total samples: {len(data)}\n")
        f.write(f"Valid samples: {len(scores)}\n")
        f.write(f"Non-zero ratio: {non_zero_ratio:.4f}\n")  # 添加非0比例输出
        f.write(f"Average TEDS (All): {avg_all:.4f}\n")
        f.write(f"Average TEDS (Non-zero): {avg_non_zero:.4f}\n")
process_jsonl("/home/lingjun/code/InternVL/internvl_chat/results_qwenvl_1epoch_complex/output.jsonl", "/home/lingjun/code/InternVL/internvl_chat/results_qwenvl_1epoch_complex/output_output.txt")

In [None]:
import json

def filter_by_length(input_path, output_path, max_length=3500):
    """筛选gpt回复长度小于指定值的数据
    
    :param input_path: 输入文件路径
    :param output_path: 输出文件路径
    :param max_length: 最大允许长度（默认3500）
    """
    with open(input_path, 'r') as infile, \
         open(output_path, 'w', encoding='utf-8') as outfile:

        for i, line in enumerate(infile, 1):
            try:
                # 解析JSON对象
                data = json.loads(line)
                valid = True

                # 遍历对话检查长度
                for conv in data.get('conversations', []):
                    if conv.get('from') == 'gpt':
                        if len(conv.get('value', '')) >= max_length:
                            valid = False
                            break  # 发现超长立即终止检查

                # 符合条件则写入新文件
                if valid:
                    outfile.write(line)

            except Exception as e:
                print(f"Error processing line {i}: {str(e)}")
                continue

if __name__ == '__main__':
    input_file = "/home/Datasets/Table2Latex/Table2Latex/train/table2latex_grpo_hard.jsonl"   # 输入文件路径
    output_file = "/home/Datasets/Table2Latex/Table2Latex/train/table2latex_grpo_hard_select.jsonl" # 输出文件路径
    
    filter_by_length(input_file, output_file)
    print("数据筛选完成，结果已保存至", output_file)

In [None]:
import json

def meets_criteria(value):
    # 统计LaTeX多行/多列命令数量
    multirow = value.count('\\multirow')
    multicolumn = value.count('\\multicolumn')
    # 统计&符号数量
    ampersands = value.count('&')
    # 判断条件：多行命令总数≥2 且 &≥100
    return (multirow + multicolumn) >= 2 and ampersands >= 160

def filter_jsonl(input_file, output_file):
    with open(input_file, 'r', encoding='utf-8') as infile, \
         open(output_file, 'w', encoding='utf-8') as outfile:
        
        for line in infile:
            try:
                item = json.loads(line.strip())
                # 检查所有gpt对话内容
                for conv in item.get('conversations', []):
                    if conv.get('from') == 'gpt':
                        if meets_criteria(conv.get('value', '')):
                            # 满足条件则写入新文件
                            outfile.write(json.dumps(item, ensure_ascii=False) + '\n')
                            break  # 找到即停止检查当前item
            except json.JSONDecodeError:
                print(f"Warning: 解析失败的行: {line}")

# 使用示例
filter_jsonl('/home/Datasets/Table2Latex/Table2Latex/train/table2latex.jsonl', '/home/Datasets/Table2Latex/Table2Latex/train/table2latex_grpo_hard.jsonl')

In [None]:
import json

def load_image_set(jsonl2_path):
    """加载第二个JSONL文件中的所有图片路径到集合"""
    image_set = set()
    with open(jsonl2_path, 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line.strip())
            image_set.add(data['image'])
    return image_set

def clean_jsonl(jsonl1_path, jsonl2_path, output_path):
    """清理第一个JSONL文件中重复的图片项"""
    # 读取第二个文件的图片路径
    image_set = load_image_set(jsonl2_path)
    
    # 处理第一个文件并写入结果
    with open(jsonl1_path, 'r', encoding='utf-8') as f_in, \
         open(output_path, 'w', encoding='utf-8') as f_out:
        for line in f_in:
            data = json.loads(line.strip())
            if data['image'] not in image_set:
                f_out.write(json.dumps(data, ensure_ascii=False) + '\n')

# 示例用法
clean_jsonl('/home/Datasets/Table2Latex/Table2Latex/train/table2latex_grpo_hard_select.jsonl', '/home/Datasets/Table2Latex/Table2Latex/separated/complex/table2latex.jsonl', '/home/Datasets/Table2Latex/Table2Latex/train/table2latex_grpo_hard_final.jsonl')

In [None]:
import json
import os
import shutil

def copy_images_from_jsonl(jsonl_path, output_dir):
    """
    从JSONL文件中提取图片路径并复制到目标文件夹
    
    参数:
        jsonl_path (str): JSONL文件路径
        output_dir (str): 图片输出目录
    """
    # 创建输出目录（如果不存在）
    os.makedirs(output_dir, exist_ok=True)
    
    # 统计变量
    copied_count = 0
    error_count = 0
    
    with open(jsonl_path, 'r') as f:
        for line_num, line in enumerate(f, 1):
            try:
                # 解析JSON数据
                data = json.loads(line.strip())
                
                # 获取原始图片路径
                src_path = data['image']
                
                # 生成目标路径
                filename = os.path.basename(src_path)
                dest_path = os.path.join(output_dir, filename)
                
                # 复制文件
                shutil.copy2(src_path, dest_path)
                copied_count += 1
                print(f"行 {line_num}: 成功复制 {filename}")
                
            except KeyError:
                print(f"行 {line_num}: 错误 - 缺少 'image' 字段")
                error_count += 1
            except FileNotFoundError:
                print(f"行 {line_num}: 错误 - 文件不存在 {src_path}")
                error_count += 1
            except json.JSONDecodeError:
                print(f"行 {line_num}: 错误 - JSON解析失败")
                error_count += 1
            except Exception as e:
                print(f"行 {line_num}: 发生意外错误 - {str(e)}")
                error_count += 1
    
    print("\n操作完成:")
    print(f"成功复制文件数: {copied_count}")
    print(f"失败数量: {error_count}")

# 使用示例
if __name__ == "__main__":
    # 配置路径（根据实际情况修改）
    jsonl_file = "/home/Datasets/Table2Latex/Table2Latex/train/table2latex_grpo_hard_select.jsonl"   # 输入的JSONL文件路径
    output_folder = "/home/Datasets/Table2Latex/Table2Latex/train/output_images"  # 输出目录
    
    copy_images_from_jsonl(jsonl_file, output_folder)

In [None]:
from PIL import Image
import cv2
import os
import numpy as np
from numpy.lib.stride_tricks import sliding_window_view  # 添加缺失的导入
import cv2
import numpy as np
from skimage.metrics import structural_similarity as ssim
import matplotlib.pyplot as plt



def dwt2_simple(img_array):
    """Performs a simple 2D discrete wavelet transform using averaging and differencing"""
    rows, cols = img_array.shape
    # 确保图像尺寸为偶数，避免索引错误
    rows = rows // 2 * 2
    cols = cols // 2 * 2
    cA = np.zeros((rows//2, cols//2))
    cH = np.zeros((rows//2, cols//2))
    cV = np.zeros((rows//2, cols//2))
    cD = np.zeros((rows//2, cols//2))

    for i in range(0, rows, 2):
        for j in range(0, cols, 2):
            avg = (img_array[i,j] + img_array[i+1,j] + img_array[i,j+1] + img_array[i+1,j+1]) / 4
            diff_h = (img_array[i,j] + img_array[i,j+1] - img_array[i+1,j] - img_array[i+1,j+1]) / 4
            diff_v = (img_array[i,j] + img_array[i+1,j] - img_array[i,j+1] - img_array[i+1,j+1]) / 4
            diff_d = (img_array[i,j] - img_array[i,j+1] - img_array[i+1,j] + img_array[i+1,j+1]) / 4

            cA[i//2, j//2] = avg
            cH[i//2, j//2] = diff_h
            cV[i//2, j//2] = diff_v
            cD[i//2, j//2] = diff_d

    return cA, cH, cV, cD

def calculate_ssim(img1, img2):
    """Calculate SSIM between two images"""
    mean1, mean2 = np.mean(img1), np.mean(img2)
    var1, var2 = np.var(img1), np.var(img2)
    covar = np.cov(img1.flatten(), img2.flatten())[0, 1]

    C1 = (0.01 * 255)**2
    C2 = (0.03 * 255) ** 2

    numerator = (2 * mean1 * mean2 + C1) * (2 * covar + C2)
    denominator = (mean1**2 + mean2**2 + C1) * (var1 + var2 + C2)
    return numerator / denominator if denominator != 0 else 0.0

def calculate_cwssim(image1_path, image2_path):
    """Compares two images using Complex Wavelet Structural Similarity Index (CW-SSIM)"""
    # 统一调整为256x256并转为灰度
    image1 = Image.open(image1_path).convert('L').resize((256, 256))
    image2 = Image.open(image2_path).convert('L').resize((256, 256))
    
    img1_array = np.array(image1, dtype=np.float64)
    img2_array = np.array(image2, dtype=np.float64)
    
    # 执行小波变换
    cA1, cH1, cV1, cD1 = dwt2_simple(img1_array)
    cA2, cH2, cV2, cD2 = dwt2_simple(img2_array)
    
    # 计算各子带的SSIM
    ssim_scores = [
        calculate_ssim(cA1, cA2),
        calculate_ssim(cH1, cH2),
        calculate_ssim(cV1, cV2),
        calculate_ssim(cD1, cD2)
    ]
    
    # 取平均值作为CW-SSIM结果
    cwssim_score = np.mean(ssim_scores)
    return max(0.0, min(cwssim_score, 1.0))
    
def calculate_similarity(image1_path, image2_path):
    # 读取图像
    image1 = cv2.imread(image1_path, cv2.IMREAD_GRAYSCALE)
    image2 = cv2.imread(image2_path, cv2.IMREAD_GRAYSCALE)

    # 检查图像是否加载成功
    if image1 is None or image2 is None:
        raise ValueError("无法加载图像，请检查路径是否正确。")

    # 确保图像大小相同
    if image1.shape != image2.shape:
        # 调整图像大小到较小的那个
        min_height = min(image1.shape[0], image2.shape[0])
        min_width = min(image1.shape[1], image2.shape[1])
        image1 = cv2.resize(image1, (min_width, min_height))
        image2 = cv2.resize(image2, (min_width, min_height))

    # 计算均方误差 (MSE)
    mse = np.mean((image1 - image2) ** 2)

    # 计算峰值信噪比 (PSNR)
    max_pixel = 255.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))

    # 计算结构相似性指数 (SSIM)
    ssim_index, ssim_image = ssim(image1, image2, full=True)
    ssim_image = (ssim_image * 255).astype("uint8")

    # 边缘检测
    edges1 = cv2.Canny(image1, 100, 200)
    edges2 = cv2.Canny(image2, 100, 200)
    edge_mse = np.mean((edges1 - edges2) ** 2)
    edge_psnr = 20 * np.log10(max_pixel / np.sqrt(edge_mse))

    # 综合评估
    similarity_score = (ssim_index * 0.5) + (psnr / 100 * 0.3) + (edge_psnr / 100 * 0.2)

    # 显示结果
    plt.figure(figsize=(12, 6))
    plt.subplot(2, 3, 1), plt.imshow(image1, cmap='gray')
    plt.title('Image 1'), plt.axis('off')
    plt.subplot(2, 3, 2), plt.imshow(image2, cmap='gray')
    plt.title('Image 2'), plt.axis('off')
    plt.subplot(2, 3, 3), plt.imshow(ssim_image, cmap='gray')
    plt.title('SSIM Map'), plt.axis('off')
    plt.subplot(2, 3, 4), plt.imshow(edges1, cmap='gray')
    plt.title('Edges 1'), plt.axis('off')
    plt.subplot(2, 3, 5), plt.imshow(edges2, cmap='gray')
    plt.title('Edges 2'), plt.axis('off')
    plt.subplot(2, 3, 6), plt.imshow(cv2.absdiff(edges1, edges2), cmap='gray')
    plt.title('Edge Difference'), plt.axis('off')
    plt.tight_layout()
    plt.show()

    # 打印结果
    print(f"均方误差 (MSE): {mse:.2f}")
    print(f"峰值信噪比 (PSNR): {psnr:.2f} dB")
    print(f"结构相似性指数 (SSIM): {ssim_index:.4f}")
    print(f"边缘均方误差 (Edge MSE): {edge_mse:.2f}")
    print(f"边缘峰值信噪比 (Edge PSNR): {edge_psnr:.2f} dB")
    print(f"综合相似性分数: {similarity_score:.4f}")

    return similarity_score
def run_tests():
    
    try:


        # 测试1: 相同图像
        calculate_similarity("/home/lingjun/code/InternVL/internvl_chat/test_images/6_annotation.png", "/home/lingjun/code/InternVL/internvl_chat/test_images/6_annotation.png")
        score = calculate_cwssim("/home/lingjun/code/InternVL/internvl_chat/test_images/6_annotation.png", "/home/lingjun/code/InternVL/internvl_chat/test_images/6_annotation.png")
        print(f"测试1（相同图像）: 得分为{score:.4f}")
        try:
            assert abs(score - 1.0) < 0.01
            print("测试1（相同图像）: 通过")
        except AssertionError:
            print(f"测试1（相同图像）: 失败，得分为{score:.4f}")

        # 测试2: 不同图像
        try:
            calculate_similarity("/home/lingjun/code/InternVL/internvl_chat/test_images/6_annotation.png", "/home/lingjun/code/InternVL/internvl_chat/test_images/Table conv 1.png")
            score = calculate_cwssim("/home/lingjun/code/InternVL/internvl_chat/test_images/6_annotation.png", "/home/lingjun/code/InternVL/internvl_chat/test_images/Table conv 1.png")
            print(f"测试2（不同图像）: 得分为{score:.4f}")
        except AssertionError:
            print(f"测试2（不同图像）: 失败，得分为{score:.4f}")

    finally:

        print("清理完成")


run_tests()