In [None]:
import json
import os

import numpy as np
from PIL import Image
import pandas as pd

In [None]:
from tqdm.auto import tqdm

In [None]:
from operator import itemgetter
import random
from collections import defaultdict

In [None]:
#!/usr/bin/env python3
"""
taken from
https://github.com/marcel-dancak/lz-string-python/blob/master/lzstring.py
"""
import math

KEYSTRURISAFE = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-$"
base_reverse_dict = {}


class Object:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


def get_base_value(alphabet, character):
    if alphabet not in base_reverse_dict:
        base_reverse_dict[alphabet] = {}
    for i in range(len(alphabet)):
        base_reverse_dict[alphabet][alphabet[i]] = i
    return base_reverse_dict[alphabet][character]


def decompress(length, reset_value, get_next_value):
    dictionary = {}
    enlarge_in = 4
    dict_size = 4
    num_bits = 3
    entry = ""
    result = []

    data = Object(val=get_next_value(0), position=reset_value, index=1)

    for i in range(3):
        dictionary[i] = i

    bits = 0
    maxpower = math.pow(2, 2)
    power = 1

    while power != maxpower:
        resb = data.val & data.position
        data.position >>= 1
        if data.position == 0:
            data.position = reset_value
            data.val = get_next_value(data.index)
            data.index += 1

        bits |= power if resb > 0 else 0
        power <<= 1

    next = bits
    if next == 0:
        bits = 0
        maxpower = math.pow(2, 8)
        power = 1
        while power != maxpower:
            resb = data.val & data.position
            data.position >>= 1
            if data.position == 0:
                data.position = reset_value
                data.val = get_next_value(data.index)
                data.index += 1
            bits |= power if resb > 0 else 0
            power <<= 1
        c = chr(bits)
    elif next == 1:
        bits = 0
        maxpower = math.pow(2, 16)
        power = 1
        while power != maxpower:
            resb = data.val & data.position
            data.position >>= 1
            if data.position == 0:
                data.position = reset_value
                data.val = get_next_value(data.index)
                data.index += 1
            bits |= power if resb > 0 else 0
            power <<= 1
        c = chr(bits)
    elif next == 2:
        return ""

    dictionary[3] = c
    w = c
    result.append(c)
    counter = 0
    while True:
        counter += 1
        if data.index > length:
            return ""

        bits = 0
        maxpower = math.pow(2, num_bits)
        power = 1
        while power != maxpower:
            resb = data.val & data.position
            data.position >>= 1
            if data.position == 0:
                data.position = reset_value
                data.val = get_next_value(data.index)
                data.index += 1
            bits |= power if resb > 0 else 0
            power <<= 1

        c = bits
        if c == 0:
            bits = 0
            maxpower = math.pow(2, 8)
            power = 1
            while power != maxpower:
                resb = data.val & data.position
                data.position >>= 1
                if data.position == 0:
                    data.position = reset_value
                    data.val = get_next_value(data.index)
                    data.index += 1
                bits |= power if resb > 0 else 0
                power <<= 1

            dictionary[dict_size] = chr(bits)
            dict_size += 1
            c = dict_size - 1
            enlarge_in -= 1
        elif c == 1:
            bits = 0
            maxpower = math.pow(2, 16)
            power = 1
            while power != maxpower:
                resb = data.val & data.position
                data.position >>= 1
                if data.position == 0:
                    data.position = reset_value
                    data.val = get_next_value(data.index)
                    data.index += 1
                bits |= power if resb > 0 else 0
                power <<= 1
            dictionary[dict_size] = chr(bits)
            dict_size += 1
            c = dict_size - 1
            enlarge_in -= 1
        elif c == 2:
            return "".join(result)

        if enlarge_in == 0:
            enlarge_in = math.pow(2, num_bits)
            num_bits += 1

        if c in dictionary:
            entry = dictionary[c]
        else:
            if c == dict_size:
                entry = w + w[0]
            else:
                return None
        result.append(entry)

        # Add w+entry[0] to the dictionary.
        dictionary[dict_size] = w + entry[0]
        dict_size += 1
        enlarge_in -= 1

        w = entry
        if enlarge_in == 0:
            enlarge_in = math.pow(2, num_bits)
            num_bits += 1


