# (Prompt base + Image caption) model


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q openprompt
!pip install -q ftfy regex tqdm
!pip install -q git+https://github.com/openai/CLIP.git
!pip install -q git+https://github.com/jasonnoy/BLIP.git
!git clone https://github.com/jasonnoy/BLIP.git

In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

In [None]:
import clip
import hashlib
import math
import numpy as np
import os
import pickle
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

from models.blip import blip_decoder
from PIL import Image
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

print("Loading BLIP model...")
blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='large', med_config='./BLIP/configs/med_config.json')
blip_model.eval()
blip_model = blip_model.to(device)

print("Loading CLIP model...")
clip_model_name = 'ViT-L/14' # https://huggingface.co/openai/clip-vit-large-patch14
clip_model, clip_preprocess = clip.load(clip_model_name, device=device)
clip_model.to(device).eval()

chunk_size = 2048
flavor_intermediate_count = 256

In [None]:
print("Loading BLIP model...")
blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_caption.pth'
blip_model = blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='large', med_config='./BLIP/configs/med_config.json')
blip_model.eval()
blip_model = blip_model.to(device)

# Load data

In [None]:
# Code to download file into Colaboratory:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials
import csv
import torch
import pandas as pd
from PIL import Image
# Authenticate
drive = None
def authenticate():
    global drive
    auth.authenticate_user()
    gauth = GoogleAuth()
    gauth.credentials = GoogleCredentials.get_application_default()
    drive = GoogleDrive(gauth)
#Download files
def downloadFiles(fileIds):
    authenticate()
    for fileId in fileIds:
        downloaded = drive.CreateFile({"id": fileId[1]})
        downloaded.GetContentFile(fileId[0])
#Download file if not existing

try:
  _ = open("harmc_images.zip", "r")
except:
  downloadFiles([["harmc_images.zip", "11hvQxSQSqAPekKwHoE4WlC1KgXgeYWru"]])
try:
  _ = open("parts.zip", "r")
except:
  downloadFiles([["parts.zip", "1ydtbt8jIA0PFQHIV8Ee4L93Vt1FxmC-d"]])
# https://drive.google.com/file/d/1ydtbt8jIA0PFQHIV8Ee4L93Vt1FxmC-d/view?usp=share_link

In [None]:
try:
  _ = open("/content/images/covid_memes_2.png", "r")
except:
  !unzip -q harmc_images.zip

try:
  _ = open("/content/flavors.txt", "r")
except:
  !unzip -q parts.zip

In [None]:
image_dirs = os.listdir("./images")
image_ids = [f[:-4] for f in image_dirs]

In [None]:
imgs={}
for i, img_dir in enumerate(image_dirs):
  imgs[image_ids[i]] = Image.open(os.path.join("./images/", img_dir)).convert("RGB")
print("image number:", len(imgs))

## CLIP-Interrogator
by pharmapsychotic
opensource: https://huggingface.co/spaces/pharma/CLIP-Interrogator/tree/main

In [None]:
class LabelTable():
    def __init__(self, labels, desc):
        self.labels = labels
        self.embeds = []

        hash = hashlib.sha256(",".join(labels).encode()).hexdigest()

        os.makedirs('./cache', exist_ok=True)
        cache_filepath = f"./cache/{desc}.pkl"
        if desc is not None and os.path.exists(cache_filepath):
            with open(cache_filepath, 'rb') as f:
                data = pickle.load(f)
                if data['hash'] == hash:
                    self.labels = data['labels']
                    self.embeds = data['embeds']

        if len(self.labels) != len(self.embeds):
            self.embeds = []
            chunks = np.array_split(self.labels, max(1, len(self.labels)/chunk_size))
            for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None):
                text_tokens = clip.tokenize(chunk).to(device)
                with torch.no_grad():
                    text_features = clip_model.encode_text(text_tokens).float()
                text_features /= text_features.norm(dim=-1, keepdim=True)
                text_features = text_features.half().cpu().numpy()
                for i in range(text_features.shape[0]):
                    self.embeds.append(text_features[i])

            with open(cache_filepath, 'wb') as f:
                pickle.dump({"labels":self.labels, "embeds":self.embeds, "hash":hash}, f)

    def _rank(self, image_features, text_embeds, top_count=1):
        top_count = min(top_count, len(text_embeds))
        similarity = torch.zeros((1, len(text_embeds))).to(device)
        text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).float().to(device)
        for i in range(image_features.shape[0]):
            similarity += (image_features[i].unsqueeze(0) @ text_embeds.T).softmax(dim=-1)
        _, top_labels = similarity.cpu().topk(top_count, dim=-1)
        return [top_labels[0][i].numpy() for i in range(top_count)]

    def rank(self, image_features, top_count=1):
        if len(self.labels) <= chunk_size:
            tops = self._rank(image_features, self.embeds, top_count=top_count)
            return [self.labels[i] for i in tops]

        num_chunks = int(math.ceil(len(self.labels)/chunk_size))
        keep_per_chunk = int(chunk_size / num_chunks)

        top_labels, top_embeds = [], []
        for chunk_idx in range(num_chunks):
            start = chunk_idx*chunk_size
            stop = min(start+chunk_size, len(self.embeds))
            tops = self._rank(image_features, self.embeds[start:stop], top_count=keep_per_chunk)
            top_labels.extend([self.labels[start+i] for i in tops])
            top_embeds.extend([self.embeds[start+i] for i in tops])

        tops = self._rank(image_features, top_embeds, top_count=top_count)
        return [top_labels[i] for i in tops]

