In [None]:
import os.path as osp
import os
os.chdir('..')
print(os.getcwd())

import pandas as pd
from torchvision import transforms
from torchvision.io import read_video
# from transformers import VideoMAEImageProcessor, VideoMAEForPreTraining
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import display, HTML

from einops import rearrange
from modeling_pretrain import (
    pretrain_videomae_small_patch16_224,
    pretrain_videomae_base_patch16_224,
    pretrain_videomae_huge_patch16_224
)

import glob
from types import SimpleNamespace
from datasets import build_dataset
import video_transforms as video_transforms
import volume_transforms as volume_transforms

IMG_STD = [0.229, 0.224, 0.225]
IMG_MEAN = [0.485, 0.456, 0.406]


def pad_frames(vid, out_len, method='reflect'):
    len_vid = len(vid)
    if out_len < len_vid:
        return vid
    ret = []
    if method == 'reflect':
        pad_size = out_len - len_vid
        pad_left = pad_size // 2
        pad_right = pad_size - pad_left
        for il in range(pad_right):
            ret.append(vid[0])
        ret += vid
        for ir in range(pad_right):
            ret.append(vid[-1])
    return ret


num_frames = 16


def load_video(video_path, num_frames=16):
    # Load the video
    # video, audio, info = read_video(video_path, output_format="TCHW")
    video, audio, info = read_video(video_path, output_format="THWC")
    video_shape = video.shape

    # Apply the transformation to each frame in the video
    # frames_resized = [transform(frame) for frame in video]
    # frames_permuted = [frame.permute(1, 2, 0) for frame in frames_resized]
    # frames_permuted = [frame.permute(1, 2, 0) for frame in video]
    frames_permuted = video
    frames_permuted = pad_frames(frames_permuted, num_frames, method='reflect')

    return frames_permuted


# def load_model_hf(model_name='MCG-NJU/videomae-base-short-ssv2'):
#     feature_extractor = VideoMAEImageProcessor.from_pretrained(model_name)
#     model = VideoMAEForPreTraining.from_pretrained(model_name)
#     return model, feature_extractor


# def run_model_hf(pixel_values, model, p=0.5):
#     bool_masked_pos = get_mask(model, p)
#     outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
#     return outputs, pixel_values


