## Masked Autoencoders: Visualization Demo

This is a visualization demo using our pre-trained MAE models. No GPU is needed.

### Prepare
Check environment. Install packages if in Colab.


In [None]:
print('meow')

In [None]:
import sys
import os
import requests
import random

import torch
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image

# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    !git clone https://github.com/facebookresearch/mae.git
    sys.path.append('./mae')
else:
    sys.path.append('..')
import models_mae
from uncertainty_mae import UncertaintyMAE

### Define utils

In [None]:
# define the utils

# imagenet_mean = np.array([0, 0, 0])
# imagenet_std = np.array([1, 1, 1])

imagenet_mean = 255 * np.array([0.485, 0.456, 0.406])
imagenet_std = 255 * np.array([0.229, 0.224, 0.225])

def show_image(image, title='', mean=imagenet_mean, std=imagenet_std):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * std + mean), 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = models_mae.__dict__[arch](norm_pix_loss=False, 
                                    quantile=None, 
                                    vae=True, kld_beta=1)
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    if 'model' in checkpoint:
        checkpoint = checkpoint['model']
    msg = model.load_state_dict(checkpoint, strict=False)
    print(msg)
    print('is vae:', model.vae)
    return model

def prepare_uncertainty_model(chkpt_dir, arch='mae_vit_base_patch16', same_encoder=True):
    visible_model = models_mae.__dict__[arch](norm_pix_loss=False, 
                                    quantile=None, 
                                    vae=False, kld_beta=0)
    invisible_model = models_mae.__dict__[arch](norm_pix_loss=False, 
                                    quantile=None, 
                                    vae=True, kld_beta=0, num_vae_blocks=1)
    model = UncertaintyMAE(visible_mae=None if same_encoder else visible_model, 
                           invisible_mae=invisible_model, same_encoder=same_encoder)
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    if 'model' in checkpoint:
        checkpoint = checkpoint['model']
    try:
        msg = model.load_state_dict(checkpoint, strict=True)
    except RuntimeError as the_error:
        print(the_error)
        assert 'invisible_mae.logVar_zero_conv_weight' not in checkpoint
        assert 'invisible_mae.logVar_zero_conv_bias' not in checkpoint
        assert 'invisible_mae.mean_zero_conv_weight' not in checkpoint
        assert 'invisible_mae.mean_zero_conv_bias' not in checkpoint

        msg = model.load_state_dict(checkpoint, strict=False)

        invisible_mae = model.invisible_mae
        invisible_mae.logVar_zero_conv_weight = torch.nn.Parameter(torch.ones(1))
        invisible_mae.logVar_zero_conv_bias = torch.nn.Parameter(torch.zeros(0))
        invisible_mae.mean_zero_conv_weight = torch.nn.Parameter(torch.ones(1))
        invisible_mae.mean_zero_conv_bias = torch.nn.Parameter(torch.zeros(0))

    print(msg)

    return model

