In [1]:
import glob
import clip
import os
from torch import nn
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as nnf
import sys
from typing import Tuple, List, Union, Optional
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
from tqdm import tqdm, trange
# from google.colab import files
import skimage.io as io
import PIL.Image
from IPython.display import Image
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

import requests

from clip_prefix_captioning import MLP, ClipCaptionModel, ClipCaptionPrefix, generate_beam, generate2
from nltk.sentiment import SentimentIntensityAnalyzer

In [8]:
pd.set_option('display.max_columns', None)  # or 1000
pd.set_option('display.max_rows', None)  # or 1000
pd.set_option('display.max_colwidth', -1)  # or 199

  pd.set_option('display.max_colwidth', -1)  # or 199


In [14]:
D = torch.device
CPU = torch.device('cpu')

is_gpu = True

def get_device(device_id: int) -> D:
    if not torch.cuda.is_available():
        return CPU
    device_id = min(torch.cuda.device_count() - 1, device_id)
    return torch.device(f'cuda:{device_id}')


CUDA = get_device

current_directory = os.getcwd()
save_path = os.path.join(os.path.dirname(current_directory), "pretrained_models")
os.makedirs(save_path, exist_ok=True)
model_path = os.path.join(save_path, 'model_weights.pt')

device = CUDA(0) if is_gpu else "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

prefix_length = 10

model = ClipCaptionModel(prefix_length)

model.load_state_dict(torch.load(model_path, map_location=CPU)) 

model = model.eval() 
device = CUDA(0) if is_gpu else "cpu"
model = model.to(device)

In [15]:
def generate_caption(input_img, use_clip_emb=False, display_img=False, use_beam_search=False, return_all_beam=False, return_clip_emb=False):
    if use_clip_emb:
        prefix = input_img
        with torch.no_grad():
            prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
    else:
        image = io.imread(input_img)
        pil_image = PIL.Image.fromarray(image)
        if display_img:
            display(pil_image.resize((512, 512)))
        image = preprocess(pil_image).unsqueeze(0).to(device)
        with torch.no_grad():
            prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
            prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1)
    if use_beam_search:
        if return_all_beam:
            generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)
        else:
            generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
    else:
        generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)

    if return_clip_emb:
        return generated_text_prefix, prefix
    else:
        return generated_text_prefix

In [2]:
flickr_data = pd.read_pickle('flickr_data_all.pkl')

In [3]:
flickr_data = flickr_data[flickr_data['clip_emb'].notnull()]
flickr_data['clip_emb'] = flickr_data.apply(lambda row: np.frombuffer(row['clip_emb'], dtype='float32'), axis=1)

In [4]:
flickr_data = flickr_data[flickr_data['style_emb'].notnull()]
flickr_data['style_emb'] = flickr_data.apply(lambda row: np.frombuffer(row['style_emb'], dtype='float32'), axis=1)

In [5]:
flickr_data = flickr_data[flickr_data['caption_emb'].notnull()]
flickr_data['caption_emb'] = flickr_data.apply(lambda row: np.frombuffer(row['caption_emb'], dtype='float32'), axis=1)

# Captioning the style embedding

In [22]:
index= 7
generate_caption(torch.from_numpy(flickr_data.loc[index,'style_emb']), use_clip_emb=True, display_img=False, use_beam_search=True, return_all_beam=True, return_clip_emb=False)

['A couple of men standing next to each other on a street.',
 'A couple of men standing next to each other.',
 'A couple of men standing next to each other on a  street.',
 'A couple of men standing next to each other on a road.',
 'A couple of men standing next to each other near a motorcycle.']

In [23]:
flickr_data.loc[index]

Unnamed: 0     13292286774                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                              