In [2]:
import sys
sys.path.append('..')

import torch
import numpy as np
from torchinfo import summary

from dataio import ImageFitting, EncodedImageFitting
from models import Siren, HybridSiren, MLP
from training import train_inr

import os
from argparse import ArgumentParser
from tqdm.notebook import tqdm

In [None]:
parser = ArgumentParser()

parser.add_argument('--input', type=str, default='data/kodak/kodim24.png')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--epochs_til_summary', type=int, default=25)
parser.add_argument('--batch_size', type=int, default=64*64)
parser.add_argument('--normalization', type=bool, default=False)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--model', type=str, default='siren')
parser.add_argument("--siren_layers", type=list[int], default=[0,1])

args = parser.parse_args(args=[])

In [None]:
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model_datasets = {}

img_fitting = ImageFitting(args.input, args.normalization)
siren = Siren(in_features=2, 
              out_features=3, 
              hidden_features=256, 
              num_hidden_layers=4
              ).to(device)
model_datasets['siren'] = (siren, img_fitting.copy())

for i in range(4):
    siren_layers = np.arange(i)
    hybrid_siren = HybridSiren(in_features=2,
                        out_features=3, 
                        hidden_features=256, 
                        num_hidden_layers=4,
                        siren_layers=siren_layers
                        ).to(device)
    model_datasets[f'hybrid_{str(siren_layers)}'] = (hybrid_siren, img_fitting.copy())

gaussian_encoded_img_fitting = EncodedImageFitting(args.input, args.normalization, encoding_type='gaussian')
mlp_gaussian = MLP(in_features=gaussian_encoded_img_fitting.encoding_dim,
          out_features=3, 
          hidden_features=256, 
          num_hidden_layers=4
          ).to(device)
model_datasets['gaussian_ff_mlp'] = (mlp_gaussian, gaussian_encoded_img_fitting)

fourier_encoded_img_fitting = EncodedImageFitting(args.input, args.normalization, encoding_type='basic', include_original=True)
mlp_basic_fourier = MLP(in_features=fourier_encoded_img_fitting.encoding_dim,
          out_features=3, 
          hidden_features=256, 
          num_hidden_layers=4
          ).to(device)
model_datasets['basic_ff_mlp'] = (mlp_basic_fourier, fourier_encoded_img_fitting)



In [None]:
training_results = {}

for model_name, (model, img_fitting) in model_datasets.items():
    print(f"Training {model_name}...")
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    print(summary(model, input_size=(1, 2), device=device))
    
    training_results[model_name] = train_inr(
        model=model,
        dataset=img_fitting,
        optimizer=optimizer,
        epochs=args.epochs,
        batch_size=args.batch_size,
        epochs_til_summary=args.epochs_til_summary,
        device=device
    )