def run_one_image(img, model, mask_ratio=0.75, force_mask=None, mean=imagenet_mean, std=imagenet_std):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    #x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=mask_ratio, force_mask=force_mask)
    if isinstance(model, UncertaintyMAE):
        y = model.visible_mae.unpatchify(y)
    else:
        y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    if isinstance(model, UncertaintyMAE):
        mask = mask.unsqueeze(-1).repeat(1, 1, model.visible_mae.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
        mask = model.visible_mae.unpatchify(mask)  # 1 is removing, 0 is keeping 
    else:
        mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
        mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', x).detach().cpu()

    print(x.mean())

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original", mean=mean, std=std)

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked", mean=mean, std=std)

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction", mean=mean, std=std)

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible", mean=mean, std=std)

    plt.show()

### Load an image

In [None]:
from dataset_generation.emoji_dataset import EmojiDataset
from torchvision import datasets, transforms
# simple augmentation
transform_test = transforms.Compose([
        transforms.Resize((224, 224), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
transform_celeba = transforms.Compose([
    transforms.RandomResizedCrop((224, 224), scale=(0.6, 1.0), interpolation=3),  # 3 is bicubic
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# transform_train = transforms.Compose([
#         transforms.RandomResizedCrop((224, 224), scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
#     ])
emnist_mean = np.array([0.176, 0.176, 0.176])
emnist_std = np.array([0.328, 0.328, 0.328])
emnist_transform = transforms.Compose([
        lambda img: transforms.functional.rotate(img, -90),
        lambda img: transforms.functional.hflip(img),
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize((224, 224), interpolation=3),
        transforms.ToTensor(),
        transforms.Normalize(emnist_mean, emnist_std)
    ])
# dataset2 = datasets.CIFAR100('../data', train=False, download=True,
#                        transform=transform_test)
# dataset2 = datasets.CelebA('/local/zemel/gzg2104/datasets', split='test', target_type='attr', 
#                            transform=transform_celeba, download=True)
dataset2 = datasets.EMNIST('../data', split='balanced', train=False, download=True,
                           transform=emnist_transform)
test_kwargs = {'batch_size': 1}
test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

In [None]:
for idx, img_tuple in enumerate(test_loader):
    print(idx)
    plt.rcParams['figure.figsize'] = [5, 5]
    img, label = img_tuple
    print(img.shape)
    assert img.shape == (1, 3, 224, 224)
    img = img.squeeze()
    print(img.mean())
    show_image(torch.einsum('chw->hwc', img))#, mean=255*emnist_mean, std=255*emnist_std)
    if idx == 11:
        break

In [None]:
# from datasets import load_dataset
# from tqdm import tqdm

# sketch_mean = [0.857, 0.857, 0.857]
# sketch_std = [0.254, 0.254, 0.254]

# transform_test = transforms.Compose([
#         transforms.RandomResizedCrop((224, 224), scale=(0.2, 1.0), interpolation=3),  # 3 is bicubic
#         transforms.RandomHorizontalFlip(),
#         transforms.Grayscale(num_output_channels=3),
#         # transforms.Resize((224, 224), interpolation=3),
#         transforms.ToTensor(),
#         transforms.Normalize(sketch_mean, sketch_std)
#     ])

# def transform_wrapper(examples):
#     examples["image"] = [transform_test(image) for image in examples["image"]]
#     return examples

# dataset = load_dataset("imagenet_sketch", split='train', 
#                        cache_dir='/local/zemel/gzg2104/datasets')

# dataset.set_transform(transform_wrapper)

# dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)

# # for idx, item in tqdm(enumerate(dataloader)):
# #     print('mean', torch.mean(item['image'], dim=(0, 2, 3)))
# #     print('std', torch.std(item['image'], dim=(0, 2, 3)))
# #     break

# for idx, item in enumerate(dataloader):
#     print(item['image'].shape)
#     print(item['label'].shape)

#     print('mean', torch.mean(item['image'], dim=(0, 2, 3)))
#     print('std', torch.std(item['image'], dim=(0, 2, 3)))

#     img = item['image'].squeeze()

#     show_image(torch.einsum('chw->hwc', img), 
#                mean = 255 * np.array(sketch_mean), 
#                std = 255 * np.array(sketch_std))

#     if idx == 5:
#         break

### Load a pre-trained MAE model

In [None]:
# Thanks ChatGPT!

def load_decoder_state_dict(model, chkpt_dir):
    state_dict = torch.load(chkpt_dir)['model']
    # Filter the state_dict to include only the keys for the desired parameters
    filtered_state_dict = {k: v for k, v in state_dict.items() if k.startswith((
        'decoder_embed',
        'mask_token',
        'decoder_pos_embed',
        'decoder_blocks',
        'decoder_norm',
        'decoder_pred'
    ))}

    # Load the filtered state_dict into the model
    # Set strict=False to ignore non-matching keys
    model.load_state_dict(filtered_state_dict, strict=False)

    print('loaded decoder')

In [None]:
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_base.pth

# chkpt_dir = '/local/zemel/gzg2104/_cifar_models/fromScratch_06_21_24_zeroConv_eps_1e-4/checkpoint-600.pth'
# chkpt_dir = '/local/zemel/gzg2104/_cifar_models/06_12_24_batchSize_384/checkpoint-799.pth'
# chkpt_dir = '/local/zemel/gzg2104/_cifar_models/REDO_06_12_24_batchSize_384/checkpoint-700.pth'
# chkpt_dir = '/local/zemel/gzg2104/_celeba_models/initialTry_06_20_24/checkpoint-200.pth'
# chkpt_dir = '/local/zemel/gzg2104/_emnist_models/06_21_24_noZeroConv/checkpoint-40.pth'
chkpt_dir = '/local/zemel/gzg2104/_emnist_models/06_24_24/common_encoder/beta5_blr1e-4_eps1e-8/warmup20_total400/checkpoint-80.pth'
model_mae = prepare_uncertainty_model(chkpt_dir, 'mae_vit_base_patch16', same_encoder=True)
print('Model loaded.')

### Run MAE on the image

In [None]:
print(len(test_loader.dataset))

In [None]:
def randomize_mask_layout(mask_layout, mask_ratio=0.75):
    all_indices = [(i, j) for i in range(mask_layout.shape[0]) for j in range(mask_layout.shape[1])]
    random.shuffle(all_indices)
    for i, j in all_indices[:int(mask_ratio * len(all_indices))]:
        mask_layout[i, j] = 0
    return


In [None]:
model_mae = model_mae.cuda()
model_mae.eval()
random_mask = True
    
print(model_mae)
for idx, img_tuple in enumerate(test_loader):
    print(idx)
    plt.rcParams['figure.figsize'] = [5, 5]
    img, label = img_tuple
    print(label)
    assert img.shape == (1, 3, 224, 224)
    img = img.cuda()
    img = img.squeeze()
    #show_image(torch.einsum('chw->hwc', img))

    torch.manual_seed(idx)
    print('MAE with pixel reconstruction:')
    mask_layout = torch.ones(14, 14).to(device=img.device)
    #print(mask_layout.shape)
    if random_mask:
        randomize_mask_layout(mask_layout, mask_ratio=0.9)
    else:
        # mask_layout[0:7, 0:14] = 0
        # mask_layout[7:14, 7:14] = 0
        mask_layout[0:14, 0:14] = 0
    
    mask_layout = mask_layout.flatten()
    keep_indices = torch.where(mask_layout == 1)[0]
    mask_indices = torch.where(mask_layout == 0)[0]
    keep_indices = keep_indices.reshape(1, -1)
    mask_indices = mask_indices.reshape(1, -1)
    ids_shuffle = torch.cat((keep_indices, mask_indices), dim=1)
    print('run regular')
    mask_ratio = 1 - keep_indices.shape[1] / ids_shuffle.shape[1]
    #print('mask ratio:', mask_ratio)
    for j in range(3):
        print(f'Generate {j}')
        # print('mask layout:', mask_layout)
        # print('mask layout shape:', mask_layout.shape)
        # print('ids shuffle:', ids_shuffle)
        if isinstance(model_mae, UncertaintyMAE):
            run_one_image(img, model_mae, mask_ratio=mask_ratio, force_mask=(keep_indices, mask_indices),
                          mean=255*emnist_mean, std=255*emnist_std)
        # else:
        #     run_one_image(img, model_mae, mask_ratio=mask_ratio, force_mask=ids_shuffle)
    if idx == 20:
        break