def decompress_from_encoded_uri(compressed):
    if compressed is None:
        return ""
    if compressed == "":
        return None
    compressed = compressed.replace(" ", "+")
    decompressed = decompress(
        len(compressed),
        32,
        lambda index: get_base_value(KEYSTRURISAFE, compressed[index]),
    )

    return decompressed


In [None]:
from pycocotools import mask as mask_utils

In [None]:
from ego4d.research.readers import TorchAudioStreamReader, PyAvReader
VideoReader = TorchAudioStreamReader

In [None]:
import spacy

In [None]:
nlp = spacy.load("en_core_web_md")

In [None]:
from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('all-mpnet-base-v2')

def txt_simm(txt1, txt2):
    query_embedding = model.encode(txt1)
    passage_embedding = model.encode([txt2])
    
    return util.dot_score(query_embedding, passage_embedding)

In [None]:
def decode_mask(mask):
    w = mask["width"]
    h = mask["height"]
    encoded_mask = mask["encodedMask"]

    decomp_string = decompress_from_encoded_uri(encoded_mask)
    decomp_encoded = decomp_string.encode()
    rle_obj = {
        "size": [h, w],
        "counts": decomp_encoded,
    }

    output = mask_utils.decode(rle_obj)
    return output


def blend_mask(input_img, binary_mask, alpha=0.5):
    if input_img.ndim == 2:
        return input_img

    mask_image = np.zeros(input_img.shape, np.uint8)
    mask_image[:, :, 1] = 255
    mask_image = mask_image * np.repeat(binary_mask[:, :, np.newaxis], 3, axis=2)

    blend_image = input_img[:, :, :]
    pos_idx = binary_mask > 0
    for ind in range(input_img.ndim):
        ch_img1 = input_img[:, :, ind]
        ch_img2 = mask_image[:, :, ind]
        ch_img3 = blend_image[:, :, ind]
        ch_img3[pos_idx] = alpha * ch_img1[pos_idx] + (1 - alpha) * ch_img2[pos_idx]
        blend_image[:, :, ind] = ch_img3
    return blend_image

In [None]:
RELEASE_DIR = "/checkpoint/miguelmartin/egoexo_data/dev"

egoexo = {
    "takes": os.path.join(RELEASE_DIR, "takes.json"),
    "captures": os.path.join(RELEASE_DIR, "captures.json"),
    "physical_setting": os.path.join(RELEASE_DIR, "physical_setting.json"),
    "participants": os.path.join(RELEASE_DIR, "participants.json"),
    "visual_objects": os.path.join(RELEASE_DIR, "visual_objects.json"),
}

for k, v in egoexo.items():
    egoexo[k] = json.load(open(v))

takes = egoexo["takes"]
captures = egoexo["captures"]
takes_by_uid = {x["take_uid"]: x for x in takes}

In [None]:
annotation_dir = os.path.join(RELEASE_DIR, "annotations/")

In [None]:
narrs = json.load(open(os.path.join(annotation_dir, "narrations_latest.json")))
relation_objs = json.load(open(os.path.join(annotation_dir, "relations_objects_latest.json")))

In [None]:
narr_takes = set(narrs.keys())
relation_takes = set(relation_objs.keys())
len(relation_takes & narr_takes)

In [None]:
overlap = sorted(list(relation_takes & narr_takes))

In [None]:
take_uid = random.sample(overlap, 1)[0]
# take_uid = overlap[0]
take_uid

In [None]:
annotation = relation_objs[take_uid]

object_masks = annotation['object_masks']
object_names = [(x, "".join(x.split("_")[0])) for x in object_masks.keys()]
object_names

In [None]:
narrs_for_take = [n for narr_pass in narrs[take_uid] for n in narr_pass["narrations"]]
narrs_with_simms = []
for n in narrs_for_take:
    doc = nlp(n["text"])
    toks_by_class = defaultdict(list)
    for tok in doc:
        # toks_by_class[tok.pos_].append((tok, n))
        toks_by_class[tok.pos_].append(tok)

