In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.manual_seed(1337)
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(0)

print(f"Using device: {device}", flush=True)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

In [None]:
# define the utils
cifar_mean = np.array([0.4914, 0.4822, 0.4465])
cifar_std = np.array([0.2470, 0.2435, 0.2616])

def show_image(image, title=''):
    # image is [H, W, 3]
    assert image.shape[2] == 3
    plt.imshow(torch.clip((image * cifar_std + cifar_mean) * 255, 0, 255).int()) #check
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

def prepare_model(checkpoint):
    # load model
    mae = torch.load(checkpoint, map_location='cpu')
    return mae

def run_one_image(x, model):
    # make it a batch-like
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)

    # run MAE
    loss, y, mask = model(x.float(), mask_ratio=0.75)
    y = model.unpatchify(y)
    y = torch.einsum('nchw->nhwc', y).detach().cpu()

    # visualize the mask
    mask = mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.patch_embed.patch_size[0]**2 *3)  # (N, H*W, p*p*3)
    mask = model.unpatchify(mask)  # 1 is removing, 0 is keeping
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()

    x = torch.einsum('nchw->nhwc', x)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 24]

    plt.subplot(1, 4, 1)
    show_image(x[0], "original")

    plt.subplot(1, 4, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 4, 3)
    show_image(y[0], "reconstruction")

    plt.subplot(1, 4, 4)
    show_image(im_paste[0], "reconstruction + visible")

    plt.show()

In [None]:
# load an image
img_1, _ = test_dataset[0]  # Get the first test image
img_1 = img_1.unsqueeze(0).to(device)  # Add batch dimension

plt.rcParams['figure.figsize'] = [5, 5]
show_image(torch.tensor(img_1))

In [None]:
checkpoint = 'checkpoint.pth'
model_mae = prepare_model(checkpoint)
print('Model loaded.')

In [None]:
# make random mask reproducible (comment out to make it change)
print('MAE with pixel reconstruction:')
run_one_image(img_1, model_mae)

In [None]:
print('MAE with pixel reconstruction:')
img_2, _ = test_dataset[1]  # Get the first test image
img_1 = img_1.unsqueeze(0).to(device)  # Add batch dimension
run_one_image(img_2, model_mae)

In [None]:
print('MAE with pixel reconstruction:')
img_3, _ = test_dataset[2]  # Get the first test image
img_3 = img_3.unsqueeze(0).to(device)  # Add batch dimension
run_one_image(img_3, model_mae)

In [None]:
print('MAE with pixel reconstruction:')
img_4, _ = test_dataset[3]  # Get the first test image
img_4 = img_4.unsqueeze(0).to(device)  # Add batch dimension
run_one_image(img_4, model_mae)