In [None]:
import pandas as pd
import numpy as np

from PIL import Image
from IPython.display import display, HTML, clear_output
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

import sys
import os
DIR = os.getcwd()
import torch
import torchvision

from facebook_hateful_memes_detector.utils.globals import set_global, get_global
set_global("cache_dir", "/home/ahemf/cache2/cache")
set_global("dataloader_workers", 8)
set_global("use_autocast", True)
set_global("models_dir", "/home/ahemf/cache/")

from facebook_hateful_memes_detector.preprocessing import DefinedRotation, QuadrantCut, DefinedAffine, DefinedColorJitter, DefinedRandomPerspective, ImageAugment
from PIL import Image
from facebook_hateful_memes_detector.utils import get_image_info_fn, set_device, get_device
from torchvision import transforms
import joblib
from tqdm.auto import tqdm, trange
from joblib import Parallel, delayed
from facebook_hateful_memes_detector.preprocessing import TextImageDataset, get_datasets, get_image2torchvision_transforms, TextAugment, get_transforms_for_bbox_methods
from facebook_hateful_memes_detector.models.external.detr import get_detr_model
from facebook_hateful_memes_detector.training import *

def hash(x):
    return joblib.hashing.hash(x, 'sha1')

def print_code(func):
    import inspect
    from pygments import highlight
    from pygments.lexers import PythonLexer
    from pygments.formatters import TerminalFormatter

    code = "".join(inspect.getsourcelines(func)[0])
    print(highlight(code, PythonLexer(), TerminalFormatter()))

set_device('cuda')



In [None]:
!wget https://raw.githubusercontent.com/airsplay/py-bottom-up-attention/master/demo/data/images/input.jpg


In [None]:
data = get_datasets(data_dir="/home/ahemf/cache/data/",
                    train_text_transform=None,
                    train_image_transform=None,
                    test_text_transform=None,
                    test_image_transform=None,
                    train_torchvision_pre_image_transform=None,
                    test_torchvision_pre_image_transform=None,
                    cache_images=False,
                    use_images=True,
                    dev=False,
                    test_dev=True,
                    keep_original_text=True,
                    keep_original_image=True,
                    keep_processed_image=True,
                    keep_torchvision_image=False,
                    train_mixup_config=None)



df = pd.concat((data["train"],
                data['dev_unseen'],
                data["test"], data['test_unseen'])).sample(frac=1.0)

dataset = convert_dataframe_to_dataset(df, data["metadata"], True)


In [None]:
from time import sleep
for transformation in list(get_transforms_for_bbox_methods().transforms + get_transforms_for_multiview()):
    display(transformation(dataset[0]['original_image']))
    sleep(2)

# Cache Warmup

In [None]:
from facebook_hateful_memes_detector.utils import get_image_info_fn
from facebook_hateful_memes_detector.preprocessing import get_transforms_for_bbox_methods
from facebook_hateful_memes_detector.preprocessing import get_transforms_for_multiview
from time import sleep
all_transforms = list(get_transforms_for_bbox_methods().transforms + get_transforms_for_multiview())
get_img_details = get_image_info_fn(enable_encoder_feats=False, device=get_device())["get_img_details"]
get_lxmert_details = get_image_info_fn(enable_encoder_feats=False, device=get_device())["get_lxmert_details"]

def lxmert_faster_rcnn_fn(img):
    sleep(0.1)
    _ = get_img_details(img)
    sleep(0.1)
    _ = get_lxmert_details(img)
    sleep(0.1)

for elem in tqdm(iter(dataset), total=len(dataset)):
    image = elem["original_image"]
    for idx, transformation in enumerate(all_transforms):
        lxmert_faster_rcnn_fn(transformation(image.copy()))
        if idx == 0:
            for _ in range(3):
                lxmert_faster_rcnn_fn(transformation(image.copy()))
        
        


In [None]:
cache_stats = get_global("cache_stats")
cache_stats['get_img_details']
cache_stats['get_lxmert_details']


In [None]:
%timeit rows = caches[0]._sql('SELECT rowid, expire_time, tag, mode, filename, value from Cache limit 10').fetchall()


