In [14]:
import os
import sys
from typing import List, Tuple, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import json
import yaml
from pathlib import Path
import rasterio
import random
from glob import glob
import math

from util.datasets import SentinelIndividualImageDataset, SentinelNormalize
from util.misc import load_model
from models_cae import cae_vit_base_patch16
from models_mae import mae_vit_base_patch16
import util.misc as misc

device = torch.cuda.is_available() and 'cuda' or 'cpu'
device

'cuda'

In [49]:
def test_cae(model: torch.nn.Module,
         data_loader,
         device: torch.device,
         args=None):
    """
    Run model evaluation on test dataset.
    Args:
        model: The model to evaluate
        data_loader: DataLoader containing test data
        device: Device to run testing on
        args: Arguments containing mask_ratio
    Returns:
        dict: Dictionary containing average test losses
    """
    model.eval()
    metric_logger = misc.MetricLogger(delimiter="  ")
    header = 'Test:'

    # Collect all losses for final statistics
    all_losses = []
    all_losses_main = []
    all_losses_align = []

    with torch.no_grad():
        for samples, *_ in metric_logger.log_every(data_loader, 20, header):
            samples = samples.to(device, non_blocking=True)
            
            with torch.cuda.amp.autocast():
                loss, loss_main, loss_align, *_ = model(samples, mask_ratio=0.75)
            
            loss_value = loss.item()
            loss_main_value = loss_main.item()
            loss_align_value = loss_align.item()
            
            
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping testing".format(loss_value))
                raise ValueError(f"Loss is {loss_value}, stopping testing")
           
            # Collect losses
            all_losses.append(loss_value)
            all_losses_main.append(loss_main_value)
            all_losses_align.append(loss_align_value)
            
            # Update metrics
            metric_logger.update(loss=loss_value)
            metric_logger.update(loss_main=loss_main_value)
            metric_logger.update(loss_align=loss_align_value)

    # Calculate final statistics
    avg_loss = sum(all_losses) / len(all_losses)
    avg_loss_main = sum(all_losses_main) / len(all_losses_main)
    avg_loss_align = sum(all_losses_align) / len(all_losses_align)

    # Print results
    print('=' * 80)
    print(f'Test Results:')
    print(f'Average Loss: {avg_loss:.4f}')
    print(f'Average Main Loss: {avg_loss_main:.4f}')
    print(f'Average Align Loss: {avg_loss_align:.4f}')
    print('=' * 80)

    return {
        'test_loss': avg_loss,
        'test_loss_main': avg_loss_main,
        'test_loss_align': avg_loss_align
    }


def test_mae(model: torch.nn.Module,
         data_loader,
         device: torch.device,
         args=None):
    """
    Run model evaluation on test dataset.
    Args:
        model: The model to evaluate
        data_loader: DataLoader containing test data
        device: Device to run testing on
        args: Arguments containing mask_ratio
    Returns:
        dict: Dictionary containing average test losses
    """
    model.eval()
    metric_logger = misc.MetricLogger(delimiter="  ")
    header = 'Test:'

    # Collect all losses for final statistics
    all_losses = []

    with torch.no_grad():
        for samples, *_ in metric_logger.log_every(data_loader, 20, header):
            samples = samples.to(device, non_blocking=True)
            
            with torch.cuda.amp.autocast():
                loss, *_ = model(samples, mask_ratio=0.75)
            
            loss_value = loss.item()
            
            if not math.isfinite(loss_value):
                print("Loss is {}, stopping testing".format(loss_value))
                raise ValueError(f"Loss is {loss_value}, stopping testing")
           
            # Collect losses
            all_losses.append(loss_value)
            
            # Update metrics
            metric_logger.update(loss=loss_value)

    # Calculate final statistics
    avg_loss = sum(all_losses) / len(all_losses)

    # Print results
    print('=' * 80)
    print(f'Test Results:')
    print(f'Average Loss: {avg_loss:.4f}')
    print('=' * 80)

    return {
        'test_loss': avg_loss,
    }

