See if updating to latest version of timm (0.4.5 -> 0.9.12), which requires modifications to the code (qk_scale -> qk_norm in `models_mae_loss.py`), makes a difference.

In [1]:
import torch
from test_loss import get_dataloader
from util.datasets import standardize
from models_mae_loss import mae_vit_base_patch16


In [2]:
def evaluate_base(dataloader):
    data = torch.empty((0, 2048))

    with torch.no_grad():
        for samples in dataloader:
            data = torch.cat((data, samples), 0)
        mean = data.mean()
        loss = (data-mean)**2
        loss = loss.mean()
    return loss.item()

def inverse_standardize(raws, pred_un):
    mean = raws.mean(dim=1, keepdim=True)
    std = raws.std(dim=1, keepdim=True)
    return pred_un * std + mean

def inverse_log(raws, pred_un):
    # raws is not used, just to be consistent with other inverse functions
    return torch.exp(pred_un) - 1

def get_mse(true_values, pred_values):
    # all in tensor
    loss = (true_values - pred_values)**2
    loss = loss.mean()
    return loss.item()

def evaluate_inraw(model, dataloader_raw, dataloader_norm, inverse=None, device='cuda'):
    total_loss = 0.
    model.eval()  # turn on evaluation mode

    with torch.no_grad():
        for (raws, norms) in zip(dataloader_raw, dataloader_norm):
            raws = raws.to(device, non_blocking=True, dtype=torch.float)
            norms = norms.to(device, non_blocking=True, dtype=torch.float)

            _, pred, _ = model(norms)    # in normalized space
            pred_un = model.unpatchify(pred)
            if inverse:
                pred_un = inverse(raws, pred_un)    # in raw space
            loss = get_mse(raws, pred_un)
            total_loss += loss
    return total_loss / len(dataloader_raw)

In [3]:
# the base model
dataloader_raw = get_dataloader(batch_size=64, transform=None)
mse_base = evaluate_base(dataloader_raw['val'])

# the model with standardization
dataloader_std = get_dataloader(batch_size=64, transform=standardize)
model = mae_vit_base_patch16().to('cuda')
model.load_state_dict(torch.load('models/mae_vit_base_patch16_update_1e-05_20231214.pth'))
mse_std = evaluate_inraw(model, dataloader_raw['val'], dataloader_std['val'], inverse=inverse_standardize)

print(f'MSE: {round(mse_std, 1)}')
print(f'R2: {round(1 - mse_std / mse_base, 2)}')

MSE: 839728.0
R2: 0.96


The result is relevant to the optimal model trained by old version of timm, actually even slightly better. It should be the randomness of the training process. So I will use the latest version of timm to train the model.