# Masked Autoencoders As Spatiotemporal Learners
This is a visualization demo showcasing a pre-trained MAE-3D model.

### Installation


In [None]:
# Clone the repository and install dependencies
!git clone https://github.com/cyrilzakka/MAE3D
%cd MAE3D
!pip install -r requirements.txt

### Visualization

In [None]:
import torch
import numpy as np

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

import models_mae3d


IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])


def show_sequence(seq, title=''):
    '''Plot a sequence of images horizontally'''
    # clip is [T, H, W, 3]
    assert (seq.shape[-1] == 3) 
    fig = plt.figure(figsize=(20, 2.5))
    fig.suptitle(title, fontsize=16)
    grid = ImageGrid(fig, 111, nrows_ncols=(1, 8), axes_pad=0.1)
    for ax, img in zip(grid, seq):
        ax.imshow(torch.clip((img * IMAGENET_STD + IMAGENET_MEAN) * 255, 0, 255).int())
        ax.set_axis_off()
    plt.show()
    return

def prepare_model(chkpt_dir, arch='mae3d_vit_large_patch16'):
    '''Retrieve model from checkpoint'''
    # build model
    model = getattr(models_mae3d, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def run_one_image(seq, model):
    '''Run model on one video sequence'''
    x = torch.tensor(seq)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nthwc->ntchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.90)
    y = model.unpatchify3D(y)
    y = torch.einsum('ntchw->nthwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[1]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify3D(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('ntchw->nthwc', mask).detach().cpu()
    
    x = torch.einsum('ntchw->nthwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [20, 10]

    plt.subplot(1, 4, 1)
    show_sequence(x[0], "original")

    plt.subplot(1, 4, 2)
    show_sequence(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_sequence(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_sequence(im_paste[0], "reconstruction + visible")

    plt.show()