# Cache Sync

In [None]:
import gc
import os
import sys
from collections import defaultdict, Counter
from random import random, shuffle
from time import sleep
from typing import List, Callable
from tqdm.auto import tqdm, trange

import cv2
import numpy as np
import requests
import torch
from PIL import Image

cache_dirs = ["/home/ahemf/cache/cache", "/home/ahemf/cache/cache2",
              #"/home/ahemf/cache2/cache",
              "/home/ahemf/cache3/cache", 
              "/home/ahemf/cache4/cache"
             ] # "/home/ahemf/cache3/cache"
args = dict(eviction_policy='none', sqlite_cache_size=2 ** 16, sqlite_mmap_size=2 ** 28, disk_min_file_size=2 ** 18)
from diskcache import Cache
caches = [Cache(cd, **args) for cd in cache_dirs]

for cache in caches:
    cache.check(True)


In [None]:
cache2keys = [set(list(cache.iterkeys())) for cache in caches]


for cn1, ck1 in enumerate(cache2keys):
    c1 = caches[cn1]
    for cn2, ck2 in enumerate(cache2keys):
        if cn1==cn2:
            continue
        c2 = caches[cn2]
        cn1_cn2 = ck1 - ck2
        print((cn1, cn2), len(cn1_cn2))
        for citem in tqdm(cn1_cn2):
            item = c1[citem]
            c2[citem] = item
        
            
for cc in caches:
    cc.close()
        
        



In [None]:
# Broadcast from 1 to all others
cache2keys = [set(list(cache.iterkeys())) for cache in caches]

c1 = caches[0]
for citem in tqdm(cache2keys[0]):
    item = c1[citem]
    for cc in caches[1:]:
        cc[citem] = item

        
            
for cc in caches:
    cc.close()
        
        



# Get image Captions and BBoxes

In [None]:
from facebook_hateful_memes_detector.utils import get_image_info_fn
from facebook_hateful_memes_detector.preprocessing import get_transforms_for_bbox_methods
from facebook_hateful_memes_detector.preprocessing import get_transforms_for_multiview
from time import sleep
from collections import OrderedDict

df = pd.concat((data["train"],
                data['dev_unseen'],
                data["test"], data['test_unseen']))

dataset = convert_dataframe_to_dataset(df, data["metadata"], False)

vg_classes = pd.read_csv("https://raw.githubusercontent.com/peteanderson80/bottom-up-attention/master/data/genome/1600-400-20/objects_vocab.txt", engine="python", header=None, sep="\t", names=["classes"])
vg_classes = np.array(["N/A"] + list(vg_classes.classes))
get_captioning_fn = get_image_info_fn(enable_encoder_feats=True, enable_image_captions=True)["get_image_captions"]
get_img_details = get_image_info_fn(enable_encoder_feats=False, device=get_device())["get_img_details"]

image_classes = []
for elem, row in tqdm(zip(iter(dataset),df.iterrows()), total=len(dataset)):
    identifier = row[1]["id"]
    image = elem["original_image"]
    feats, info = get_img_details(image)
    cls = info["cls_prob"].argmax(axis=1)
    detected_classes = vg_classes[cls]
    detected_classes = list(OrderedDict((i, 0) for i in detected_classes).keys())
    detected_classes = [d.replace(",",' ') for d in detected_classes if d not in ["N/A","background"]]
    captions = get_captioning_fn(image)
    captions = list(sorted(captions, key=lambda x:len(x.split()), reverse=True))[0]
    image_classes.append([identifier, " ".join(detected_classes[:10]), captions])
    
    
    
    
    
    
    
    

In [None]:
image_classes[5]

In [None]:
imdf = pd.DataFrame(image_classes, columns=["id","objects","caption"])

In [None]:
imdf.to_csv("/home/ahemf/cache/data/objects_captions.csv", index=False)

In [None]:
imdf.sample(5)

# Other

