## Masked Autoencoders: Visualization Demo

This is a visualization demo using our pre-trained MAE models. No GPU is needed.

### Prepare
Check environment. Install packages if in Colab.


In [1]:
import sys
import os
import requests

import torch
import torchvision
import numpy as np

import matplotlib.pyplot as plt
from PIL import Image,ImageFilter
import torchvision.transforms as T

# check whether run in Colab
if 'google.colab' in sys.modules:
    print('Running in Colab.')
    !pip3 install timm==0.4.5  # 0.3.2 does not work in Colab
    !git clone https://github.com/facebookresearch/mae.git
    sys.path.append('./mae')
else:
    sys.path.append('..')
import models_mae

In [2]:
import sys
print(sys.version)

3.9.13 (main, Aug 25 2022, 18:29:29) 
[Clang 12.0.0 ]


### Define utils

In [3]:
# define the utils

imagenet_mean = np.array([0.485, 0.456, 0.406])
imagenet_std = np.array([0.229, 0.224, 0.225])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int())
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    #     import pdb;pdb.set_trace()
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model

def get_original_image_in_tensor(img,model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)
    x = torch.einsum('nchw->nhwc', x)
    return x[0]
    

def run_one_image(img, model):
    x = torch.tensor(img)

    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()
    return x[0], im_paste[0]

### Load an image

In [4]:
# load an image
def load_image(image_name, apply_filter=False,filter_size=3):
    img_url = 'https://user-images.githubusercontent.com/11435359/147738734-196fd92f-9260-48d5-ba7e-bf103d29364d.jpg' # fox, from ILSVRC2012_val_00046145
    # img_url = 'https://user-images.githubusercontent.com/11435359/147743081-0428eecf-89e5-4e07-8da5-a30fd73cc0ba.jpg' # cucumber, from ILSVRC2012_val_00047851
    img = Image.open(image_name)
    if apply_filter:
        im = apply_median_filter(img,window_size=filter_size)
        im = im.resize((224, 224))
        im = np.array(im) / 255.
        if im.shape != (224, 224, 3):
            return im
        im = im - imagenet_mean
        im = im / imagenet_std
        
    img = img.resize((224, 224))
    img = np.array(img) / 255.
    if img.shape != (224, 224, 3):
        return img
#     assert img.shape == (224, 224, 3) or img.shape == (224, 224)

    # normalize by ImageNet mean and std
    img = img - imagenet_mean
    img = img / imagenet_std
    

    plt.rcParams['figure.figsize'] = [5, 5]
    show_image(torch.tensor(img))
    return img,im

def apply_median_filter(image,window_size=3):
    im = image.filter(ImageFilter.GaussianBlur(radius=window_size)) 
    return im


### Load a pre-trained MAE model

In [5]:
# This is an MAE model trained with pixels as targets for visualization (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth

chkpt_dir = 'mae_visualize_vit_large.pth'
model_mae = prepare_model(chkpt_dir, 'mae_vit_large_patch16')
print('Model loaded.')


File 'mae_visualize_vit_large.pth' already there; not retrieving.

