In [1]:
import os.path as osp
from torchvision import transforms
from torchvision.io import read_video
from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML

from einops import rearrange

In [2]:
from modeling_pretrain import  pretrain_videomae_small_patch16_224

             This can result in unexpected behavior including runtime errors.
             Reinstall Horovod using `pip install --no-cache-dir` to build with the new version.


In [50]:
def pad_frames(vid, out_len, method='reflect'):
    len_vid = len(vid)
    if out_len < len_vid:
        return vid
    ret = []
    if method == 'reflect':
        pad_size = out_len-len_vid
        pad_left = pad_size//2
        pad_right = pad_size-pad_left
        for il in range(pad_right):
            ret.append(vid[0])
        ret += vid
        for ir in range(pad_right):
            ret.append(vid[-1])
    return ret

num_frames = 16

def load_video(video_path, num_frames=16):
    # Load the video
    # video, audio, info = read_video(video_path, output_format="TCHW")
    video, audio, info = read_video(video_path, output_format="THWC")
    video_shape = video.shape
   
    
      
    
    # Apply the transformation to each frame in the video
    # frames_resized = [transform(frame) for frame in video]
    # frames_permuted = [frame.permute(1, 2, 0) for frame in frames_resized]
    # frames_permuted = [frame.permute(1, 2, 0) for frame in video]
    frames_permuted = video
    frames_permuted = pad_frames(frames_permuted, num_frames, method='reflect')
    
    return frames_permuted

def load_model_hf(model_name='MCG-NJU/videomae-base-short-ssv2'):
    feature_extractor = VideoMAEImageProcessor.from_pretrained(model_name)
    model = VideoMAEForPreTraining.from_pretrained(model_name)
    return model, feature_extractor
    
def run_model_hf(pixel_values, model,p=0.5):
    bool_masked_pos = get_mask(model, p)
    outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
    return outputs, pixel_values

def get_mask(config, p=0.5, method='random'):
    try:
        img_size = config.input_size
    except:
        img_size = 224
    
    try:
        patch_size = config.patch_size[0]
    except:
        patch_size = 16    
        
    tubelet_size = 2
    num_patches_per_frame = (img_size // patch_size) ** 2
    seq_length = (num_frames // tubelet_size) * num_patches_per_frame
    if method == 'random':
        return torch.rand(1, seq_length) < p
    else:
        num_true = int(p * 14)
        mask = torch.zeros(1, seq_length, dtype=torch.bool)
        mask = rearrange(mask, 'b (t h w) -> b t h w',t=8, h=14, w=14)
        if method=='left':
            mask[:, :,:,:num_true] = True
        elif method == 'right':
            mask[:, :,:,-num_true:] = True
        elif method == "top":
            mask[:, :,:num_true,:] = True
        elif method == 'bottom':
            mask[:, :,-num_true:,:] = True
        else:
            raise f'unknown method {method}'  

        mask = rearrange(mask, 'b t h w -> b (t h w)',t=8, h=14, w=14)

    return mask
    
def run_model_mine(pixel_values, model, config, p=0.9, method='random'):
    if p == 0:
        bool_masked_pos = None
    else:
        bool_masked_pos = get_mask(config, p, method)
    outputs = model(pixel_values, mask=bool_masked_pos)
    return outputs, bool_masked_pos

def plot_video(image_list):
    fig, ax = plt.subplots()

    # Function to update the plot with each frame
    def update(frame):
        ax.clear()
        frame_to_show = image_list[frame]
        # 
        # ax.imshow(np.transpose(frame_to_show, [1,2,0]))
        ax.imshow(frame_to_show)
        ax.axis('off')
    
    # Create an animation
    ani = FuncAnimation(fig, update, frames=len(image_list), repeat=True)
    
    # Display the animation as a GIF in the Jupyter Notebook
    display(HTML(ani.to_jshtml()))
    
    # Close the figure to avoid a double display
    plt.close()


In [4]:
model_path = '/home/ubuntu/efs/trained_models/lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint-1364.pth'
model = pretrain_videomae_small_patch16_224(decoder_depth=4)
checkpoint = torch.load(
            model_path, map_location="cpu"
        )
model.load_state_dict(checkpoint['model'])


<All keys matched successfully>

In [36]:
# import os
import glob
# checkpoint_paths = glob.glob('/home/ubuntu/efs/trained_models/lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/*')
# checkpoint_paths

In [37]:
# Define the file path to your video
root_folder = '/data/lsfb_dataset'
all_videos = glob.glob(osp.join(root_folder,'isol','videos','*.mp4'))
# video_path = osp.join(root_folder,'isol','videos','CLSFBI0103A_S001_B_286621_286862.mp4')
video_path = all_videos[10]
video = load_video(video_path)
# plot_video(video)



In [38]:
# model, feature_extractor = load_model(model_path)
# vars(checkpoint['args'])

In [39]:
# video[0].shape
model_hf, feature_extractor = load_model_hf()

In [40]:
# outputs, pixel_values = run_model(video, feature_extractor, model)
# _, feature_extractor = load_model()
pixel_values_raw = feature_extractor(video, return_tensors="pt").pixel_values

pixel_values = rearrange(pixel_values_raw, 'b t c h w -> b c t h w')
# outputs = run_model2(pixel_values, model, checkpoint['args'])
# B, C, T, H, W
# pixel_values.shape

In [41]:
# outputs_hf, pixels_hf = run_model_hf(pixel_values_raw, model_hf,p=0.65)

In [42]:
image_std = np.array(feature_extractor.image_std)[None, :, None, None,  None]
image_mean = np.array(feature_extractor.image_mean)[None, :, None, None,  None]

In [60]:
outputs, mask = run_model_mine(pixel_values, model, checkpoint['args'], p=0.5, method='bottom')
outputs = outputs.detach().numpy()
try:
    mask = mask.detach().numpy()
except:
    pass

In [61]:
output_mean = np.mean(outputs)
output_min = np.min(outputs)
output_max = np.max(outputs)

pixel_values_unnorm = pixel_values*image_std + image_mean
videos_patch = rearrange(pixel_values_unnorm, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=16, p2=16)

outputs_unnorm = (outputs - output_min)/(output_max-output_min)
if mask is None:
   B,num_patches,patch_pix = outputs.shape  
   outputs_reconstruction = outputs_unnorm      
else:
    B, num_patches = mask.shape
    _,_,patch_pix = outputs.shape # 3 c * 16 p * 16 p
    outputs_reconstruction = np.zeros((B, num_patches, patch_pix))
    
    outputs_reconstruction[~mask,:] = videos_patch[~mask,:]
    outputs_reconstruction[mask,:] = outputs_unnorm

# outputs_reconstruction[:,:] = outputs_unnorm
patch_size = model_hf.config.patch_size
video_reconstruction = rearrange(outputs_reconstruction, 'b (t h w) (p0 p1 p2 c) -> b (t p0) c (h p1) (w p2)',t=8, h=14, w=14,  p0=2, p1=patch_size, p2=patch_size)
video_reconstruction = video_reconstruction.squeeze()
video_reconstruction_transposed = video_reconstruction.transpose(0,2,3,1)

In [62]:
plot_video(video_reconstruction_transposed)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Clipping i