In [None]:
import torch

from ..dataio import ImageFitting, EncodedImageFitting
from models import Siren, HybridSiren, MLP
from ..training import train_inr

import os
from argparse import ArgumentParser
from tqdm.notebook import tqdm

In [None]:
parser = ArgumentParser()

parser.add_argument('--input', type=str, default='data/kodak/kodim24.png')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--epochs', type=int, default=500)
parser.add_argument('--epochs_til_summary', type=int, default=25)
parser.add_argument('--batch_size', type=int, default=64*64)
parser.add_argument('--normalization', type=bool, default=False)
parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--model', type=str, default='siren')

args = parser.parse_args(args=[])

In [None]:
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize the dataset
img_fitting = ImageFitting(args.input, normalization=args.normalization)

In [None]:
# Initialize the model and optimizer
if args.model == 'siren':
    model = Siren(img_fitting.input_dim, img_fitting.output_dim, hidden_features=256, num_hidden_layers=3).to(device)
elif args.model == 'hybrid_siren':
    model = HybridSiren(img_fitting.input_dim, img_fitting.output_dim, hidden_features=256, num_hidden_layers=3).to(device)
elif args.model == 'mlp':
    model = MLP(img_fitting.input_dim, img_fitting.output_dim, hidden_features=256, num_hidden_layers=3).to(device)
else:
    raise ValueError(f"Unknown model type: {args.model}")

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

In [None]:
results = train_inr(
    model=model,
    dataset=img_fitting,
    optimizer=optimizer,
    epochs=args.epochs,
    batch_size=args.batch_size,
    epochs_til_summary=args.epochs_til_summary,
    device=device
)