> [0;32m/var/folders/yq/h_8tpsgn7mg08_j54v167d_r0000gn/T/ipykernel_19289/238310871.py[0m(19)[0;36mprepare_model[0;34m()[0m
[0;32m     17 [0;31m    [0;31m# load model[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     18 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m[0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 19 [0;31m    [0mcheckpoint[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mload[0m[0;34m([0m[0mchkpt_dir[0m[0;34m,[0m [0mmap_location[0m[0;34m=[0m[0;34m'cpu'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     20 [0;31m    [0mmsg[0m [0;34m=[0m [0mmodel[0m[0;34m.[0m[0mload_state_dict[0m[0;34m([0m[0mcheckpoint[0m[0;34m[[0m[0;34m'model'[0m[0;34m][0m[0;34m,[0m [0mstrict[0m[0;34m=[0m[0;32mFalse[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     21 [0;31m    [0mprint[0m[0;34m([0m[0mmsg[0m[0;

### Run MAE on the image

In [None]:
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with pixel reconstruction:')
run_one_image(img, model_mae)

### Load another pre-trained MAE model

In [None]:
# This is an MAE model trained with an extra GAN loss for more realistic generation (ViT-Large, training mask ratio=0.75)

# download checkpoint if not exist
!wget -nc https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large_ganloss.pth

chkpt_dir = 'mae_visualize_vit_large_ganloss.pth'
model_mae_gan = prepare_model('mae_visualize_vit_large_ganloss.pth', 'mae_vit_large_patch16')
print('Model loaded.')

### Run MAE on the image

In [None]:
# make random mask reproducible (comment out to make it change)
torch.manual_seed(2)
print('MAE with extra GAN loss:')
run_one_image(img, model_mae_gan)

# Processing High frequency images from Image Net

In [None]:
hf_rmse

In [None]:
original1.ndim

In [None]:
# output1 = torch.mean(output, dim=0, keepdim=True)  # Convert to grayscale
# original1 = torch.mean(original, dim=0, keepdim=True)  # Convert to grayscale

# Calculate the element-wise squared differences
squared_diff = (output - original) ** 2

# Calculate the mean of the squared differences
mean_squared_diff = torch.mean(squared_diff)

# Calculate the RMSE value
rmse = torch.sqrt(mean_squared_diff)
float(rmse)

In [None]:
original1.flatten()[0]

In [None]:
imagenet_lf_rmse = []
low_frequency_images = os.listdir("ImageNetWorst/")
len(low_frequency_images)
for image in low_frequency_images:
    print(f"ImageNetWorst/{image}")
    image = load_image(f"ImageNetWorst/{image}")
    if image.shape != (224, 224, 3):
        continue
    original, output = run_one_image(image, model_mae)
    output1 = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    original1 = torch.clip((original * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    squared_diff = (output1 - original1) ** 2
    # Calculate the mean of the squared differences
    mean_squared_diff = torch.mean(squared_diff* 1.0)

    # Calculate the RMSE value
    rmse = float(torch.sqrt(mean_squared_diff))
    imagenet_lf_rmse.append(rmse)

In [None]:
imagenet_hf_rmse = []
high_frequency_images = os.listdir("ImageNetBest/")
len(high_frequency_images)
for image in high_frequency_images:
    print(f"ImageNetBest/{image}")
    try:
        image,filtered = load_image(f"ImageNetBest/{image}",True, 4)
    except:
        continue
    if image.shape != (224, 224, 3):
        continue
    original, output = run_one_image(filtered, model_mae)
    original = get_original_image_in_tensor(image,model_mae)
    output1 = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    original1 = torch.clip((original * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    squared_diff = (output1 - original1) ** 2
    # Calculate the mean of the squared differences
    mean_squared_diff = torch.mean(squared_diff* 1.0)

    # Calculate the RMSE value
    rmse = float(torch.sqrt(mean_squared_diff))
    imagenet_hf_rmse.append(rmse)

In [None]:
imagenet_hf_rmse = []
high_frequency_images = os.listdir("ImageNetBest/")
len(high_frequency_images)
for image in high_frequency_images:
    print(f"ImageNetBest/{image}")
    try:
        import pdb;pdb.set_trace()
        image,filtered = load_image(f"ImageNetBest/{image}",True, )
    except:
        continue
    if image.shape != (224, 224, 3):
        continue
    original, output = run_one_image(filtered, model_mae)
    original = get_original_image_in_tensor(image,model_mae)
    output1 = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    original1 = torch.clip((original * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    squared_diff = (output1 - original1) ** 2
    # Calculate the mean of the squared differences
    mean_squared_diff = torch.mean(squared_diff* 1.0)

    # Calculate the RMSE value
    rmse = float(torch.sqrt(mean_squared_diff))
    imagenet_hf_rmse.append(rmse)

In [None]:
coco_hf_rmse = []
high_frequency_images = os.listdir("CocoBest/")
len(high_frequency_images)
for image in high_frequency_images:
    print(f"CocoBest/{image}")
    try:
        image,filtered = load_image(f"CocoBest/{image}",True, 4)
    except:
        continue
    if image.shape != (224, 224, 3):
        continue
    original, output = run_one_image(filtered, model_mae)
    original = get_original_image_in_tensor(image,model_mae)
    output1 = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    original1 = torch.clip((original * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    squared_diff = (output1 - original1) ** 2
    # Calculate the mean of the squared differences
    mean_squared_diff = torch.mean(squared_diff* 1.0)

    # Calculate the RMSE value
    rmse = float(torch.sqrt(mean_squared_diff))
    coco_hf_rmse.append(rmse)

In [20]:
coco_hf_rmse

[39.72161102294922,
 41.466243743896484,
 49.20736312866211,
 34.625274658203125,
 34.31605529785156,
 42.78290939331055,
 32.17943572998047,
 38.8576545715332,
 38.84165573120117,
 31.60117530822754,
 37.012290954589844,
 43.367919921875,
 50.04998016357422,
 44.34191131591797,
 41.67820358276367,
 41.114418029785156,
 42.621883392333984,
 41.17630386352539,
 33.495262145996094,
 34.17367172241211,
 44.73710250854492,
 48.44806671142578,
 33.37726974487305,
 48.968231201171875,
 46.04912185668945,
 60.63768768310547,
 34.939998626708984,
 33.500667572021484,
 35.64263153076172,
 37.16334533691406,
 52.74317169189453,
 37.7349967956543,
 36.653839111328125,
 37.263710021972656,
 45.838558197021484,
 36.1551513671875,
 43.28950119018555,
 46.8079833984375,
 32.292118072509766,
 44.36004638671875,
 38.27259826660156,
 46.55219268798828,
 44.02777099609375,
 46.89068603515625,
 36.84990692138672,
 36.53355026245117,
 41.98127365112305,
 37.630489349365234,
 43.445945739746094,
 41.4486198

In [None]:
coco_lf_rmse = []
high_frequency_images = os.listdir("CocoWorst/")
len(high_frequency_images)
for image in high_frequency_images:
    print(f"CocoWorst/{image}")
    image = load_image(f"CocoWorst/{image}")
    original, output = run_one_image(image, model_mae)
    output1 = torch.clip((output * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    original1 = torch.clip((original * imagenet_std + imagenet_mean) * 255, 0, 255).int().flatten()
    squared_diff = (output1 - original1) ** 2
    # Calculate the mean of the squared differences
    mean_squared_diff = torch.mean(squared_diff* 1.0)

    # Calculate the RMSE value
    rmse = float(torch.sqrt(mean_squared_diff))
    coco_lf_rmse.append(rmse)

In [None]:
len(coco_lf_rmse)

In [None]:
np.mean(coco_lf_rmse)

In [None]:
np.mean(coco_hf_rmse)

In [10]:
np.mean(imagenet_hf_rmse)

39.29011320141439

In [None]:
np.mean(imagenet_hf_rmse)

In [None]:
np.mean(imagenet_lf_rmse)

In [None]:
coco_lf_rmse

In [None]:
coco_hf_rmse

In [11]:
imagenet_hf_rmse

[38.00617218017578,
 44.72846221923828,
 47.58067321777344,
 31.17874526977539,
 52.757381439208984,
 33.3055534362793,
 40.90459060668945,
 33.5127067565918,
 25.105764389038086,
 39.203948974609375,
 43.28631591796875,
 37.98569869995117,
 40.017723083496094,
 39.04658508300781,
 38.80532455444336,
 50.775177001953125,
 37.70596694946289,
 36.92179870605469,
 32.62538528442383,
 47.406253814697266,
 34.771240234375,
 38.84030532836914,
 37.69122314453125,
 37.46937561035156,
 36.242733001708984,
 32.4649772644043,
 46.479522705078125,
 38.48707962036133,
 33.972259521484375,
 33.89411544799805,
 33.723323822021484,
 56.39878845214844,
 30.673311233520508,
 35.6621208190918,
 46.14383316040039,
 39.3895149230957,
 35.711395263671875,
 45.52936935424805,
 48.2234992980957,
 30.73318099975586,
 32.512264251708984,
 37.612369537353516,
 41.90407180786133,
 38.00260925292969,
 44.22209167480469,
 43.98411560058594,
 52.90557098388672,
 36.197471618652344,
 37.87428665161133,
 43.722206115

In [14]:
imagenet_hf_rmse

[36.67558670043945,
 45.23104476928711,
 48.557640075683594,
 32.59537124633789,
 51.95182800292969,
 32.01007080078125,
 42.541748046875,
 33.944061279296875,
 26.233562469482422,
 40.38662338256836,
 43.71321105957031,
 39.997154235839844,
 43.79865646362305,
 38.77368927001953,
 37.98261260986328,
 51.60393524169922,
 38.42317199707031,
 38.07696533203125,
 33.22703552246094,
 46.6870002746582,
 37.189640045166016,
 39.236873626708984,
 38.029296875,
 36.33623504638672,
 36.82571029663086,
 34.44379806518555,
 45.041378021240234,
 38.40098190307617,
 37.82810974121094,
 36.02855682373047,
 36.9271125793457,
 55.975284576416016,
 31.395111083984375,
 36.05349349975586,
 49.40426254272461,
 43.091453552246094,
 35.16608810424805,
 43.902347564697266,
 51.12691879272461,
 31.813955307006836,
 33.319053649902344,
 37.47641372680664,
 41.97075271606445,
 40.3996696472168,
 46.55083465576172,
 46.281959533691406,
 53.51401138305664,
 36.19816589355469,
 37.94409942626953,
 42.573596954345

In [None]:
len(imagenet_lf_rmse)

In [None]:
imagenet_3pt_median_filter = imagenet_hf_rmse

In [None]:
y = [1,1,3]
x= y
y=[1,2]

In [18]:
coco_hf_rmse

[40.01418685913086,
 40.29827880859375,
 53.954341888427734,
 34.639686584472656,
 32.52113723754883,
 45.94272232055664,
 29.128734588623047,
 38.93730545043945,
 40.75054931640625,
 31.59794044494629,
 35.709007263183594,
 44.123748779296875,
 46.83586120605469,
 43.69535827636719,
 40.53071975708008,
 41.57402801513672,
 40.61527633666992,
 42.07217025756836,
 34.30278396606445,
 34.05022430419922,
 44.54413604736328,
 44.46985626220703,
 30.66412925720215,
 49.282222747802734,
 42.437522888183594,
 57.29024124145508,
 33.632259368896484,
 33.17685317993164,
 36.3046760559082,
 37.867271423339844,
 49.42519760131836,
 36.00853729248047,
 37.1890983581543,
 36.53473663330078,
 47.59408187866211,
 36.302223205566406,
 43.15013885498047,
 42.762001037597656,
 30.401147842407227,
 45.54239273071289,
 36.25474166870117,
 46.61229705810547,
 41.95122528076172,
 46.80423355102539,
 34.146812438964844,
 38.709312438964844,
 39.633548736572266,
 38.435543060302734,
 39.09422302246094,
 39.96