In [1]:
# !pip install git+https://github.com/huggingface/transformers
!pip uninstall -y transformers
# !pip install --no-deps /kaggle/input/certifi/certifi-2022.12.7-py3-none-any.whl
!mkdir temp && cp -r /kaggle/input/transformers-main-09082023 temp/transformers && cd temp/transformers && python setup.py develop --no-deps

Found existing installation: transformers 4.27.4
Uninstalling transformers-4.27.4:
  Successfully uninstalled transformers-4.27.4
[0mrunning develop
running egg_info
creating src/transformers.egg-info
writing src/transformers.egg-info/PKG-INFO
writing dependency_links to src/transformers.egg-info/dependency_links.txt
writing entry points to src/transformers.egg-info/entry_points.txt
writing requirements to src/transformers.egg-info/requires.txt
writing top-level names to src/transformers.egg-info/top_level.txt
writing manifest file 'src/transformers.egg-info/SOURCES.txt'
reading manifest file 'src/transformers.egg-info/SOURCES.txt'
reading manifest template 'MANIFEST.in'
adding license file 'LICENSE'
writing manifest file 'src/transformers.egg-info/SOURCES.txt'
running build_ext
Creating /opt/conda/lib/python3.7/site-packages/transformers.egg-link (link to src)
Adding transformers 4.28.0.dev0 to easy-install.pth file
Installing transformers-cli script to /opt/conda/

In [2]:
!pip show transformers

Name: transformers
Version: 4.28.0.dev0
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /kaggle/working/temp/transformers/src
Requires: filelock, huggingface-hub, importlib_metadata, numpy, packaging, pyyaml, regex, requests, tokenizers, tqdm
Required-by: 


In [3]:
import sys
sys.path.insert(0, "/kaggle/working/temp/transformers/src/")

In [4]:
import pandas as pd
import io
import torch
from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor
import requests
from PIL import Image
import re

### Deplot model

In [5]:
def display_deplot_output(deplot_output, visualize=True):
    '''
    The raw output of deplot
    TITLE | Rural population (%) long-run with 2050 projections<0x0A>(OWID) in Greece<0x0A>Years | Rural population<0x0A>1940 | 47.38<0x0A>1960 | 43.68<0x0A>1980 | 30.28<0x0A>...
    '''

    # x1,x2 <> y1 y2
    # x1 | y1 <0x0A> x2 | y2 <0x0A>
    # label xxxxyyy -> xyxyxyxy
    deplot_output = deplot_output.replace("<0x0A>", "\n").replace(" | ", "\t")

    second_a_index = [m.start() for m in re.finditer('\t', deplot_output)][1]
    last_newline_index = deplot_output.rfind('\n', 0, second_a_index) 

    title = deplot_output[:last_newline_index]
    table = deplot_output#[last_newline_index+1:]

    data = io.StringIO(table)
    df = pd.read_csv(data, sep='\t', names=['x', 'y'])
    if visualize:
        display(df)
    
    return df
    

def deplot(path, model, processor, device, visualize=True):

    image = Image.open(path)
    if visualize:
        display(image)
    inputs = processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt", is_vqa=False)

    # Move inputs to GPU
    inputs = {key: value.to(device) for key, value in inputs.items()}

    predictions = model.generate(**inputs, max_new_tokens=512)
    return processor.decode(predictions[0], skip_special_tokens=True)


In [6]:
deplot_weights_path = '/kaggle/input/matcha-base/matcha-base'
model- path = '你自己上传的模型权重。bin文件的路径'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
deplot_model = Pix2StructForConditionalGeneration.from_pretrained(deplot_weights_path).to(device)
deplot_model.load_state_dict(torch.load(model- path))
processor = Pix2StructProcessor.from_pretrained(deplot_weights_path)
# processor.is_vqa = False

In [7]:
def deplot_inference(image_path, visualize):
    deplot_output=deplot(image_path, deplot_model, processor, device, visualize)
    return display_deplot_output(deplot_output, visualize)

In [8]:
# deplot_inference("/kaggle/input/benetech-making-graphs-accessible/test/images/00dcf883a459.jpg", visualize=True)

### Classification model

In [9]:
import torch
import torchvision
import torch.utils.data
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

import cv2
import numpy as np

In [10]:
label_map = {'dot': 0, 'horizontal_bar' : 1, 'vertical_bar': 2, 'line': 3, 'scatter': 4}
label_idx_to_classname = {v: k for k, v in label_map.items()}


In [11]:
classification_model = torchvision.models.resnet50(pretrained=False)

num_features = classification_model.fc.in_features
classification_model.fc = nn.Linear(num_features, 5)
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
classification_model = classification_model.to(device)
classification_model.eval()

state_dict = torch.load("/kaggle/input/deplot/Benetech _ResNet50_fold0.pth")
classification_model.load_state_dict(state_dict)

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


<All keys matched successfully>

In [12]:
val_transforms = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

def classification_inference(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(500,300))
    img = img.astype(np.float32)/255.0
    img = val_transforms(img)
    
#     print(img.shape)
    
#     inp = torch.stack([img.permute(1, 2, 0)]).cuda()
    inp = torch.stack([img]).cuda()
    out = classification_model(inp).detach().cpu().numpy()
    return label_idx_to_classname[np.argmax(out)]

### Inference pipeline

In [13]:
import os
from tqdm import tqdm
import math

IMAGE_FOLDER = "/kaggle/input/benetech-making-graphs-accessible/test/images"

all_ids = []
all_values = []
all_chart_types = []

for image_name in tqdm(os.listdir(IMAGE_FOLDER)):
    if ".jpg" not in image_name:
        continue
    try:
        image_path = os.path.join(IMAGE_FOLDER, image_name)

        graph_type = classification_inference(image_path)
        inference_df = deplot_inference(image_path, False)
        
        x_values = inference_df[inference_df.columns[0]].values
        y_values = inference_df[inference_df.columns[1]].values
        
        # check value type for each graph type
        if graph_type in ["horizontal_bar", "vertical_bar", "line", "dot", "scatter"]:
            # x is categorical, y is numerical
            formatted_y_values = []
            for v in y_values:
                try:
                    float(v)
                    if math.isnan(float(v)):
                        formatted_y_values.append(0)
                    else:
                        formatted_y_values.append(v)
                except:
                    formatted_y_values.append(0)
                    
            y_values = formatted_y_values
        
        if graph_type in ["dot", "scatter"]:
            # for dot graph, x could be categorical as well as numerical, how to handle this?
            # just leave all the value be numerical for now

            formatted_x_values = []
            for v in x_values:
                try:
                    float(v)
                    if math.isnan(float(v)):
                        formatted_x_values.append(0)
                    else:
                        formatted_x_values.append(v)
                except:
                    formatted_x_values.append(0)
                    
            x_values = formatted_x_values
            
        
        length = min(len(x_values), len(y_values))
        x_values = ";".join([str(v).strip() for v in x_values][:length])
        y_values = ";".join([str(v).strip() for v in y_values][:length])
    except Exception as e:
        print("Exception", e)
        graph_type = "line"
        x_values = "0;0"
        y_values = "0;0"
    
    image_id = image_name.split(".")[0]

    all_ids.append(image_id + "_x")
    all_values.append(x_values)
    
    all_ids.append(image_id + "_y")
    all_values.append(y_values)
    
    all_chart_types.extend([graph_type, graph_type])

  0%|          | 0/5 [00:00<?, ?it/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
 20%|██        | 1/5 [00:06<00:25,  6.39s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
 40%|████      | 2/5 [00:14<00:21,  7.18s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
 60%|██████    | 3/5 [00:16<00:09,  4.93s/it]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
 80%|████████  | 4/5 [00:17<00:03,  3.41s/it]A decoder-only architecture is being used, but right-padding was detected! For correct gene

In [14]:
!rm -rf temp

In [15]:
submission_df = pd.DataFrame({
    "id": all_ids,
    "data_series": all_values,
    "chart_type": all_chart_types
})

submission_df.to_csv("submission.csv", index=False)
submission_df

Unnamed: 0,id,data_series,chart_type
0,000b92c3b098_x,0;6;12;18;24,line
1,000b92c3b098_y,0.0;1.32;2.62;1.94;3.24,line
2,01b45b831589_x,21-Feb;22-Feb;23-Feb;24-Feb;25-Feb;26-Feb;27-F...,vertical_bar
3,01b45b831589_y,89000;151000;172000;177000;137000;99000;0;4150...,vertical_bar
4,00f5404753cf_x,3;4;5;6;7;8;9;10;11,line
5,00f5404753cf_y,14.0;13.8;22.0;26.0;25.9;27.0;22.0;13.6;13.0,line
6,00dcf883a459_x,Group 1;Group 2,vertical_bar
7,00dcf883a459_y,3.6;8.4,vertical_bar
8,007a18eb4e09_x,2013;2014;2015;2016;2017;2018;2019;2020;2021;2...,line
9,007a18eb4e09_y,0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0;0.0,line
