In [61]:
from IPython.display import clear_output
import os
import sys

def download_from_gdrive(gdrive_id, filename):
    !wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id='$gdrive_id -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id="$gdrive_id -O $filename && rm -rf /tmp/cookies.txt

import torch
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('DEVICE:', DEVICE)

DEVICE: cuda


In [62]:
GDRIVE_ID = '1S4QwcuznRxLlxkIT0Lb6vIuqDTib41B3'
FILE_IDS_NAME = 'file_ids.txt'

download_from_gdrive(GDRIVE_ID, FILE_IDS_NAME)

FILE_IDS = {}
with open(FILE_IDS_NAME, 'r') as f:
    for line in f:
        name, gid = line.strip().split('\t')
        FILE_IDS[name] = gid

clear_output()

In [63]:
print('Loading code from GitHub')
!git clone https://github.com/criticalH1T/deep_meme

Loading code from GitHub
fatal: destination path 'deep_meme' already exists and is not an empty directory.


In [64]:
os.chdir('/content')
sys.path.append('/content/deep_meme')

In [65]:
DATA_DIR = 'memes900k'
CAPTIONS_FILE = os.path.join(DATA_DIR, 'captions_train.txt')

In [66]:
print('Loading the dataset from Google Drive')
fname = f'{DATA_DIR}.zip'
download_from_gdrive(FILE_IDS[fname], fname)
!unzip -o {DATA_DIR}
clear_output()

In [67]:
from deep_meme.data.vocab import Vocab, build_vocab_from_file
from deep_meme.data.tokenizers import WordPunctTokenizer, CharTokenizer

LOAD_VOCABULARY = True #@param {type:"boolean"}
MIN_DF = 5 #@param {type:"integer"}

tokenizer_words = WordPunctTokenizer()
tokenizer_chars = CharTokenizer()

if LOAD_VOCABULARY:
    print('Loading vocabularies from Google Drive')

    fname = 'vocab.zip'
    download_from_gdrive(FILE_IDS[fname], fname)
    !unzip -o {fname}

    vocab_words = Vocab.load('vocab/vocab_words.txt')
    vocab_chars = Vocab.load('vocab/vocab_chars.txt')
    clear_output()

    print('Loaded vocabularies from Google Drive')
else:
    print(f'Building WordPunct Vocabulary from {CAPTIONS_FILE}, min_df={MIN_DF}')
    vocab_words = build_vocab_from_file(CAPTIONS_FILE, tokenizer_words, min_df=MIN_DF)

    print(f'Building Character Vocabulary from {CAPTIONS_FILE}, min_df={MIN_DF}')
    vocab_chars = build_vocab_from_file(CAPTIONS_FILE, tokenizer_chars, min_df=MIN_DF)


print('\nVocabulary sizes:')
print('WordVocab:', len(vocab_words))
print('CharVocab:', len(vocab_chars))

Loaded vocabularies from Google Drive

Vocabulary sizes:
WordVocab: 36541
CharVocab: 71


In [68]:
from deep_meme.data import MemeDataset

# use this to limit the dataset size (300 classes in total)
NUM_CLASSES = 200 #@param {type:"slider", min:1, max:300, step:1}  
PAD_IDX = vocab_words.stoi['<pad>']

from torchvision import transforms
image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

splits = ['train', 'val', 'test']

datasets_words = {
    # WORD-LEVEL
    split: MemeDataset(DATA_DIR, vocab_words, tokenizer_words, image_transform=image_transform,
                       num_classes=NUM_CLASSES, split=split, preload_images=True)
    for split in splits
}

datasets_chars = {
    # CHAR-LEVEL
    split: MemeDataset(DATA_DIR, vocab_chars, tokenizer_chars, image_transform=image_transform,
                       num_classes=NUM_CLASSES, split=split, preload_images=True)
    for split in splits
}

In [69]:
from deep_meme.models import (
    CaptioningLSTMWithLabels, 
)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def load_and_build_model(gdrive_id, ckpt_path, model_class):
    print('Downloading model weights from Google Drive')
    download_from_gdrive(gdrive_id, ckpt_path)
    clear_output()
    print('Downloaded model weights')

    print(f'Building {model_class.__name__} model')
    model = model_class.from_pretrained(ckpt_path).to(DEVICE)
    print(f'Built and loaded {model_class.__name__} model from {ckpt_path}')
    print('# parameters:', count_parameters(model))

    return model