def get_mask(p=0.5, method='random', config=None):
    if p == 0:
        return None

    if config is None:
        img_size = 224
        patch_size = 16
    else:
        img_size = config.input_size
        patch_size = config.patch_size[0]

    tubelet_size = 2
    num_patches_per_frame = (img_size // patch_size) ** 2
    seq_length = (num_frames // tubelet_size) * num_patches_per_frame
    if method == 'random':
        return torch.rand(1, seq_length) < p
    else:
        num_true = int(p * 14)
        mask = torch.zeros(1, seq_length, dtype=torch.bool)
        mask = rearrange(mask, 'b (t h w) -> b t h w', t=8, h=14, w=14)
        if method == 'left':
            mask[:, :, :, :num_true] = True
        elif method == 'right':
            mask[:, :, :, -num_true:] = True
        elif method == "top":
            mask[:, :, :num_true, :] = True
        elif method == 'bottom':
            mask[:, :, -num_true:, :] = True
        elif method == 'horizontal':
            mask[:, :, ::2, :] = True
        elif method == 'vertical':
            mask[:, :, :, ::2] = True
        elif method == 'last':
            mask[:, -1, :, :] = True
        elif method == 'mid':
            inds = [a for a in range(14) if a not in [7]]
            mask[:, 3, :, inds] = True
        elif method == 'grid':
            mask[:, :, ::2, ::2] = True
        elif method == 'inv_grid':
            mask[:, :, :, :] = True
            mask[:, :, ::2, ::2] = False    
        else:
            raise f'unknown method {method}'

        mask = rearrange(mask, 'b t h w -> b (t h w)', t=8, h=14, w=14)

    return mask


def run_model_mine(pixel_values, model, config, p=0.9, method='random', reconstruction_mode=False):
    if p == 0:
        bool_masked_pos = None
    else:
        bool_masked_pos = get_mask(p, method, config=config)
    outputs = model(pixel_values, mask=bool_masked_pos, reconstruction_mode=reconstruction_mode)
    return outputs, bool_masked_pos


def create_directory_if_not_exists(directory):
    if not os.path.exists(directory):
        os.makedirs(directory)


def plot_video(image_list, save_path=None, override_save=False):
    fig, ax = plt.subplots()

    # Function to update the plot with each frame
    def update(frame):
        ax.clear()
        frame_to_show = image_list[frame]
        #
        # ax.imshow(np.transpose(frame_to_show, [1,2,0]))
        ax.imshow(frame_to_show)
        ax.axis('off')

    # Create an animation
    ani = FuncAnimation(fig, update, frames=len(image_list), repeat=True)

    # Save the animation
    if save_path is not None:
        # create folders if they don't exist
        create_directory_if_not_exists(directory=osp.dirname(save_path))

        # check if a file with the same name exists
        if osp.exists(save_path) and (not override_save):
            raise f'{save_path} allready exists, set override to true to override'

        # get writer name
        file_type = os.path.splitext(save_path)[-1][1:]
        if file_type == 'gif':
            writer = 'imagemagick'
        elif file_type == 'mp4':
            writer = 'ffmpeg'
        else:
            raise f'Save video type must gif or mp4 but it was {file_type} instead'
        # save the file
        print(f'Saving file {save_path}')
        ani.save(save_path, writer=writer, fps=30)  # Adjust the fps as needed

    # Display the animation as a GIF in the Jupyter Notebook
    display(HTML(ani.to_jshtml()))

    # Close the figure to avoid a double display
    plt.close()


def transform_video(video):
    # pixel_values_raw = feature_extractor(video, return_tensors="pt").pixel_values
    # pixel_values = rearrange(pixel_values_raw, 'b t c h w -> b c t h w')

    pixel_values_raw = data_transform(video)
    pixel_values = rearrange(pixel_values_raw, 'F C H W -> 1 F C H W')

    return pixel_values


def post_processing(outputs, mask, pixel_values, image_std, image_mean, reconstruction_mode):
    outputs = outputs.detach().numpy()
    try:
        mask = mask.detach().numpy()
    except:
        pass
    image_std_torch = np.array(image_std)[None, :, None, None, None]
    image_mean_torch = np.array(image_mean)[None, :, None, None, None]
    pixel_values_unnorm = pixel_values * image_std_torch + image_mean_torch
    videos_patch = rearrange(pixel_values_unnorm, 'b c (t p0) (h p1) (w p2) -> b (t h w) (p0 p1 p2 c)', p0=2, p1=16,
                             p2=16)
    patch_size = 16
    # output_mean = 0.48
    # output_std = 0.08
    # outputs_unnorm = outputs * output_std + output_mean
    outputs_unnorm = outputs

    if mask is None:
        outputs_reconstruction = outputs_unnorm
    else:
        B, num_patches = mask.shape
        _, _, patch_pix = outputs.shape  # 3 c * 16 p * 16 p
        outputs_reconstruction = np.zeros((B, num_patches, patch_pix))

        if reconstruction_mode:
            num_mask = np.sum(~mask)
            outputs_reconstruction[~mask, :] = outputs_unnorm[:, :num_mask, :]
            outputs_reconstruction[mask, :] = outputs_unnorm[:, num_mask:, :]
        else:
            outputs_reconstruction[~mask, :] = videos_patch[~mask, :]
            outputs_reconstruction[mask, :] = outputs_unnorm

    # outputs_reconstruction[:,:] = outputs_unnorm
   
    video_reconstruction = rearrange(outputs_reconstruction, 'b (t h w) (p0 p1 p2 c) -> b (t p0) c (h p1) (w p2)', t=8,
                                     h=14, w=14, p0=2, p1=patch_size, p2=patch_size)
    video_reconstruction = video_reconstruction.squeeze()
    video_reconstruction_transposed = video_reconstruction.transpose(0, 2, 3, 1)
    return video_reconstruction_transposed


def reconstruct_video(model, video, mask_prob=0.5, mask_method='bottom', reconstruction_mode=False, save_path=None,
                      override_save=False, args=None, image_std=IMG_STD, image_mean=IMG_MEAN):
    '''
    Run inference on a video and recreate the video and saves it as an mp4 or gif file
    '''
    pixel_values = transform_video(video)
    mask = get_mask(p=mask_prob, method=mask_method, config=args)
    outputs = model(pixel_values, mask=mask, reconstruction_mode=reconstruction_mode)

    video_reconstruction_transposed = post_processing(outputs, mask, pixel_values, image_std, image_mean,
                                                      reconstruction_mode)

    plot_video(video_reconstruction_transposed, save_path=save_path, override_save=override_save)


def load_model(model_path, model_type=None):
    if model_type is None:
        model_type = pretrain_videomae_small_patch16_224

    model = model_type(decoder_depth=4)
    checkpoint = torch.load(
        model_path, map_location="cpu"
    )
    model.load_state_dict(checkpoint['model'])
    return model


def get_name_from_path(path):
    return osp.split(path)[-1].split('.')[0]

In [None]:
# model_hf, feature_extractor = load_model_hf()
# model_path = '/home/ubuntu/efs/trained_models/lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint-1364.pth'
# model_path = '/home/ubuntu/efs/trained_models/lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint-1414.pth'
# model_path = '/home/ubuntu/efs/trained_models/lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint-1569.pth'

# model_path = '/home/ubuntu/efs/trained_models/Kinetics-400_finetune_lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e400/checkpoint-400.pth'
# model_path = '/home/ubuntu/efs/videoMAE/pretrained/VideoMAE_ViT-S_checkpoint_Kinetics-400.pth'
model_path = '/home/ubuntu/efs/videoMAE/pretrained/VideoMAE_ViT-B_checkpoint_Kinetics-400.pth'
# model_path = '/home/ubuntu/efs/videoMAE/pretrained/VideoMAE _ViT-H_checkpoint_Kinetics-400.pth'

# model_path = '/home/ubuntu/efs/videoMAE/pretrained/VideoMAE_ViT-S_checkpoint_ssv2.pth'


# model = pretrain_videomae_small_patch16_224(decoder_depth=4)
model = pretrain_videomae_base_patch16_224(decoder_depth=4)
# model = pretrain_videomae_huge_patch16_224()
checkpoint = torch.load(
    model_path, map_location="cpu"
)
model.load_state_dict(checkpoint['model'])



In [None]:
# checkpoint = torch.load(
#     model_path, map_location="'cuda'"
# )
# model.load_state_dict(checkpoint['model'])

In [None]:
data_transform = video_transforms.Compose([
                video_transforms.PadToSquare(pad_value=0),
                video_transforms.Resize(244, interpolation='bilinear'),
                video_transforms.CenterCrop(size=(224,224)),
                volume_transforms.ClipToTensor(),                
                video_transforms.Normalize(mean=IMG_MEAN,
                                           std=IMG_STD)
            ])

In [None]:
# root_folder = '/data/lsfb_dataset'
# all_videos_paths = glob.glob(osp.join(root_folder, 'isol', 'videos', '*.mp4'))

root_folder = '/videos/mpi_data/2Itzik/dyadic_communication/PIS_ID_000_SPLIT/Cam3_segmented_split/'
all_videos_paths = glob.glob(osp.join(root_folder, '*', '*.mp4'))

video_path = all_videos_paths[9]
video = load_video(video_path)
video = [a.numpy() for a in video]

In [None]:
# len(video)

In [None]:
# reconstruct_video(model, video, mask_prob=0.5, mask_method='inv_grid',save_path=None, reconstruction_mode=False)

In [None]:
mask_prob = 0.5
mask_method = 'horizontal'
override_save = False
save_path = None
reconstruction_mode = True
args = None
n_frames = 16

vid_len = len(video)
n_segments = int(np.ceil(vid_len/n_frames))
vid_len_final = int(n_segments * n_frames)

padded_video = pad_frames(video, out_len=vid_len_final, method='reflect')
pixel_values = transform_video(padded_video)

pixel_values = rearrange(pixel_values, 'b  c (t s) h w -> (b s) c t h w', t=n_frames, s=n_segments)

outputs = model(pixel_values, mask=None, reconstruction_mode=reconstruction_mode).detach().numpy()






In [None]:
pixel_values.shape

In [None]:

patch_size = 16
# Reshape the output to the original shape
outputs_re = rearrange(outputs, 'b (t h w) (p0 p1 p2 c) -> b c (t p0) (h p1) (w p2)', t=8,
                                     h=14, w=14, p0=2, p1=16, p2=16)

# # Normalize the output using the channels 
# std_re = np.nanstd(outputs_re, axis=(0,1,3,4), keepdims=True)
# mean_re =  np.nanmean(outputs_re, axis=(0,1,3,4), keepdims=True)
# outputs_re  = (outputs_re - mean_re ) / std_re

# Unnormalize the input
image_std_torch = np.array(IMG_STD)[None, :, None, None, None]
image_mean_torch = np.array(IMG_MEAN)[None, :, None, None, None]

# Shift the outputs to have the same mean and std as the original image
outputs_re_unnorm  = outputs_re * image_std_torch + image_mean_torch
video_reconstruction_transposed = rearrange(outputs_re_unnorm, '(b s) c t h w -> 1 (b t s) h w c', t=n_frames, s=n_segments).squeeze()


plot_video(video_reconstruction_transposed, save_path=save_path, override_save=override_save)

In [None]:
image_std_torch = np.array(IMG_STD)[None, :, None, None, None]
image_mean_torch = np.array(IMG_MEAN)[None, :, None, None, None]
pixel_values_unnorm = pixel_values * image_std_torch + image_mean_torch


In [None]:
outputs_re.shape

In [None]:
torch.std(pixel_values_unnorm, dim=(0,2,3,4), keepdim=True)

In [None]:
mean_re = np.nanmean(outputs_re,axis=[0,1,3,4] )
std_re =  np.nanstd(outputs_re, axis=[0,1,3,4] )

In [None]:
output_mean = 0.48
output_std = 0.08
outputs_unnorm = outputs * output_std + output_mean
# outputs_unnorm = (outputs - output_mean) * output_std

In [None]:
torch.std(pixel_values_unnorm, dim=(0,2,3,4))

In [None]:
np.min(video_reconstruction_transposed)

In [None]:
reconstruct_video(model, [a.numpy() for a in video], mask_prob=0.5, mask_method='grid',save_path=None, reconstruction_mode=True)

In [None]:
# reconstruct_video(model, video, mask_prob=0.5, mask_method='random',save_path=None)
# reconstruct_video(model, video, mask_prob=0.5, mask_method='horizontal',save_path=None, reconstruction_mode=True)
reconstruct_video(model, [a.numpy() for a in video], mask_prob=0.99, mask_method='horizontal',save_path=None, reconstruction_mode=True)
# reconstruct_video(model, video, mask_prob=0, mask_method='random',save_path=None)

In [None]:
trained_models_paths_all = glob.glob('/home/ubuntu/efs/trained_models/lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/*.pth')
trained_models_checkpoint_numbers = [int(a.split('checkpoint-')[-1].split('.pth')[0]) for a in trained_models_paths_all]
df_model_ckpt = pd.DataFrame(zip(trained_models_paths_all,trained_models_checkpoint_numbers), columns=['paths','number'])


In [None]:
df_model_ckpt = df_model_ckpt.sort_values('number').reset_index(drop=True)
inds = list(range(0,300,50))+list(range(300,len(df_model_ckpt),5))+[len(df_model_ckpt)-1]
lsfb_paths = df_model_ckpt.iloc[inds].paths.values

# trained_models_paths = df_model_ckpt.ilocp[]

In [None]:
# load videos, models, and create examples
models_folder = '/home/ubuntu/efs/videoMAE/pretrained'
model_checkpoint_paths_pretraind = ['VideoMAE_ViT-S_checkpoint_ssv2.pth','VideoMAE_ViT-S_checkpoint_Kinetics-400.pth']
model_checkpoint_paths_pretraind = [osp.join(models_folder,p) for p in model_checkpoint_paths_pretraind]

save_folder = '/home/ubuntu/efs/videoMAE/generated_videos'

all_model_paths = list(lsfb_paths) + model_checkpoint_paths_pretraind




In [None]:
all_model_names = [get_name_from_path(a) for a in all_model_paths]
all_model_names

In [None]:
# 
# # load all videos
# videos_paths = all_videos_paths[1:5]
# all_videos_loaded = {}
# for video_path in videos_paths:
#     video_name =  get_name_from_path(video_path)
#     all_videos_loaded[video_name] = load_video(video_path)
#     
# 
# for model_path in all_model_paths:
#     model = load_model(model_path)
#     model_name = get_name_from_path(osp.split(model_path)[-1].split('.')[0])
#     for video_name, video in all_videos_loaded.items():
#         video_name += '_reconstructed.gif'        
#         save_path = osp.join(save_folder, model_name, video_name)
#         try:
#             reconstruct_video(model, video, mask_prob=0.5, mask_method='horizontal',save_path=save_path, reconstruction_mode=True)
#         except:
#             print(f'Skipping file {save_path}')


In [None]:
import os
import imageio
import matplotlib.pyplot as plt
from itertools import product

def compare_gifs_in_grid_old(models_folder, gif_names, all_model_names=None):
    if all_model_names is None:
        models_list = os.listdir(models_folder)
    else:
        models_list = all_model_names
        
    models_list = [a for a in models_list if osp.exists(osp.join(models_folder,a))]
        
    num_models = len(models_list)
    num_gifs = len(gif_names)

    fig, axes = plt.subplots(num_models, num_gifs, figsize=(num_gifs * 5, num_models * 5))

    for i, model_folder in enumerate(models_list):
        model_path = os.path.join(models_folder, model_folder)

        for j, gif_name in enumerate(gif_names):
            try:
                gif_path = os.path.join(model_path, gif_name)
                
                # Read the GIF using imageio
                gif = imageio.mimread(gif_path)
    
                # Display the GIF in the corresponding grid cell
                axes[i, j].imshow(gif[0])  # Display the first frame for simplicity
                # axes[i, j].set_title(f"{model_folder} - {gif_name}")
                axes[i, j].set_title(f"{model_folder}")
                axes[i, j].axis('off')
            except:
                continue

    plt.tight_layout()
    plt.show()

# Example usage:
save_folder = '/home/ubuntu/efs/videoMAE/generated_videos'
models_folder = save_folder
gif_names = ['CLSFBI2314A_S048_B_110920_111100_reconstructed.gif', 'CLSFBI3306A_S067_B_68737_68857_reconstructed.gif']
compare_gifs_in_grid_old(models_folder, gif_names, all_model_names=all_model_names)


In [None]:
pixel_values[0].shape


In [None]:
import os
import imageio
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display
from itertools import product

def compare_gifs_in_grid(models_folder, gif_names, all_model_names=None, save_path=None, display_flag=False):
    if save_path is None and not display_flag:
        Warning(f'This function will neither save nor display the images')
        
    if all_model_names is None:
        models_list = os.listdir(models_folder)
    else:
        models_list = all_model_names
        
    models_list = [a for a in models_list if os.path.exists(os.path.join(models_folder, a))]
        
    num_models = len(models_list)
    num_gifs = len(gif_names)

    fig, axes = plt.subplots(num_models, num_gifs, figsize=(num_gifs * 5, num_models * 5))

    def update(frame):
        for i, model_folder in enumerate(models_list):
            model_path = os.path.join(models_folder, model_folder)

            for j, gif_name in enumerate(gif_names):
                try:
                    gif_path = os.path.join(model_path, gif_name)

                    # Read the GIF using imageio
                    gif = imageio.mimread(gif_path)

                    # Display the GIF in the corresponding grid cell
                    axes[i, j].imshow(gif[frame])
                    axes[i, j].set_title(f"{model_folder}")
                    axes[i, j].axis('off')
                except:
                    continue

    # Get the maximum number of frames among all GIFs
    max_frames = max(imageio.get_reader(os.path.join(models_folder, model, gif_names[0])).get_length()
                    for model in models_list)

    # Create an animation
    ani = FuncAnimation(fig, update, frames=max_frames, repeat=True)

    if display_flag:
        # Display the animation in the Jupyter Notebook
        display(HTML(ani.to_jshtml()))
    if save_path is not None:
        print(f'saving animation at {save_path}')
        ani.save(save_path, writer='Pillow', fps=30)  # Adjust the fps as needed
        print('Done')

    plt.tight_layout()
    plt.show()
    return ani



In [None]:
all_model_names = [
    '/home/ubuntu/efs/trained_models/lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e1600/checkpoint-1569.pth',
    '/home/ubuntu/efs/videoMAE/pretrained/VideoMAE_ViT-S_checkpoint_Kinetics-400.pth',
    '/home/ubuntu/efs/trained_models/Kinetics-400_finetune_lsfb_isol_videomae_pretrain_small_patch16_224_frame_16x4_tube_mask_ratio_0.9_e400/checkpoint-400.pth'
]
gif_names = ['CLSFBI2314A_S048_B_110920_111100_reconstructed.gif', 'CLSFBI3306A_S067_B_68737_68857_reconstructed.gif']
compare_gifs_in_grid(models_folder='/home/ubuntu/efs/videoMAE/generated_videos', gif_names=gif_names, all_model_names=all_model_names,save_path=None,display_flag=True)

In [None]:
all_model_names

In [None]:
# Example usage:
save_folder = '/home/ubuntu/efs/videoMAE/generated_videos'
models_folder = save_folder
save_path = osp.join(models_folder,'Results.gif')
gif_names = ['CLSFBI2314A_S048_B_110920_111100_reconstructed.gif', 'CLSFBI3306A_S067_B_68737_68857_reconstructed.gif']
ani = compare_gifs_in_grid(models_folder, gif_names, all_model_names=all_model_names, save_path=save_path, display_flag=False)