all_nouns = toks_by_class['NOUN'] + toks_by_class['PROPN']
txt_simm_cache = {}
for tok in tqdm(all_nouns):
    for key, name in object_names:
        simm = txt_simm(tok.text, name)
        txt_simm_cache[(tok.text, name)] = simm

for n in tqdm(narrs_for_take):
    matching_objs_per_tok = []
    for tok in all_nouns:
        matching_objs = {}
        for key, name in object_names:
            simm_key = (tok.text, name)
            assert simm_key in txt_simm_cache
            simm = txt_simm_cache[simm_key]
            matching_objs[key] = simm.squeeze().cpu().item()
        
        if len(matching_objs) > 0:
            matching_objs_per_tok.append({"tok_txt": tok.text, "tok_idx": tok.i, "matches": matching_objs})
    narrs_with_simms.append({"narration": n, "matches": matching_objs_per_tok})

In [None]:
# narrs_with_simms[0]

In [None]:
# TODO: add pass information
SIMM_THRESHOLD = 0.6
narrs_with_matches = []
for x in narrs_with_simms:
    if any(y >= SIMM_THRESHOLD for temp in x["matches"] for y in temp["matches"].values()):
        narrs_with_matches.append(x)
len(narrs_with_matches)

In [None]:
# narrs_with_matches[0]

In [None]:
cam_id_sid = camera_name.split("_")
stream_id = "0"
cam_id = cam_id_sid[0]
if len(cam_id_sid) > 1:
    cam_id, stream_id = cam_id_sid
    if stream_id == "214-1":
        stream_id = "rgb"
cam_id, stream_id

In [None]:
takes_by_uid[take_uid]["frame_aligned_videos"].keys()

In [None]:
rel_path = takes_by_uid[take_uid]["frame_aligned_videos"][cam_id][stream_id]["relative_path"]
video_path = os.path.join(RELEASE_DIR, "takes", takes_by_uid[take_uid]["root_dir"], rel_path)
assert os.path.exists(video_path)

reader = VideoReader(path=video_path, resize=None, mean=None, frame_window_size=1, stride=1, gpu_idx=-1)

video_path

In [None]:
narr_viz = random.sample(narrs_with_matches, 1)[0]
narr_viz

In [None]:
# TODO
narr_txt = narr_viz["narration"]["text"]
matched_words = []
for ms in narr_viz["matches"]:
    for object_name, prob in ms["matches"].items():
        if prob >= SIMM_THRESHOLD:
            matched_words.append(object_name)
matched_words
# narr_viz["matches"]

In [None]:
# TODO
# object_name, object_annotations = random.sample(list(object_masks.items()), 1)[0]
# camera_name, mask_annotations = random.sample(list(object_annotations.items()), 1)[0]
# frame_number, annotation_obj = random.sample(list(mask_annotations['annotation'].items()), 1)[0]
# width, height, encodedMask = itemgetter('width', 'height', 'encodedMask')(annotation_obj)

# take_uid, object_name, camera_name, frame_number, width, height, encodedMask

In [None]:
frame = reader[int(frame_number)]
mask = decode_mask({"encodedMask": encodedMask, "width": width, "height": height})
input_img = frame[0].numpy()
pil_img = Image.fromarray(blend_mask(input_img, mask, alpha=0.7))
pil_img

In [None]:
# TODO: cross check with CLIP embeddings

In [None]:
import torch
import clip
from PIL import Image

In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)


In [None]:
# def clip_txt_simm(txt1, txt2):
#     t1 = clip.tokenize([txt1]).to(device)
#     t2 = clip.tokenize([txt2, "abc"]).to(device)
#     with torch.no_grad():
#         t1_features = clip_model.encode_text(t1)
#         t2_features = clip_model.encode_text(t2)
#         t1_features = t1_features / t1_features.norm(dim=1, keepdim=True)
#         t2_features = t2_features / t2_features.norm(dim=1, keepdim=True)
#         logit_scale = clip_model.logit_scale.exp()
#         logits_pt = logit_scale * t1_features @ t2_features.t()
#     probs = logits_pt.softmax(dim=-1).cpu().numpy()
#     return probs[0]

In [None]:
# clip_txt_simm("basketball", "def")

In [None]:
# clip_txt_simm("cat", "cat")

In [None]:

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("Label probs:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]