In [1]:
from pathlib import Path
import os
import sys
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
import torch
from datasets import load_from_disk
sys.path.append(str(Path.cwd().parent))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.utils.dirutils import get_data_dir, get_models_dir
from src.models.multiclassification.predict_model import ViTForMultiClassificationPredictor

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
multiclassification_model = ViTForMultiClassificationPredictor(
    get_models_dir() / "multiclassification" / "full" / "model-20230513_121917-35.pt",
    DEVICE
)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTModel: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
multiclassification_model.model.log_vars

Parameter containing:
tensor([-2.0949, -2.0840, -2.4154, -5.1525, -5.7846], device='cuda:0',
       requires_grad=True)

In [5]:
dataset = load_from_disk(get_data_dir() / "processed" / "captioning_dataset_augmented_processed")

In [6]:
def capitalize_artist(s):
    return " ".join([word.capitalize() for word in s.split("-")])

In [7]:
def multiclassification_prediction_to_prompt(prediction):
    prompt = ""
    multiclass_preds = prediction[0]
    multilabel_preds = prediction[1]
    for feature, pred in multiclass_preds.items():
        if feature == "artist":
            if pred[0] == "other":
                continue
            else:
                # remove dash and capitalize first letter of each word
                pred = capitalize_artist(pred[0])
                prompt += f"{pred} "
        else:
            prompt += f"{pred[0].capitalize()} "
    # for feature, pred in multilabel_preds.items():
    #     for label in pred:
    #         prompt += f"{label} "
    return prompt.strip()

def multiclassification_prediction_to_caption(prediction):
    caption = ""
    multiclass_preds = prediction[0]
    multilabel_preds = prediction[1]

    caption += f"The artwork is a {multiclass_preds['genre'][0].capitalize()}"
    caption += f", in the style of {multiclass_preds['style'][0].capitalize()}."

    if multiclass_preds["artist"][0] != "other":
        artist_pred = capitalize_artist(multiclass_preds["artist"][0])
        caption += f" It could be by {artist_pred}."
    
    if multilabel_preds["media"]:
        caption += f" The used media are {', '.join(multilabel_preds['media'])}."

    if multilabel_preds["tags"]:
        caption += f" It is about {', '.join(multilabel_preds['tags'])}."

    return caption.strip()

In [8]:
MODEL_NAME = "microsoft/git-base"
PROCESSOR = AutoProcessor.from_pretrained(MODEL_NAME)
MODEL = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
OUTPUT_DIR = "microsoft_git-base_artgraph"

checkpoint = torch.load(get_models_dir() / "captioning" / "microsoft-git-base-good-samples" / "2.pt")
MODEL.load_state_dict(checkpoint["model_state_dict"])
MODEL.to(DEVICE)
MODEL.train(False)

GitForCausalLM(
  (git): GitModel(
    (embeddings): GitEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(1024, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (image_encoder): GitVisionModel(
      (vision_model): GitVisionTransformer(
        (embeddings): GitVisionEmbeddings(
          (patch_embedding): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16), bias=False)
          (position_embedding): Embedding(197, 768)
        )
        (pre_layrnorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (encoder): GitVisionEncoder(
          (layers): ModuleList(
            (0-11): 12 x GitVisionEncoderLayer(
              (self_attn): GitVisionAttention(
                (k_proj): Linear(in_features=768, out_features=768, bias=True)
                (v_proj): Linear(in_features=768, out_features=768, bias=True)
             

In [9]:
# create a gradio interface with image input and text output
import gradio as gr

def captioning_pipeline(image):
    prediction = multiclassification_model.predict(image)
    caption = multiclassification_prediction_to_caption(prediction)
    pixel_values = PROCESSOR(images=image, return_tensors="pt").pixel_values.to(DEVICE, torch.float16)
    generated_ids = MODEL.generate(pixel_values=pixel_values, min_length=12, max_length=50, num_beams=4, no_repeat_ngram_size=2, do_sample=False, temperature=1.5)
    generated_caption = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
    caption += f" The artwork depicts {generated_caption}."
    return caption

def prompt_pipeline(prompt):
    prediction = multiclassification_model.predict(prompt)
    prompt = multiclassification_prediction_to_prompt(prediction)
    return prompt

image = gr.components.Image()
caption = gr.components.Textbox()

gr.Interface(
    captioning_pipeline,
    image,
    caption,
    title="Artwork Captioning",
    description="Generate a caption for an artwork.",
    examples=[
        ["https://uploads2.wikiart.org/images/ivan-aivazovsky/moonlit-night-beside-the-sea-1847.jpg!Large.jpg"],
        ["https://uploads4.wikiart.org/images/georges-seurat/sunday-afternoon-on-the-island-of-la-grande-jatte-1886.jpg!Large.jpg"],
        ["https://uploads2.wikiart.org/images/edvard-munch/the-scream-1893(2).jpg!Large.jpg"]
    ]
    ).launch(share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://850a3165d124aaa807.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades (NEW!), check out Spaces: https://huggingface.co/spaces


