<a href="https://colab.research.google.com/github/nyanta012/ERNIE-ViLG/blob/main/ERNIE_ViLG_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install paddlehub paddlepaddle gradio

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting paddlehub
  Downloading paddlehub-2.3.0-py3-none-any.whl (212 kB)
[K     |████████████████████████████████| 212 kB 29.0 MB/s 
[?25hCollecting paddlepaddle
  Downloading paddlepaddle-2.3.2-cp37-cp37m-manylinux1_x86_64.whl (112.5 MB)
[K     |████████████████████████████████| 112.5 MB 47 kB/s 
[?25hCollecting gradio
  Downloading gradio-3.2-py3-none-any.whl (6.1 MB)
[K     |████████████████████████████████| 6.1 MB 41.0 MB/s 
Collecting paddlenlp>=2.0.0
  Downloading paddlenlp-2.3.7-py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 46.4 MB/s 
[?25hCollecting paddle2onnx>=0.5.1
  Downloading paddle2onnx-1.0.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 42.9 MB/s 
[?25hCollecting colorlog
  Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)
Collecting colorama
  Downloading 

In [None]:
import gradio as gr
import paddlehub as hub

class Predictor:
    def __init__(self):
        self.style_list = ["水彩", "油画", "粉笔画", "卡通", "蜡笔画", "儿童画", "探索无限"]
        self.language_list = ["jp", "en", "zh"]
        self.model = hub.Module(name="ernie_vilg")
        self.language_model = hub.Module(name="baidu_translate")

    def predict(self, prompt, index_lang, index_style):
        try:
            if self.language_list[index_lang] != "zh":
                language = self.language_list[index_lang]
                prompt = self.language_model.translate(prompt, language, "zh")
            results = self.model.generate_image(
                text_prompts=prompt,
                style=self.style_list[index_style],
                visualization=False,
            )
            return "Success", results[:6]
        except Exception as e:
            error_text = str(e)
            error_text = self.language_model.translate(error_text, "zh", "jp")
            return error_text, None

css = """
        .gradio-container {
            font-family: 'IBM Plex Sans', sans-serif;
        }
        .container {
            max-width: 730px;
            margin: auto;
            padding-top: 1.5rem;
        }
        #gallery {
            min-height: 22rem;
            margin-bottom: 15px;
            margin-left: auto;
            margin-right: auto;
            border-bottom-right-radius: .5rem !important;
            border-bottom-left-radius: .5rem !important;
        }
        #gallery>div>.h-full {
            min-height: 20rem;
        }

"""


block = gr.Blocks(css=css)


with block:
    gr.Markdown("ERNIE-ViLG Demo")
    runner = Predictor()
    with gr.Group():
        with gr.Box():
            with gr.Row(mobile_collapse=False, equal_height=True):
                prompt = gr.Textbox(
                    label="prompt",
                    show_label=False,
                    placeholder="prompt(日本語or英語or中国語)",
                    max_lines=1,
                ).style(
                    border=(True, False, True, True),
                    rounded=(True, False, False, True),
                    container=False,
                )
                text_button = gr.Button("Generate Image").style(
                    margin=False,
                    rounded=(False, True, True, False),
                )
        language = gr.Dropdown(label="言語", choices=["日本語", "英語", "中国語"], type="index")
        style = gr.Dropdown(
            label="作品スタイル",
            choices=["水彩画", "油絵", "チョーク画", "漫画", "クレヨン画", "児童画", "探索"],
            type="index",
        )
        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery"
        ).style(grid=[2, 3], height="auto")

        status_text = gr.Textbox(
            label="Process status", show_label=True, max_lines=1, interactive=False
        )

        text_button.click(
            runner.predict,
            inputs=[prompt, language, style],
            outputs=[status_text, gallery],
        )

block.launch(debug=True)



Download https://bj.bcebos.com/paddlehub/paddlehub_dev/ernie_vilg.tar.gz
[##################################################] 100.00%
Decompress /root/.paddlehub/tmp/tmp_byqwwbv/ernie_vilg.tar.gz
[##################################################] 100.00%






[32m[2022-09-02 13:06:36,029] [    INFO][0m - Successfully installed dependent packages.[0m
[32m[2022-09-02 13:06:36,337] [    INFO][0m - Successfully installed ernie_vilg-1.0.0[0m


Download https://bj.bcebos.com/paddlehub/paddlehub_dev/baidu_translate.tar.gz
[##################################################] 100.00%
Decompress /root/.paddlehub/tmp/tmpu6pqcpr2/baidu_translate.tar.gz
[##################################################] 100.00%


[32m[2022-09-02 13:06:41,424] [    INFO][0m - Successfully installed baidu_translate-1.0.0[0m


Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://39151.gradio.app

This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces


  0%|          | 0/100 [00:00<?, ?%/s]

Saving Images...
Done
