In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

from src.data import Sentinel2DatasetRGB
from src.model import Siren
from src.utils import postprocess_model_output, psnr, show_image, bpp

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

IMAGE_SIZE = 64
DATA_DIR = "archive/EuroSAT"

Using MPS device


In [2]:
transform = Compose([
    Resize((IMAGE_SIZE, IMAGE_SIZE)),
    ToTensor(),
    Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dataset = Sentinel2DatasetRGB(root_dir=DATA_DIR, transform=transform)
dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0, shuffle=True)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
def train_siren(model, model_input, ground_truth):
    # Model Parameters
    total_steps = 10_000
    steps_til_summary = 100
    
    # Early stopping parameters
    patience = 10
    min_delta = 1e-4

    optim = torch.optim.Adam(lr=1e-3, params=model.parameters())

    best_loss = float('inf')
    steps_since_improvement = 0

    for step in range(total_steps):
        model_output, coords = model(model_input)    
        loss = ((model_output - ground_truth) ** 2).mean()
        
        if loss + min_delta < best_loss:
            best_loss = loss
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
    
        if steps_since_improvement >= patience:
            print(f"Stopping early at step {step} due to no improvement.")
            break
        
        optim.zero_grad()
        loss.backward()
        optim.step()


In [None]:
all_qualities = []
all_sizes = []

counter = 0
for x, y in dataloader:
    qualities = []
    sizes = []

    for n_hidden in range(2, 101, 2):
        
        model = Siren(in_features=2,
                      out_features=3,
                      hidden_features=n_hidden, 
                      hidden_layers=1,
                      outermost_linear=True).to(device)
        
        
        model_input, ground_truth = x.to(device), y.to(device)
        train_siren(model, model_input, ground_truth.to(device))
    
        compressed = postprocess_model_output(model(model_input)[0])
        real = postprocess_model_output(ground_truth)
    
        quality = psnr(compressed, real).item()
        bits_per_pixel = bpp(model) 
        # n_params = sum(p.numel() for p in model.parameters())
    
        qualities.append(quality)
        sizes.append(bits_per_pixel)

    all_qualities.append(qualities)
    all_sizes.append(sizes)
    
    counter += 1 
    if counter == 100:
        break