In [81]:
import os
os.chdir("/home/castrose/pycharm_project_394")

In [None]:
from typing import Any, Callable, Iterable, MutableMapping, Optional, Sequence, Union

import PIL.Image
import clip
import decord
import numpy as np
import seaborn as sns
import torch
from clip.model import CLIP
from matplotlib import pyplot as plt
from matplotlib.offsetbox import AnnotationBbox, OffsetImage
from spacy.tokens import Doc, Span


def get_video_info(path: str) -> MutableMapping[str, Any]:
    video_reader = decord.VideoReader(path)

    frame_indices = list(range(0, len(video_reader), 10))
    frames = [PIL.Image.fromarray(f) for f in video_reader.get_batch(frame_indices).asnumpy()]    

    thumbnails_frame_indices = video_reader.get_key_indices()
    thumbnails = [PIL.Image.fromarray(f) for f in video_reader.get_batch(thumbnails_frame_indices).asnumpy()]
    
    thumbnails = [f.copy() for f in thumbnails]
    for thumbnail in thumbnails:
        thumbnail.thumbnail((64, 64))

    return {
        "frames": frames,
        "frame_times": video_reader.get_frame_timestamp(frame_indices).mean(axis=-1),
        "thumbnails": thumbnails,
        "thumbnail_times": video_reader.get_frame_timestamp(thumbnails_frame_indices).mean(axis=-1),
    }


def get_local_video_info(video_id: str) -> MutableMapping[str, Any]:
    path = f"demo/static/videos/{video_id}.mp4"
    if not os.path.isfile(path):
        path = f"demo/static/videos/{video_id}.webm"
        assert os.path.isfile(path)
        
    video_info = get_video_info(path)
    return {"video_id": video_id, **video_info}


def encode_visual(images: Iterable[PIL.Image.Image], clip_model: CLIP,
                  image_preprocessor: Callable[[PIL.Image.Image], torch.Tensor],
                  device: Optional[Any] = None) -> torch.Tensor:
    images = torch.stack([image_preprocessor(image) for image in images])

    if device is not None:
        images = images.to(device)

    with torch.inference_mode():
        encoded_images = clip_model.encode_image(images)
        return encoded_images / encoded_images.norm(dim=-1, keepdim=True)


def encode_text(text: str, clip_model: CLIP, device: Optional[Any] = None) -> torch.Tensor:
    tokenized_texts = clip.tokenize([text])

    if device is not None:
        tokenized_texts = tokenized_texts.to(device)

    with torch.inference_mode():
        encoded_texts = clip_model.encode_text(tokenized_texts)
        return encoded_texts / encoded_texts.norm(dim=-1, keepdim=True)


def text_probs(encoded_images: torch.Tensor, encoded_texts: torch.Tensor) -> np.ndarray:
    with torch.inference_mode():
        # clip_model.logit_scale.exp() == 100
        return (100 * encoded_images @ encoded_texts.T).softmax(dim=0).squeeze(-1).cpu().numpy()


def create_figure(times: Sequence[float], probs: Sequence[float], thumbnail_times: Sequence[float],
                  thumbnails: Iterable[PIL.Image.Image], title: Union[Doc, Span, str]) -> plt.Axes:
    sns.set(rc={"figure.figsize": (1.0 * len(thumbnail_times), 1.5)})

    ax = sns.lineplot(x=times, y=probs)
    
    plt.xticks(thumbnail_times)

    ax.set_title(title.text if isinstance(title, (Doc, Span)) else title, fontsize=35, y=0.6)
    ax.set(xlabel="time", ylabel="probability")

    plt.fill_between(times, probs)
    
    if isinstance(title, (Doc, Span)):
        start_time = title[0]._.start_time
        end_time = title[-1]._.end_time

        plt.axvspan(start_time, end_time, alpha=0.5, color="red")

    for i, (time, thumbnail) in enumerate(zip(thumbnail_times, thumbnails)):
        im = OffsetImage(thumbnail, axes=ax)
        ab = AnnotationBbox(im, (time, 0), xybox=(0, -60), frameon=False, boxcoords="offset points", pad=0)
        ax.add_artist(ab)

    plt.margins(x=0, tight=True)
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    return ax


