In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from src.data import NumpyDatasetAllBands
from src.model import Siren
from src.utils import *

if torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA device")
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("Using CPU device")


Using MPS device


In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
])
dataset = NumpyDatasetAllBands("data/np_downscaled_and_cropped", transform)
dataloader = DataLoader(dataset, batch_size=1, pin_memory=True, num_workers=0, shuffle=True)

In [6]:
def train_siren(model, model_input, ground_truth):
    # Model Parameters
    total_steps = 10_000
    steps_til_summary = 100
    
    # Early stopping parameters
    patience = 50
    min_delta = 1e-4

    optim = torch.optim.Adam(lr=1e-3, params=model.parameters())

    best_loss = float('inf')
    steps_since_improvement = 0

    for step in range(total_steps):
        model_output, coords = model(model_input)    
        loss = ((model_output - ground_truth) ** 2).mean()
        
        if loss + min_delta < best_loss:
            best_loss = loss
            steps_since_improvement = 0
        else:
            steps_since_improvement += 1
    
        if steps_since_improvement >= patience:
            print(f"Stopping early at step {step} due to no improvement.")
            break
        
        optim.zero_grad()
        loss.backward()
        optim.step()


In [7]:
all_qualities = []
all_sizes = []

image_counter = 1
for x, y in dataloader:
    qualities = []
    sizes = []

    counter = 1
    for n_layers in [1, 2, 3]:
        for n_hidden in range(5, 201, 5):
            model = Siren(in_features=2,
                          out_features=13,
                          hidden_features=n_hidden, 
                          hidden_layers=n_layers,
                          outermost_linear=True).to(device)


            model_input, ground_truth = x.to(device), y.to(device)
            print(f"image {image_counter}; model {counter}/{3 * len(range(5, 201, 5))}", end='\t')
            train_siren(model, model_input, ground_truth.to(device))

            compressed = postprocess_model_output(model(model_input)[0])
            real = postprocess_model_output(ground_truth)

            quality = psnr(compressed, real).item()
            bits_per_pixel = bpp(model) 
            # n_params = sum(p.numel() for p in model.parameters())

            qualities.append(quality)
            sizes.append(bits_per_pixel)
            counter += 1

    image_counter += 1
    all_qualities.append(qualities)
    all_sizes.append(sizes)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


image 1; model 1/120	Stopping early at step 316 due to no improvement.
image 1; model 2/120	

KeyboardInterrupt: 

In [None]:
np.save("artifacts/sizes_all_bands.npy", all_qualities)
np.save("artifacts/qualities_all_bands.npy", all_qualities)