# Caption Transformer (CATR)
This notebook uses the CATR model developed by GitHub user saahiluppal to create captions for provided images.

In [1]:
from catr import predict_v2
from catr.configuration import Config
config = Config()

To use data.metrics please install scikit-learn. See https://scikit-learn.org/stable/index.html


In [2]:
# Load captioning model
import torch
model = torch.hub.load('saahiluppal/catr', 'v3', pretrained=True)

# Load tokenizer
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create start token
start_token = tokenizer.convert_tokens_to_ids(tokenizer._cls_token)

Using cache found in C:\Users\dtylutki3/.cache\torch\hub\saahiluppal_catr_master


In [3]:
# Test captioning
from time import time
start = time()
caption = predict_v2.predict(
    "./data/img/01235.png", model=model, tokenizer=tokenizer, start_token=start_token, config=config)
end = time()
print(caption)
print(end-start)

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


A man with a beard is holding a pair of scissors
6.3809778690338135


# Conduct Captioning (Reading Images from Google Cloud Storage)

In [7]:
# Connect to bucket
from google.cloud import storage
from os import environ
environ['GOOGLE_APPLICATION_CREDENTIALS']="./credentials.json"
client = storage.Client(project='deep-learning-project-347210')
#client = storage.Client(project='deep-learning-project-347210', credentials=None)
bucket = client.get_bucket('hateful_memes')

In [26]:
# Get all image file names
images = [blob.name.split('/')[-1] for blob in bucket.list_blobs(prefix='hateful_memes/img/')][1:]
print(images)

['01235.png', '01236.png', '01243.png', '01245.png', '01247.png', '01256.png', '01258.png', '01264.png', '01268.png', '01269.png', '01274.png', '01275.png', '01276.png', '01278.png', '01284.png', '01285.png', '01293.png', '01295.png', '01324.png', '01325.png', '01327.png', '01329.png', '01348.png', '01349.png', '01358.png', '01359.png', '01364.png', '01368.png', '01379.png', '01382.png', '01389.png', '01392.png', '01395.png', '01397.png', '01423.png', '01425.png', '01436.png', '01439.png', '01452.png', '01456.png', '01459.png', '01465.png', '01467.png', '01468.png', '01469.png', '01472.png', '01475.png', '01476.png', '01483.png', '01487.png', '01492.png', '01497.png', '01498.png', '01524.png', '01526.png', '01527.png', '01529.png', '01546.png', '01547.png', '01548.png', '01562.png', '01564.png', '01567.png', '01568.png', '01569.png', '01576.png', '01578.png', '01579.png', '01589.png', '01594.png', '01598.png', '01627.png', '01634.png', '01637.png', '01642.png', '01643.png', '01649.png'

In [None]:
# For each image, download the image from google cloud, create caption, and write caption to csv
from tqdm import tqdm
from os import remove
for image in tqdm(images):
    blob = bucket.get_blob(f'hateful_memes/img/{image}')
    blob.download_to_filename(image)
    caption = predict_v2.predict(image, model=model, tokenizer=tokenizer, start_token=start_token, config=config)
    remove(image)
    with open('./data/captions.csv', 'a') as f:
        f.write(f"{image},{caption}\n")

# Conduct Captioning (Reading Images from Local Machine)

In [7]:
from os import listdir
images = listdir('C:/Users/dtylutki3/OneDrive - Georgia Institute of Technology/data/hateful_memes/img')  # on GT VM
#images = listdir('C:/Users/danty/Data/hateful_memes/img')  # local machine
images = images[9851:]
print(images)

['80734.png', '80735.png', '80736.png', '80742.png', '80759.png', '80764.png', '80765.png', '80769.png', '80792.png', '80912.png', '80914.png', '80915.png', '80916.png', '80921.png', '80924.png', '80925.png', '80926.png', '80927.png', '80932.png', '80935.png', '80941.png', '80942.png', '80943.png', '80945.png', '80947.png', '80954.png', '80957.png', '80965.png', '80967.png', '80971.png', '80972.png', '80974.png', '80976.png', '81026.png', '81027.png', '81035.png', '81036.png', '81043.png', '81047.png', '81054.png', '81056.png', '81059.png', '81063.png', '81064.png', '81067.png', '81069.png', '81075.png', '81079.png', '81092.png', '81093.png', '81094.png', '81095.png', '81096.png', '81097.png', '81203.png', '81205.png', '81206.png', '81207.png', '81239.png', '81243.png', '81245.png', '81249.png', '81250.png', '81254.png', '81256.png', '81257.png', '81259.png', '81260.png', '81263.png', '81265.png', '81273.png', '81276.png', '81279.png', '81293.png', '81294.png', '81295.png', '81296.png'

In [8]:
# For each image, open the image from local folder, create caption, and write caption to csv
from tqdm import tqdm
from os import remove
for image in tqdm(images):
    img_path = f'C:/Users/dtylutki3/OneDrive - Georgia Institute of Technology/data/hateful_memes/img/{image}'  # on GT VM
    #img_path = f'C:/Users/danty/Data/hateful_memes/img/{image}'
    caption = predict_v2.predict(img_path, model=model, tokenizer=tokenizer, start_token=start_token, config=config)
    with open('./data/captions.csv', 'a') as f:
        f.write(f"{image},{caption}\n")

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2289/2289 [4:34:23<00:00,  7.19s/it]