def create_figure_for_text(encoded_frames: torch.Tensor, text: Union[Doc, Span, str], clip_model: CLIP,
                           times: Sequence[float], thumbnail_times: Sequence[float],
                           thumbnails: Iterable[PIL.Image.Image]) -> plt.Axes:
    encoded_texts = encode_text(text.text if isinstance(text, (Doc, Span)) else text, clip_model,
                                device=encoded_frames.device)
    probs = text_probs(encoded_frames, encoded_texts)
    return create_figure(times, probs, thumbnail_times, thumbnails, text)

In [None]:
sns.set_theme()

device = "cuda" if torch.cuda.is_available() else "cpu"

clip_model, image_preprocessor = clip.load("ViT-B/16", device=device)

In [None]:
video_info = get_local_video_info("1v2PRuxoMp8")
encoded_frames = encode_visual(video_info["frames"], clip_model, image_preprocessor, device=device)

In [None]:
create_figure_for_text(encoded_frames, "oil", clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
plt.show()

In [None]:
create_figure_for_text(encoded_frames, "I am a human being.", clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
plt.show()

In [None]:
create_figure_for_text(encoded_frames, "Shake it", clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
plt.show()

## Alternatives to softmax

In [None]:
create_figure_for_text(encoded_frames, "Shake it really well by putting your finger on top.", clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
plt.show()

### Dot product

Note temperature in dot-product just scales everything so it's useless.

In [None]:
text = "Shake it really well by putting your finger on top."
encoded_texts = encode_text(text.text if isinstance(text, (Doc, Span)) else text, device=encoded_frames.device)

with torch.inference_mode():
    probs = (encoded_frames @ encoded_texts.T).squeeze(-1).cpu().numpy()  # .softmax(dim=0)

create_figure(video_info["frame_times"], probs, video_info["thumbnail_times"], video_info["thumbnails"], text)
plt.show()

### Softmax with diff temperature

In [None]:
text = "Shake it really well by putting your finger on top."
encoded_texts = encode_text(text.text if isinstance(text, (Doc, Span)) else text, device=encoded_frames.device)

with torch.inference_mode():
    probs = (1 * encoded_frames @ encoded_texts.T).softmax(dim=0).squeeze(-1).cpu().numpy()

create_figure(video_info["frame_times"], probs, video_info["thumbnail_times"], video_info["thumbnails"], text)
plt.show()

In [None]:
text = "Shake it really well by putting your finger on top."
encoded_texts = encode_text(text.text if isinstance(text, (Doc, Span)) else text, device=encoded_frames.device)

with torch.inference_mode():
    probs = (10 * encoded_frames @ encoded_texts.T).softmax(dim=0).squeeze(-1).cpu().numpy()

create_figure(video_info["frame_times"], probs, video_info["thumbnail_times"], video_info["thumbnails"], text)
plt.show()

In [None]:
text = "Shake it really well by putting your finger on top."
encoded_texts = encode_text(text.text if isinstance(text, (Doc, Span)) else text, device=encoded_frames.device)

with torch.inference_mode():
    probs = (50 * encoded_frames @ encoded_texts.T).softmax(dim=0).squeeze(-1).cpu().numpy()

create_figure(video_info["frame_times"], probs, video_info["thumbnail_times"], video_info["thumbnails"], text)
plt.show()

### Exponential w/o normalization, temperature 0.1

In [None]:
text = "Shake it really well by putting your finger on top."
encoded_texts = encode_text(text.text if isinstance(text, (Doc, Span)) else text, device=encoded_frames.device)

with torch.inference_mode():
    probs = (10 * encoded_frames @ encoded_texts.T).exp().squeeze(-1).cpu().numpy()

create_figure(video_info["frame_times"], probs, video_info["thumbnail_times"], video_info["thumbnails"], text)
plt.show()

## Preparing to visualize mutliple captions

In [None]:
import json
import os
import re
from typing import Mapping

import spacy
import spacy_alignments
from spacy.tokens import Token


RE_MULTIPLE_SPACES = re.compile(r" {2,}")

CAPTIONS_DIR = os.path.join(os.environ["SCRATCH_DIR"], "captions")

spacy.prefer_gpu()
nlp = spacy.load("en_core_web_trf")


def _captions_to_text(caption_full_dict: Mapping[str, Any]) -> str:
    return RE_MULTIPLE_SPACES.sub(" ", " ".join(d["alternatives"][0]["transcript"].strip()
                                                for d in caption_full_dict["results"][:-1])).strip()


def _parse_caption_time(s: str) -> float:
    return float(s[:-1])


def _load_caption(path: str) -> Optional[Mapping[str, Any]]:
    with open(path) as file:
        caption_full_dict = json.load(file)

        if results := caption_full_dict["results"]:
            tokens_info = results[-1]["alternatives"][0]["words"]
        else:
            tokens_info = None

        if tokens_info:
            return {  # Save some memory by just keeping what we actually use.
                "text": _captions_to_text(caption_full_dict),
                "video_id": os.path.basename(path).rsplit(".", maxsplit=1)[0],
                "tokens_info": [{
                    "word": wi["word"],
                    "start_time": _parse_caption_time(wi["startTime"]),
                    "end_time": _parse_caption_time(wi["endTime"]),
                } for wi in tokens_info],
            }
        else:
            return None  # There are around 750/150k that fall here for different reasons.

        
def _add_caption_info_to_doc(doc: Doc, tokens_info: Sequence[Mapping[str, Any]]) -> Doc:
    spacy2caption = spacy_alignments.get_alignments([t.text for t in doc], [w["word"] for w in tokens_info])[0]

    for token, caption_token_indices in zip(doc, spacy2caption):
        token._.start_time = tokens_info[caption_token_indices[0]]["start_time"]
        token._.end_time = tokens_info[caption_token_indices[-1]]["end_time"]

    return doc

In [None]:
Token.set_extension("start_time", default=None)
Token.set_extension("end_time", default=None)

In [None]:
def caption_to_doc(video_id: str) -> Doc:
    caption = _load_caption(os.path.join(CAPTIONS_DIR, f"{video_id}.json"))
    doc = nlp(caption["text"])
    return _add_caption_info_to_doc(doc, caption["tokens_info"])

In [None]:
doc = caption_to_doc(video_info["video_id"])

In [None]:
from typing import Iterator


def get_sents(doc: Doc) -> Iterator[Union[Span, str]]:
    return doc.sents


def get_noun_chunks(doc: Doc) -> Iterator[Union[Span, str]]:
    for chunk in doc.noun_chunks:
        yield f"A photo of {chunk}."


def get_verb_phrases(doc: Doc) -> Iterator[Union[Span, str]]:
    for t in doc:
        if t.pos_ == "VERB":
            subtree = list(t.subtree)
            yield doc[subtree[0].i:subtree[-1].i + 1]


def get_orders(doc: Doc) -> Iterator[Union[Span, str]]:
    for sent in doc.sents:
        if sent[-1].text != "?":
            for t in sent:
                if (t.tag_ == "VB"
                    and t.lower_ not in {"know", "let", "try"}
                    and all(c.dep_ != "aux" for c in t.children)
                    and t.dep_ not in {"auxpass", "xcomp"}):
                    subtree = list(t.subtree)
                    yield doc[subtree[0].i:subtree[-1].i + 1]

In [104]:
from matplotlib.backends.backend_pdf import PdfPages


def show_caption_figures_and_pdf(video_id: str, doc: Doc, encoded_frames: torch.Tensor, clip_model: CLIP,
                                 times: Sequence[float], thumbnail_times: Sequence[float],
                                 thumbnails: Iterable[PIL.Image.Image], text_mode: str = "sents") -> None:
    it = {
        "sents": get_sents,
        "nouns": get_noun_chunks,
        "verb_phrases": get_verb_phrases,
        "orders": get_orders,
    }[text_mode](doc)

    with PdfPages(f"{video_id}.pdf") as pdf_pages:
        for text in tqdm(list(it)):
            create_figure_for_text(encoded_frames, text, clip_model, times, thumbnail_times, thumbnails)
            pdf_pages.savefig(bbox_inches="tight")
            plt.show()

In [None]:
create_figure_for_text(encoded_frames, next(iter(doc.sents)), clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
plt.show()

## Changing the text

In [None]:
doc_test = nlp("It penetrates all seven layers of skin and takes other nutrients deeper into hair follicles because it closely resembles human, skin composition emu, oil blocks DHT dihydrotestosterone, a male hormone, which loves to shrink hair, follicles, 90% of cases of male pattern, baldness occur due to the effect of DHT on hair follicles.")

In [None]:
[t for t in doc_test if t.dep_ == "cc"]

In [None]:
sent = list(doc_test.sents)[0]

In [None]:
(sent.root.text, sent.root.dep_, sent.root.pos_, sent.root.tag_)

In [None]:
[" ".join(t.text for t in c.subtree) for c in sent.root.children]

In [None]:
[(c.text, c.dep_, c.pos_, c.tag_) for c in sent.root.children]

In [None]:
[" ".join(t.text for t in c.subtree) for c in sent if c.pos_ == "VERB"]

In [None]:
from spacy import displacy

displacy.render(doc_test)

In [None]:
doc_test2 = nlp("Almond oil has oleic acid omega-9 68% and vitamin K oleic acid opens pores and hair follicles to receive nutrients.")

In [None]:
list(doc_test2.noun_chunks)

In [None]:
displacy.render(nlp("Shake it really well."), options={"fine_grained": True})

In [None]:
displacy.render(nlp("They wake up and eat breakfast."), options={"fine_grained": True})

In [None]:
displacy.render(nlp("They have to do it."), options={"fine_grained": True})

In [None]:
displacy.render(nlp("You will have to do it."), options={"fine_grained": True})

In [None]:
displacy.render(nlp("You'll have to do it."), options={"fine_grained": True})

In [None]:
[sent for sent in doc.sents if sent.root.tag_ == "VB"]

In [None]:
displacy.render(nlp("Vitamin helps regulate good."), options={"fine_grained": True})

In [None]:
[sent
 for sent in doc.sents
 if any(t.tag_ == "VB"
        and all(c.lower_ not in {"to", "will", "'ll'"} for c in t.children)
        and t.dep_ != "xcomp"
        for t in sent)]

In [None]:
displacy.render(nlp("It is highly concentrated so it can clog pores and must be mixed with other oils to be beneficial."), options={"fine_grained": True})

In [None]:
[sent
 for sent in doc.sents
 if any(t.tag_ == "VB"
        and all(c.lower_ not in {"to", "will", "'ll'"} for c in t.children)
        and t.dep_ not in {"auxpass", "xcomp"}
        for t in sent)]

In [None]:
# More radical; with no auxiliaries.
[sent
 for sent in doc.sents
 if any(t.tag_ == "VB"
        and all(c.dep_ != "aux" for c in t.children)
        and t.dep_ not in {"auxpass", "xcomp"}
        for t in sent)]

In [None]:
displacy.render(nlp("Is it makes skin on scalp grow thicker and stronger which holds the hair tightly in place?"), options={"fine_grained": True})

In [82]:
[sent
 for sent in doc.sents
 if any(t.tag_ == "VB"
        and t.lower_ != "try"
        and all(c.dep_ != "aux" for c in t.children)
        and t.dep_ not in {"auxpass", "xcomp"}
        for t in sent)
 and sent[-1].text != "?"]

[Add Brewster Home Fashions, we design wallpapers that are easy to live with.,
 Thanks to paste the wall technology, it is surprisingly easy to install and remove modern wallpaper in just a matter of hours.,
 Paste paste brush or roller knife or snap off blade level sponge and bucket of Clean.,
 Water, smoothing brush or plastic smoother tape measure 4 inch to 6 inch.,
 And or straight edge a pencil step stool and finally turn off your cell phone and turn on the music.,
 First, you need to make sure the walls are clean of all debris. And that the surface is smooth, spackle and smooth out any holes or rough areas on the walls as they may affect the final finish of your wallpaper.,
 So it is important to take your time and do this first step correctly to begin.,
 You want to start by marking a guideline on the wall for the proper placement of your first strip select, your starting point and measure over one inch less than the width of your wall paper roll.,
 Mark the Spot with pencil in 

## Visualizing Multiple Captions

In [None]:
show_caption_figures_and_pdf(video_info["video_id"], doc, encoded_frames, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])

Just some of them:

In [None]:
show_caption_figures_and_pdf(video_info["video_id"], doc, encoded_frames, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"], text_mode="orders")

## Another Video

In [None]:
video2_info = get_local_video_info("2xVpyPnxg9c")
encoded_frames2 = encode_visual(video2_info["frames"], clip_model, image_preprocessor, device=device)

doc2 = caption_to_doc(video2_info["video_id"])

In [None]:
show_caption_figures_and_pdf(video2_info["video_id"], doc2, encoded_frames2, video2_info["frame_times"], video2_info["thumbnail_times"], video2_info["thumbnails"])

In [None]:
show_caption_figures_and_pdf(video2_info["video_id"], doc2, encoded_frames2, video2_info["frame_times"], video2_info["thumbnail_times"], video2_info["thumbnails"], text_mode="orders")

## 1st vs 3rd person in the Text

In [None]:
create_figure_for_text(encoded_frames, "I'm pouring this liquid into the container.", clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
plt.show()

In [None]:
create_figure_for_text(encoded_frames, "He's pouring the liquid into a container.", clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
plt.show()

In [None]:
import random
import os


video_ids = []
with os.scandir("demo/static/videos/") as it:
    for entry in it:
        if entry.is_file() and entry.name.endswith((".mp4", ".webm")):
            video_ids.append(entry.name.rsplit(".", maxsplit=1)[0])

selected_video_ids = random.sample(video_ids, 100)

In [None]:
from tqdm.auto import tqdm


video_infos = []
encoded_frames_list = []
doc_list = []

for video_id in tqdm(selected_video_ids):
    try:
        video_info = get_local_video_info(video_id)
    except Exception as e:
        print(e)
        continue

    encoded_frames = encode_visual(video_info["frames"], clip_model, image_preprocessor, device=device)

    doc = caption_to_doc(video_info["video_id"])
    
    video_infos.append(video_info)
    encoded_frames_list.append(encoded_frames)
    doc_list.append(doc)

In [None]:
indices = list(range(len(video_infos)))

selected = random.choices(indices, k=200)

In [None]:
with PdfPages("random.pdf") as pdf_pages:
    for i in tqdm(selected):
        video_info = video_infos[i]
        encoded_frames = encoded_frames_list[i]
        doc = doc_list[i]

        sents = list(doc.sents)
        sent = random.choice(sents)
        create_figure_for_text(encoded_frames, sent, clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
        pdf_pages.savefig(bbox_inches="tight")
        plt.show()

In [None]:
orders_list = [list(get_orders(doc)) for doc in doc_list]

with PdfPages("orders.pdf") as pdf_pages:
    for i in tqdm(selected):
        video_info = video_infos[i]
        encoded_frames = encoded_frames_list[i]
        
        orders = orders_list[i]
        order = random.choice(orders)
        
        create_figure_for_text(encoded_frames, order, clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
        pdf_pages.savefig(bbox_inches="tight")
        plt.show()

In [None]:
verb_phrases_list = [list(get_verb_phrases(doc)) for doc in doc_list]

with PdfPages("clauses.pdf") as pdf_pages:
    for i in tqdm(selected):
        video_info = video_infos[i]
        encoded_frames = encoded_frames_list[i]
        
        verb_phrases = verb_phrases_list[i]
        vp = random.choice(verb_phrases)
        
        create_figure_for_text(encoded_frames, vp, clip_model, video_info["frame_times"], video_info["thumbnail_times"], video_info["thumbnails"])
        pdf_pages.savefig(bbox_inches="tight")
        plt.show()