FILE_TO_CLASS = {
    'LSTMDecoderWithLabelsWords.best.pth': CaptioningLSTMWithLabels,
    'LSTMDecoderWithLabelsChars.best.pth': CaptioningLSTMWithLabels,
}

In [71]:
ckpt_path = 'LSTMDecoderWithLabelsWords.best.pth'
gdrive_id = FILE_IDS[ckpt_path]
model_class = FILE_TO_CLASS[ckpt_path]

w_lstm_model_labels = load_and_build_model(gdrive_id, ckpt_path, model_class)

Downloaded model weights
Building CaptioningLSTMWithLabels model
Built and loaded CaptioningLSTMWithLabels model from LSTMDecoderWithLabelsWords.best.pth
# parameters: 45333181


In [74]:
IMG_DIR = 'images_inference'

fname = 'inference.zip'
download_from_gdrive(FILE_IDS[fname], fname)
!unzip -o {fname}
clear_output()

In [75]:
from PIL import Image
from deep_meme.experiments import text_to_seq, seq_to_text, split_caption
from deep_meme.imaging import memeify_image
FONT_PATH = 'deep_meme/fonts/impact.ttf'

def get_a_meme(model, img_torch, img_pil, caption, T=1., beam_size=7, top_k=50, 
               labels = None, mode = 'word', device=DEVICE):
    if mode == 'word':
        vocabulary = vocab_words
        datasets = datasets_words
        delimiter=' '
        max_len = 32
    else:
        vocabulary = vocab_chars
        datasets = datasets_chars
        delimiter=''
        max_len = 128
    
    model.eval()
    if caption is not None:
        caption_tensor = torch.tensor(datasets['train']._preprocess_text(caption)[:-1]).unsqueeze(0).to(device) 
    else:
        caption_tensor = None

    if labels is None:
        with torch.no_grad():
            output_seq = model.generate(
                image=img_torch, caption=caption_tensor,
                max_len=max_len, beam_size=beam_size, temperature=T, top_k=top_k
            )
    else:
        with torch.no_grad():
            output_seq = model.generate(
                image=img_torch, label=labels, caption=caption_tensor,
                max_len=max_len, beam_size=beam_size, temperature=T, top_k=top_k
            )
    
    pred_seq = output_seq
    text = seq_to_text(pred_seq, vocab=vocabulary, delimiter=delimiter)

    top, bottom = split_caption(text, num_blocks=2)
    # print(top)
    # print(bottom)

    return memeify_image(img_pil, top, bottom, font_path=FONT_PATH)

**LSTM MODEL**

In [None]:
!pip install ColabCode
!pip install FastAPI

In [77]:
from colabcode import ColabCode
from fastapi import FastAPI
from fastapi import Response
import requests
import json

In [78]:
cc = ColabCode(port=12000, code=False)

In [79]:
def get_result(image_id = "Bad-Luck-Brian"):
  id = image_id.split('-')
  label = " ".join(id)
  labels = torch.tensor(datasets_words['train']._preprocess_text(label)).unsqueeze(0).cuda()
  img_torch = datasets_words['train'].images[label]
  img_pil = Image.open(datasets_words['train'].templates[label])
  img_torch = img_torch.unsqueeze(0).cuda()
  caption = None # "Your mom"
  
  result = get_a_meme(
    model=w_lstm_model_labels, T=1.3, 
    beam_size=10,
    top_k=100,
    img_torch=img_torch, 
    img_pil=img_pil, 
    caption=caption, 
    labels=None, 
    mode='word',
    device='cuda'
  )

  return result

**MODEL DEPLOYMENT**

In [80]:
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.post("/")
async def read_root():
  return "Hi"

@app.post("/generate")
async def get_image(image_id, response: Response):

  print(image_id)
  result = get_result(image_id)
  result.save("out.jpg", "JPEG", quality=100, optimize=True)

  # Set response headers and return buffer
  response.headers["Content-Type"] = "image/jpeg"
  return FileResponse("out.jpg", media_type="image/jpeg")

In [None]:
cc.run_app(app=app)