<a href="https://colab.research.google.com/github/hallpaz/udl/blob/main/code/mr-imaging.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Understanding Deep Learning 2024
# Multiresolution Representation of Images

#### by Hallison Paz
#### October 15th, 2024

In [None]:
from IPython.display import HTML
HTML('''<iframe width="560" height="315"
        src="https://www.youtube.com/embed/WmaWecH0ThU?si=IlFcrojnGMReHle5"
        frameborder="0" allow="accelerometer; autoplay; encrypted-media;
        gyroscope; picture-in-picture" allowfullscreen></iframe>''')

In [None]:
!pip install git+https://github.com/visgraf/mrnet.git@dev
!pip install wandb
!pip install trimesh

In [None]:
import os
from pathlib import Path
import torch

from mrnet.training.trainer import MRTrainer
from mrnet.datasets.signals import ImageSignal
from mrnet.networks.mrnet import MRFactory
from mrnet.datasets.pyramids import create_MR_structure
from mrnet.training.listener import TrainingListener

from mrnet.training.utils import load_hyperparameters, get_optim_handler

In [None]:
!mkdir -p configs/
!mkdir -p data/
!wget -P configs https://raw.githubusercontent.com/hallpaz/udl/refs/heads/main/code/image.yml
!wget -P data https://raw.githubusercontent.com/hallpaz/udl/refs/heads/main/data/masp.jpg

## Training a MR-Net Model

In [None]:
CONFIG_PATH = 'configs'

In [None]:
torch.manual_seed(777)
#-- hyperparameters in configs --#
hyper = load_hyperparameters(os.path.join(CONFIG_PATH, 'image.yml'))
project_name = hyper.get('project_name', 'framework-tests')

In [None]:
base_signal = ImageSignal.init_fromfile(
                    hyper['data_path'],
                    domain=hyper['domain'],
                    channels=hyper['channels'],
                    sampling_scheme=hyper['sampling_scheme'],
                    width=hyper['width'], height=hyper['height'],
                    batch_size=hyper['batch_size'],
                    color_space=hyper['color_space'])

train_dataset = create_MR_structure(base_signal,
                                    hyper['max_stages'],
                                    hyper['filter'],
                                    hyper['decimation'],
                                    hyper['pmode'])
test_dataset = create_MR_structure(base_signal,
                                    hyper['max_stages'],
                                    hyper['filter'],
                                    False,
                                    hyper['pmode'])

if hyper['width'] == 0:
    hyper['width'] = base_signal.shape[-1]
if hyper['height'] == 0:
    hyper['height'] = base_signal.shape[-1]

# you can substitute this line by your custom handler class
optim_handler = get_optim_handler(hyper.get('optim_handler', 'regular'))

mrmodel = MRFactory.from_dict(hyper)
print("Model: ", type(mrmodel))
name = os.path.basename(hyper['data_path'])
logger = TrainingListener(project_name,
                            f"{hyper['model']}{hyper['filter'][0].upper()}{name[0:7]}{hyper['color_space'][0]}",
                            hyper,
                            Path(hyper.get("log_path", "runs")))
mrtrainer = MRTrainer.init_from_dict(mrmodel,
                                    train_dataset,
                                    test_dataset,
                                    logger,
                                    hyper,
                                    optim_handler=optim_handler)
mrtrainer.train(hyper['device'])

## Exploring the model scales

In [None]:
import ipywidgets as widgets
import numpy as np
from PIL import Image
from ipywidgets import interact, interactive, Box, interact_manual
from mrnet.datasets.sampler import make_grid_coords

In [None]:
modelpath = input("Path to the saved model parameters file: ")
mrmodel = MRFactory.load_state_dict(modelpath)
print('modelpath:', modelpath)

In [None]:
res = hyper['width']
channels = hyper['channels']
model = mrmodel

level_slider = widgets.FloatSlider(
        value=1.0,
        min=0.0,
        max=float(mrmodel.n_stages()),
        step=0.05,
        description=f'Multilevel',
        disabled=False,
        continuous_update=True,
        readout=True,
        orientation='horizontal',
        readout_format='.2f',
        layout=widgets.Layout(width='50%')
)
def plot_model(level):
    grid = make_grid_coords(res, -1.0, 1.0, dim=2)
    weights = []
    for s in range(mrmodel.n_stages()):
        if level >= s + 1:
             weights.append(1.0)
        else:
             weights.append(max(level - s, 0.0))

    output = model(grid, mrweights=torch.Tensor(weights))
    model_out = torch.clamp(output['model_out'], 0.0, 1.0)

    pixels = model_out.cpu().detach().view(res, res, channels).numpy()
    pixels = (pixels * 255).astype(np.uint8)
    if channels == 1:
        pixels = np.repeat(pixels, 3, axis=-1)
    return Image.fromarray(pixels)

interact(plot_model, level=level_slider)