In [None]:
img = Image.open("input.jpg")
im_transform = get_transforms_for_bbox_methods()
from facebook_hateful_memes_detector.preprocessing import HalfSwap, QuadrantCut, DefinedRotation, DefinedAffine
# im_transform = transforms.RandomAffine(0, scale=(1.25, 1.25))
hashes = Parallel(n_jobs=8, backend='threading')(delayed(lambda i: hash(im_transform(i)))(img.copy()) for i in trange(10000))
hashes = set(hashes)
len(hashes)




In [None]:
im = Image.open("input.jpg")
im_transform(im)

In [None]:
%timeit im_transform(im)

In [None]:
fns = get_image_info_fn(enable_encoder_feats=True, enable_image_captions=False)
cache_fns = [fns["get_img_details"], fns["get_encoder_feats"], 
             fns["get_lxmert_details"], 
             get_detr_model(get_device(), "detr_resnet50")["detr_fn"], 
             get_detr_model(get_device(), "detr_resnet50_panoptic")["detr_fn"], lambda x: x]

images = list(data["train"].img.values) + list(data["test"].img.values)
try:
    from torch.cuda.amp import GradScaler, autocast
    scaler = GradScaler()
    use_autocast = "cuda" in str(get_device())
except:
    pass
for i in tqdm(images):
    img = Image.open(i)
    for k, aug in augs_dict.items():
        for _ in range(16):
            img_copy = aug(img.copy())
            for fn in cache_fns:
                with autocast():
                    _ = fn(img_copy)
    



# We need Qcut for robustness
# We need DefinedAffine translate since in this problem we care about presence of object not position
# We need Image models trained on classification which only care about presence not position


In [None]:
from transformers import AutoModelWithLMHead, AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
texts = pd.read_csv("text.csv", header=None)[0].values

m = lambda x: tokenizer.encode_plus(x, add_special_tokens=True, pad_to_max_length=False, truncation=False)
tlens = [len(d['input_ids']) for d in map(m, texts)]

np.percentile(tlens, [97, 99, 99.5, 99.9, 100])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_device(device)

from facebook_hateful_memes_detector.utils import get_image_info_fn



In [None]:
torchvision.transforms.RandomPerspective(p=1.0)(Image.open("input.jpg"))

In [None]:
fns = get_image_info_fn(enable_encoder_feats=True, enable_image_captions=True)


In [None]:
fns["get_image_captions"](dataset[0]['original_image'])
# fns["get_image_captions"]("../data/img/08291.png")

In [None]:
from facebook_hateful_memes_detector.utils.detectron_v1_object_detector import LXMERTFeatureExtractor, persistent_caching_fn
lxmert_feature_extractor = LXMERTFeatureExtractor(get_device(), do_autocast=False)
def fn(x):
    return x
fn = persistent_caching_fn(fn, "random_2323", False)
fn(2)
img = Image.open("input.jpg")
from torch.cuda.amp import autocast

feats = lxmert_feature_extractor(img)  

feats[0].scores.mean()
feats[1].mean()
feats[0].pred_boxes
feats[0].scores









In [None]:
from facebook_hateful_memes_detector.utils.detectron_v1_object_detector import LXMERTFeatureExtractor
lxmert_feature_extractor = LXMERTFeatureExtractor(get_device())
img = Image.open("input.jpg")
from torch.cuda.amp import autocast
with autocast():
    feats = lxmert_feature_extractor(img)  
len(feats)
feats[0].scores.mean()
feats[1].mean()
feats[0].pred_boxes
feats[0].scores









In [None]:
res, info = fns["feature_extractor"](Image.open("../data/img/08291.png"))
res[:2, :8]
info["boxes"][:4]
info["objects"]


In [None]:
res, info = fns["get_img_details"](Image.open("../data/img/08291.png"))
res[:2, :8]
info["boxes"][:4]
info["objects"]


In [None]:
im = im_transform(Image.open("input.jpg"))
im

In [None]:
instances, roi_features = fns["get_lxmert_details"](im)
instances.pred_boxes.tensor # boxes
roi_features # feats
# (feats, boxes)

In [None]:
fns["get_encoder_feats"]("../data/img/08291.png")