In [42]:
csv_path = "/home/ubuntu/satellite-cae/SatMAE/data/test_.csv"

mean = SentinelIndividualImageDataset.mean
std = SentinelIndividualImageDataset.std
transform = SentinelIndividualImageDataset.build_transform(
    is_train=False, input_size=224, mean=mean, std=std
)
dataset = SentinelIndividualImageDataset(
    csv_path=csv_path,
    transform=transform
)

batch_size = 64
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=True, num_workers=8)

# Test CAE

In [45]:
cae = cae_vit_base_patch16(in_chans=13)
cae.to(device)

checkpoint_path = "/home/ubuntu/checkpoint-98.pth"
# checkpoint_path = "/home/ubuntu/satellite-cae/SatMAE/output_dir/checkpoint-0.pth"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
cae.load_state_dict(checkpoint['model'], strict=False)

  checkpoint = torch.load(checkpoint_path, map_location='cpu')


_IncompatibleKeys(missing_keys=['teacher.0.norm1.weight', 'teacher.0.norm1.bias', 'teacher.0.attn.qkv.weight', 'teacher.0.attn.qkv.bias', 'teacher.0.attn.proj.weight', 'teacher.0.attn.proj.bias', 'teacher.0.norm2.weight', 'teacher.0.norm2.bias', 'teacher.0.mlp.fc1.weight', 'teacher.0.mlp.fc1.bias', 'teacher.0.mlp.fc2.weight', 'teacher.0.mlp.fc2.bias', 'teacher.1.norm1.weight', 'teacher.1.norm1.bias', 'teacher.1.attn.qkv.weight', 'teacher.1.attn.qkv.bias', 'teacher.1.attn.proj.weight', 'teacher.1.attn.proj.bias', 'teacher.1.norm2.weight', 'teacher.1.norm2.bias', 'teacher.1.mlp.fc1.weight', 'teacher.1.mlp.fc1.bias', 'teacher.1.mlp.fc2.weight', 'teacher.1.mlp.fc2.bias', 'teacher.2.norm1.weight', 'teacher.2.norm1.bias', 'teacher.2.attn.qkv.weight', 'teacher.2.attn.qkv.bias', 'teacher.2.attn.proj.weight', 'teacher.2.attn.proj.bias', 'teacher.2.norm2.weight', 'teacher.2.norm2.bias', 'teacher.2.mlp.fc1.weight', 'teacher.2.mlp.fc1.bias', 'teacher.2.mlp.fc2.weight', 'teacher.2.mlp.fc2.bias', 't

In [46]:
test(cae, dataloader, device)

  with torch.cuda.amp.autocast():


Test:  [   0/1327]  eta: 4:27:53  loss: 8.6131 (8.6131)  loss_main: 0.9407 (0.9407)  loss_align: 7.6724 (7.6724)  time: 12.1125  data: 11.9344  max mem: 5627
Test:  [  20/1327]  eta: 0:18:36  loss: 67.2100 (63.9049)  loss_main: 1.0811 (1.0624)  loss_align: 66.0713 (62.8425)  time: 0.2911  data: 0.1202  max mem: 5627
Test:  [  40/1327]  eta: 0:12:36  loss: 67.6484 (65.6110)  loss_main: 1.0833 (1.0756)  loss_align: 66.5650 (64.5353)  time: 0.3080  data: 0.1357  max mem: 5627
Test:  [  60/1327]  eta: 0:10:23  loss: 68.3950 (66.8223)  loss_main: 0.8738 (1.0081)  loss_align: 67.4525 (65.8143)  time: 0.2956  data: 0.1273  max mem: 5627
Test:  [  80/1327]  eta: 0:09:15  loss: 68.7400 (67.2066)  loss_main: 0.9595 (1.0099)  loss_align: 67.6276 (66.1967)  time: 0.3026  data: 0.1400  max mem: 5627
Test:  [ 100/1327]  eta: 0:08:31  loss: 67.1465 (67.3012)  loss_main: 1.0416 (1.0188)  loss_align: 66.1099 (66.2824)  time: 0.3038  data: 0.1410  max mem: 5627
Test:  [ 120/1327]  eta: 0:07:58  loss: 66

Test:  [1040/1327]  eta: 0:01:35  loss: 64.8262 (68.8015)  loss_main: 1.1994 (1.0428)  loss_align: 63.6579 (67.7587)  time: 0.3244  data: 0.1612  max mem: 5627
Test:  [1060/1327]  eta: 0:01:28  loss: 70.2621 (68.8271)  loss_main: 1.0239 (1.0425)  loss_align: 69.2353 (67.7845)  time: 0.2848  data: 0.1211  max mem: 5627
Test:  [1080/1327]  eta: 0:01:22  loss: 70.0705 (68.8591)  loss_main: 1.0603 (1.0429)  loss_align: 68.9998 (67.8163)  time: 0.3219  data: 0.1595  max mem: 5627
Test:  [1100/1327]  eta: 0:01:15  loss: 69.3203 (68.8660)  loss_main: 1.0754 (1.0434)  loss_align: 68.2572 (67.8226)  time: 0.2823  data: 0.1186  max mem: 5627
Test:  [1120/1327]  eta: 0:01:08  loss: 69.5305 (68.8760)  loss_main: 1.0151 (1.0430)  loss_align: 68.5643 (67.8331)  time: 0.3322  data: 0.1682  max mem: 5627
Test:  [1140/1327]  eta: 0:01:01  loss: 68.1279 (68.8737)  loss_main: 1.0642 (1.0434)  loss_align: 67.0311 (67.8303)  time: 0.2839  data: 0.1208  max mem: 5627
Test:  [1160/1327]  eta: 0:00:55  loss: 

{'test_loss': 68.93226244005837,
 'test_loss_main': 1.045291825713479,
 'test_loss_align': 67.88697067956414}

# Test MAE

In [47]:
mae = mae_vit_base_patch16(in_chans=13)
mae.to(device)

checkpoint_path = "/home/ubuntu/checkpoint-100.pth"
# checkpoint_path = "/home/ubuntu/satellite-cae/SatMAE/output_dir/checkpoint-satmae-99.pth"
checkpoint = torch.load(checkpoint_path, map_location='cpu')
mae.load_state_dict(checkpoint['model'], strict=False)

  checkpoint = torch.load(checkpoint_path, map_location='cpu')


<All keys matched successfully>

In [50]:
test_mae(mae, dataloader, device)

  with torch.cuda.amp.autocast():


Test:  [   0/1327]  eta: 4:25:27  loss: 0.9452 (0.9452)  time: 12.0028  data: 11.9024  max mem: 5627
Test:  [  20/1327]  eta: 0:18:24  loss: 1.0780 (1.0610)  time: 0.2873  data: 0.1954  max mem: 5627
Test:  [  40/1327]  eta: 0:12:43  loss: 1.0798 (1.0739)  time: 0.3280  data: 0.2309  max mem: 5627
Test:  [  60/1327]  eta: 0:10:26  loss: 0.8634 (1.0035)  time: 0.2919  data: 0.1893  max mem: 5627
Test:  [  80/1327]  eta: 0:09:16  loss: 0.9618 (1.0064)  time: 0.2994  data: 0.2095  max mem: 5627
Test:  [ 100/1327]  eta: 0:08:38  loss: 1.0424 (1.0161)  time: 0.3275  data: 0.2358  max mem: 5627
Test:  [ 120/1327]  eta: 0:07:59  loss: 1.0493 (1.0238)  time: 0.2696  data: 0.1770  max mem: 5627
Test:  [ 140/1327]  eta: 0:07:41  loss: 1.0450 (1.0274)  time: 0.3400  data: 0.2518  max mem: 5627
Test:  [ 160/1327]  eta: 0:07:17  loss: 0.9753 (1.0217)  time: 0.2772  data: 0.1895  max mem: 5627
Test:  [ 180/1327]  eta: 0:07:07  loss: 1.0610 (1.0249)  time: 0.3536  data: 0.2676  max mem: 5627
Test:  [

{'test_loss': 1.0426678094274582}