def generate_caption(pil_image):
    gpu_image = T.Compose([
        T.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=TF.InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])(pil_image).unsqueeze(0).to(device)

    with torch.no_grad():
        caption = blip_model.generate(gpu_image, sample=False, num_beams=3, max_length=20, min_length=5)
    return caption[0]

def load_list(filename):
    with open(filename, 'r', encoding='utf-8', errors='replace') as f:
        items = [line.strip() for line in f.readlines()]
    return items

def rank_top(image_features, text_array):
    text_tokens = clip.tokenize([text for text in text_array]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

    similarity = torch.zeros((1, len(text_array)), device=device)
    for i in range(image_features.shape[0]):
        similarity += (image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)

    _, top_labels = similarity.cpu().topk(1, dim=-1)
    return text_array[top_labels[0][0].numpy()]

def similarity(image_features, text):
    text_tokens = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
    return similarity[0][0]

def interrogate(image):
    caption = generate_caption(image)

    images = clip_preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = clip_model.encode_image(images).float()
    image_features /= image_features.norm(dim=-1, keepdim=True)

    flaves = flavors.rank(image_features, flavor_intermediate_count)
    best_entity = entities.rank(image_features, 20)[0]
    # best_celebrity = celebrities.rank(image_features, 3)[0]
    # best_event = events.rank(image_features, 3)[0]
    # best_medium = mediums.rank(image_features, 3)[0]
    # best_artist = artists.rank(image_features, 3)[0]
    # best_trending = trendings.rank(image_features, 3)[0]
    # best_movement = movements.rank(image_features, 3)[0]

    best_prompt = caption
    best_sim = similarity(image_features, best_prompt)

    def check(addition):
        nonlocal best_prompt, best_sim
        prompt = best_prompt + ", " + addition
        sim = similarity(image_features, prompt)
        if sim > best_sim:
            best_sim = sim
            best_prompt = prompt
            return True
        return False

    def check_multi_batch(opts):
        nonlocal best_prompt, best_sim
        prompts = []
        for i in range(2**len(opts)):
            prompt = best_prompt
            for bit in range(len(opts)):
                if i & (1 << bit):
                    prompt += ", " + opts[bit]
            prompts.append(prompt)

        prompt = rank_top(image_features, prompts)
        sim = similarity(image_features, prompt)
        if sim > best_sim:
            best_sim = sim
            best_prompt = prompt
    # check_multi_batch([best_medium, best_artist, best_trending, best_movement])
    check_multi_batch([best_entity])
    extended_flavors = set(flaves)
    for _ in range(20): # Flavor chain
        try:
            best = rank_top(image_features, [f"{best_prompt}, {f}" for f in extended_flavors])
            flave = best[len(best_prompt)+2:]
            if not check(flave):
                break
            extended_flavors.remove(flave)
        except:
            # exceeded max prompt length
            break
    return best_prompt

def inference(image):
    return interrogate(image)

# sites = ['Artstation', 'behance', 'cg society', 'cgsociety', 'deviantart', 'dribble', 'flickr', 'instagram', 'pexels', 'pinterest', 'pixabay', 'pixiv', 'polycount', 'reddit', 'shutterstock', 'tumblr', 'unsplash', 'zbrush central']
# trending_list = [site for site in sites]
# trending_list.extend(["trending on "+site for site in sites])
# trending_list.extend(["featured on "+site for site in sites])
# trending_list.extend([site+" contest winner" for site in sites])

# artists = [f"by {a}" for a in raw_artists]
# artists.extend([f"inspired by {a}" for a in raw_artists])

# celebrities = LabelTable(load_list('./parts/celebrities.txt'), "celebrities")
flavors = LabelTable(load_list('./flavors.txt'), "flavors")
entities = LabelTable(load_list('./entities.txt'), "entities")
# mediums = LabelTable(load_list('mediums.txt'), "mediums")
# movements = LabelTable(load_list('movements.txt'), "movements")
# trendings = LabelTable(trending_list, "trendings")

In [None]:
res_parts = np.array_split(image_ids, 10)
for i, part in enumerate(res_parts):
  print("part", i)
  res = {}
  for id in tqdm(part):
    try:
      res_text = inference(imgs[id])
      # print("id:{}, text:{}".format(id, res_text))
      res[id] = res_text
    except Exception as e:
      print("id:{}, error:{}".format(id, e))
  res_df = pd.DataFrame(pd.Series(res), columns=['text'])
  res_df = res_df.reset_index().rename(columns={'index':'id'})
  res_df.to_csv("harmc_image_interrogations_20_entity_{}.csv".format(i))



In [None]:
final = pd.DataFrame()
for i in range(10):
  df = pd.read_csv("harmc_image_interrogations_20_entity_{}.csv".format(i))
  final = final.append(df)
final = final.reset_index()[['id','text']]
final.to_csv("harmc_image_interrogations_20_entity_total.csv")