In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.generation import GenerationConfig
import torch
import csv
import os
torch.manual_seed(1234)

In [None]:
# Variables and directories
qa_output_dir = "qa_output/"

table_dir = "extracted_tables/"
table_image_dir = table_dir + "table_images/"
table_code_dir = table_dir + "table_code/"
table_metadata = table_dir + "tables.csv"

figure_dir = "extracted_figures/"
figure_metadata = figure_dir + "figures.csv"

os.makedirs(qa_output_dir, exist_ok=True)

In [None]:
# Define QA-Pair-Generation Prompts
table_prompt = "Generate an open-ended question and its corresponding answer based on a scientific table. Use its caption and text mentions \
from the scientific paper to create a question that tests the understanding of this specific table. Also, include a difficulty level, either \
“easy” or \ “hard”, where easy indicates that little reasoning is required, and hard indicates that complex reasoning is required to answer the \
question. \
Table: {table_code} \
Caption: {caption} \
Text mentions: {text_mentions} \
Question: \
Answer: \
Difficulty: "

figure_prompt = "Generate an open-ended question and its corresponding answer based on a scientific figure. Use its caption and text mentions \
from the scientific paper to create a question that tests the understanding of this specific figure. Also, include a difficulty level, either \
“easy” or \ “hard”, where easy indicates that little reasoning is required, and hard indicates that complex reasoning is required to answer the \
question. \
Caption: {caption} \
Text mentions: {text_mentions} \
Question: \
Answer: \
Difficulty: "

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True).eval()

query = tokenizer.from_list_format([
    {'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
    {'text': 'Explain me this picture'},
])
response, history = model.chat(tokenizer, query=query, history=None)
print(response)

In [None]:
# Generate a qa_pair, either for a figure or a table
def generate_qa_pair(object_id, image_file, caption, text_mentions, table_code=None):
    # Modifying the prompt
    prompt = None
    if table_code:
        prompt = table_prompt.replace("{caption}", caption).replace("{text_mentions}", text_mentions).replace("{table_code}", table_code)
    else:
        prompt = figure_prompt.replace("{caption}", caption).replace("{text_mentions}", text_mentions)

    # Executing the query
    query = tokenizer.from_list_format([
        {'image': 'https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg'},
        {'text': 'Explain me this picture'},
    ])

    # Receiving the results and store them in file
    response, history = model.chat(tokenizer, query=query, history=None)

    return response

In [None]:
# Read either figure or table data from csv file
def get_object_data(meta_file, start_index, end_index, table=False):
    object_data = {}
    with open(meta_file, "r", newline='', encoding='utf-8') as csv_file:
        spamreader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)

        index = 0
        for row in spamreader:
            if index >= start_index:
                object_id = row[0]
                caption = row[4]
                text_mentions = row[5]

                if table: # For tables
                    try:
                        table_code = get_table_code(table_code_dir + figure_id + ".tex")
                        object_data[object_id] = (caption, text_mentions, table_code)
                    except Exception as e:
                        print(f"Error occurred for index {index}: {e}")
                else: # For figures
                    object_data[object_id] = (caption, text_mentions)
            index += 1
            if index > end_index:
                break

    return object_data

# Return table code
def get_table_code(code_file):
    table_code = None
    if os.path.isfile(code_path):
        with open(code_path, "r", encoding='utf-8') as code_file:
            table_code = code_file.read()
            splitted_code = table_code.split("\pagenumbering{gobble}")
            if len(splitted_code) != 2:
                raise ValueError(f"Unexpected occurrence of pagenumbering. Please check manually {code_file}")
            table_code = splitted_code[-1]
            table_code = table_code.replace("\end{document}", "")
    else:
        raise FileNotFoundError(f"{code_file} was not found.")
    return table_code

In [None]:
# Execute whole QA generation for either figures or tables, following a range of indexes in the metadata file
def execute_generation(start_index, end_index, table=False):
    print("Data extraction from csv file started.")
    object_data = None
    if table:
        object_data = get_object_data(table_metadata, start_index, end_index, True)
    else:
        object_data = get_object_data(figure_metadata, start_index, end_index)

    print("QA-pair generation started.")
    counter = 0
    for obj in object_data:
        image_file = None
        table_code = None
        caption = object_data[obj][0]
        text_mentions = object_data[obj][1]
        if table:
            image_file = table_image_dir + obj + ".png"
            table_code = object_data[obj][2]
        else:
            image_file = figure_image_dir + obj + ".png"
            
        response = generate_qa_pair(obj, image_file, caption, text_mentions, table_code)
        output_file = qa_output_dir + obj + ".txt
        with open(output_file, "w", encoding='utf-8') as output:
            output.write(response)

        counter += 1
        if counter % int(len(counter)/10) == 0:
            print("{counter} objects have been processed.")

    print("Process complete.")

In [None]:
# Specify model and function parameters
model_id = "Qwen/Qwen-VL-Chat"
s_index = 0
e_index = 15
is_table = True

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

# Execute QA generation
execute_generation(s_index, e_index, is_table)