In [1]:
import open_clip
from PIL import Image
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-H-14-378-quickgelu', pretrained='dfn5b')

model.to('cuda')
model.eval()

In [2]:
import torch

def _is_tensor_video_clip(clip):
    if not torch.is_tensor(clip):
        raise TypeError("clip should be Tensor. Got %s" % type(clip))

    if not clip.ndimension() == 4:
        raise ValueError("clip should be 4D. Got %dD" % clip.dim())

    return True


def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])


def crop(clip, i, j, h, w):
    """
    Args:
        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
    """
    if len(clip.size()) != 4:
        raise ValueError("clip should be a 4D tensor")
    return clip[..., i: i + h, j: j + w]


def resize(clip, target_size, interpolation_mode):
    if len(target_size) != 2:
        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
    return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)


def resize_scale(clip, target_size, interpolation_mode):
    if len(target_size) != 2:
        raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
    H, W = clip.size(-2), clip.size(-1)
    scale_ = target_size[0] / min(H, W)
    return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)


def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
    """
    Do spatial cropping and resizing to the video clip
    Args:
        clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
        i (int): i in (i,j) i.e coordinates of the upper left corner.
        j (int): j in (i,j) i.e coordinates of the upper left corner.
        h (int): Height of the cropped region.
        w (int): Width of the cropped region.
        size (tuple(int, int)): height and width of resized clip
    Returns:
        clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
    """
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    clip = crop(clip, i, j, h, w)
    clip = resize(clip, size, interpolation_mode)
    return clip


def center_crop(clip, crop_size):
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    h, w = clip.size(-2), clip.size(-1)
    th, tw = crop_size
    if h < th or w < tw:
        raise ValueError("height and width must be no smaller than crop_size")

    i = int(round((h - th) / 2.0))
    j = int(round((w - tw) / 2.0))
    return crop(clip, i, j, th, tw)


def center_crop_using_short_edge(clip):
    if not _is_tensor_video_clip(clip):
        raise ValueError("clip should be a 4D torch.tensor")
    h, w = clip.size(-2), clip.size(-1)
    if h < w:
        th, tw = h, h
        i = 0
        j = int(round((w - tw) / 2.0))
    else:
        th, tw = w, w
        i = int(round((h - th) / 2.0))
        j = 0
    return crop(clip, i, j, th, tw)

class CenterCropResizeVideo:
    '''
    First use the short side for cropping length,
    center crop video, then resize to the specified size
    '''

    def __init__(
            self,
            size,
            interpolation_mode="bilinear",
    ):
        if isinstance(size, tuple):
            if len(size) != 2:
                raise ValueError(f"size should be tuple (height, width), instead got {size}")
            self.size = size
        else:
            self.size = (size, size)

        self.interpolation_mode = interpolation_mode

    def __call__(self, clip):
        """
        Args:
            clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
        Returns:
            torch.tensor: scale resized / center cropped video clip.
                size is (T, C, crop_size, crop_size)
        """
        clip_center_crop = center_crop_using_short_edge(clip)
        clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,
                                         interpolation_mode=self.interpolation_mode)
        return clip_center_crop_resize

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"

In [7]:
from decord import VideoReader, cpu
from decord import bridge

bridge = bridge.set_bridge("torch")

video_path = "/cvdata1/jihwan/mixkit/Cats/mixkit-a-white-cat-sits-in-front-of-a-white-wall-1535.mp4"

vr = VideoReader(video_path, ctx=cpu(0))
video = vr.get_batch(range(len(vr)))
video = video.permute(0, 3, 1, 2)
video.shape

torch.Size([266, 3, 1080, 1920])

In [8]:
center_crop_resize = CenterCropResizeVideo(size=512)
resized_video = center_crop_resize(video)
resized_video.shape

torch.Size([266, 3, 512, 512])

In [15]:

embeddings = []
with torch.no_grad():
    for i in range(1, len(resized_video)):
        image = resized_video[i]
        image = Image.fromarray(image.permute(1, 2, 0).numpy())
        image = preprocess(image)
        image = image.unsqueeze(0).to('cuda')
        embedding = model.encode_image(image)

        embeddings.append(embedding)


In [17]:
cosine_sim = torch.nn.CosineSimilarity(dim=1)

print("consecutive frames")
for i in range(1, 18):
    print(cosine_sim(embeddings[0], embeddings[i]))

consecutive frames
tensor([1.0000], device='cuda:0')
tensor([0.9961], device='cuda:0')
tensor([0.9968], device='cuda:0')
tensor([0.9944], device='cuda:0')
tensor([0.9931], device='cuda:0')
tensor([0.9878], device='cuda:0')
tensor([0.9916], device='cuda:0')
tensor([0.9914], device='cuda:0')
tensor([0.9919], device='cuda:0')
tensor([0.9923], device='cuda:0')
tensor([0.9913], device='cuda:0')
tensor([0.9919], device='cuda:0')
tensor([0.9919], device='cuda:0')
tensor([0.9899], device='cuda:0')
tensor([0.9904], device='cuda:0')
tensor([0.9895], device='cuda:0')
tensor([0.9920], device='cuda:0')


In [18]:
print("consecutive chunks")
for i in range(17, len(embeddings), 17):
    print(cosine_sim(embeddings[0], embeddings[i]))

consecutive chunks
tensor([0.9920], device='cuda:0')
tensor([0.9717], device='cuda:0')
tensor([0.9576], device='cuda:0')
tensor([0.9438], device='cuda:0')
tensor([0.9116], device='cuda:0')
tensor([0.9328], device='cuda:0')
tensor([0.9587], device='cuda:0')
tensor([0.9631], device='cuda:0')
tensor([0.9648], device='cuda:0')
tensor([0.9556], device='cuda:0')
tensor([0.8523], device='cuda:0')
tensor([0.8327], device='cuda:0')
tensor([0.8380], device='cuda:0')
tensor([0.8539], device='cuda:0')
tensor([0.8566], device='cuda:0')


open_clip.model.CLIP