In [None]:
%load_ext autoreload
%autoreload 2
%reload_ext autoreload
from src.utils import ROOT_DIR
from src.dataset import CustomDataset
from src.models import Flamingo0S
from src.utils import load_json
from open_flamingo import create_model_and_transforms
import torch
import os
import matplotlib.pyplot as plt
from torchvision.transforms.functional import to_pil_image as to_pil

In [None]:
# Load the dataset
dataset = CustomDataset(
    csv_file=os.path.join(ROOT_DIR, "data", "MMHS150K", "MMHS150K_text_in_image.csv"),
    img_dir=os.path.join(ROOT_DIR, "data", "MMHS150K", "img_resized/"),
    split="test"
)

In [None]:
LANG_MODEL_PATH=os.path.join(ROOT_DIR, 'data','pretrained_models','RedPajama-INCITE-Base-3B-v1')
CACHE_MODEL= os.path.join(ROOT_DIR, 'data','pretrained_models')
FLAMINGO_MODEL_PATH=os.path.join(ROOT_DIR, 'data','pretrained_models','OpenFlamingo-3B-vitl-mpt1b', 'checkpoint.pt' )

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path=LANG_MODEL_PATH,
    tokenizer_path=LANG_MODEL_PATH,
    cross_attn_every_n_layers=2,
    cache_dir=os.path.join(ROOT_DIR, 'data','pretrained_models')
)
model.load_state_dict(torch.load(FLAMINGO_MODEL_PATH), strict=False)

In [None]:
f=plt.figure()
for i in range(0, 6):
    f.add_subplot(2, 3, i + 1)
    plt.imshow(to_pil(dataset[i]['image']))
    plt.axis('off')

In [None]:
model = Flamingo0S(
    config_path=os.path.join(ROOT_DIR, "data", "config", "config_Flamingo0S.json")
)

# Exemple from the src implementation

In [None]:
from PIL import Image
import requests
import torch
from src.dataset import CustomDataset

"""
Step 1: Load images
"""
demo_image_one = dataset[0]['image']
query_image = dataset[1]['image']
vision_x = [image_processor(demo_image_one).unsqueeze(0),image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0).unsqueeze(1).unsqueeze(0)


tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>This is a hateful meme. We consider it as hateful due to the fact that it targets a specific community<|endofchunk|><image>This meme is"],
    return_tensors="pt",
)

In [None]:
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=3,
)
print("Generated text: ", tokenizer.decode(generated_text[0]))

In [None]:
model.initialize_prompt(dataset)