# MaRINeR Demo

[[Project Page](https://boelukas.github.io/mariner/)] [[GitHub](https://github.com/boelukas/mariner)]
<!-- [[Paper]()] -->
<!-- [[Video]()] -->

This notebook contains a demo of MaRINeR and covers project setup, training, evaluation and prediction of the model. A GPU runtime is recommended for optimal performance.


## Project Setup

In [None]:
!git clone https://github.com/boelukas/mariner.git
%cd mariner

### Download Demo Dataset and Pre-Trained Weights



In [None]:
!mkdir pretrained_weights
!mkdir data
!gdown 1zb90JWtX5-Si7MklJMqWn1Kwnqsi6mhF
!gdown 1VmhgXL1IFRwDlCSPZcwTt9ZsKorSknKk
!unzip demo_data.zip -d data/demo_data
!unzip mariner.zip -d pretrained_weights/

### Install Dependencies

In [6]:
!pip install torchvision==0.16.0 tensorboard===2.17.0 lightning==2.2.1 opencv-python==4.10.0.84 jsonargparse==4.31.0 erqa==1.1.2 lpips==0.1.4 tabulate==0.9.0 rich==13.7.1 numpy==1.26.4 jsonargparse[signatures]>=4.26.1

## Predict

In [None]:
!python mariner/main.py predict \
    -c configs/MaRINeR.yml \
    --ckpt_path /content/mariner/pretrained_weights/mariner.ckpt \
    --data_dir /content/mariner/data/demo_data

### Visualize Results

In [None]:
!pip install ipympl --quiet
%matplotlib widget
from google.colab import output
output.enable_custom_widget_manager()

In [None]:

import matplotlib.pyplot as plt
import os
import ipywidgets as widgets
from PIL import Image

ref_dir = "/content/mariner/data/demo_data/ref"
input_dir = "/content/mariner/data/demo_data/input"
out_dir = "/content/mariner/data/demo_data/out"
gt_dir = "/content/mariner/data/demo_data/gt"

input_files = sorted(os.listdir(input_dir))
out_files = sorted(os.listdir(out_dir))
gt_files = sorted(os.listdir(gt_dir))
ref_files = sorted(os.listdir(ref_dir))
scale_factor = 1.0

def update_fig(fig, ax, input_img, gt_img, ref_img, out_img):
    img_width, img_height = input_img.size

    fig.set_size_inches(scale_factor*img_width * 3 / 100, (scale_factor*img_height / 100) + 0.5)
    fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0.05, hspace=0.05)

    for ax in axs:
      ax.clear()
      ax.axis('off')

    axs[0].imshow(ref_img)
    axs[0].set_title("Reference")
    axs[1].imshow(input_img)
    axs[1].set_title("Input")
    axs[2].imshow(out_img)
    axs[2].set_title("Output")
    axs[3].imshow(gt_img)
    axs[3].set_title("Ground Truth")

    fig.canvas.draw_idle()


def update_scale(change):
    global scale_factor
    scale_factor = change.new
    update_fig(fig, axs, Image.open(os.path.join(input_dir, input_files[slider.value])),
               Image.open(os.path.join(gt_dir, gt_files[slider.value])),
               Image.open(os.path.join(ref_dir, ref_files[slider.value])),
               Image.open(os.path.join(out_dir, out_files[slider.value])))


def update(change):
    new_input_img = Image.open(os.path.join(input_dir, input_files[change.new]))
    new_gt_img = Image.open(os.path.join(gt_dir, gt_files[change.new]))
    new_out_img = Image.open(os.path.join(out_dir, out_files[change.new]))

    new_ref_img = Image.open(os.path.join(ref_dir, ref_files[change.new]))

    update_fig(fig, axs, new_input_img, new_gt_img, new_ref_img, new_out_img)

with plt.ioff():
    fig, axs = plt.subplots(1, 4)
    for ax in axs:
        ax.axis('off')
        input_image = Image.open(os.path.join(input_dir, input_files[0]))
        gt_image = Image.open(os.path.join(gt_dir, gt_files[0]))
        out_image = Image.open(os.path.join(out_dir, out_files[0]))

        ref_image = Image.open(os.path.join(ref_dir, ref_files[0]))

        update_fig(fig, axs, input_image, gt_image, ref_image, out_image)

slider = widgets.IntSlider(value=0, min=0, max=len(input_files)-1, description='Image Index:')
slider.observe(update, names='value')

scale_slider = widgets.FloatSlider(value=1.0, min=0, max=10.0, description='Scale:')
scale_slider.observe(update_scale, names='value')
scale_text = widgets.FloatText(value=1.0, description='Scale Factor:', layout=widgets.Layout(width='150px'))
scale_text.observe(update_scale, names='value')

widgets.VBox([fig.canvas, slider, scale_slider])

## Evaluate

### Download Evaluation Datasets

In [None]:
# https://drive.google.com/file/d/1fkajRAyxsaOsCPxZLDU1iUMo8BYZNGej/view?usp=drive_link
!gdown 1fkajRAyxsaOsCPxZLDU1iUMo8BYZNGej
!unzip test_data.zip -d data/test_data

### Predict Test Data


In [None]:
dataset = "NeRF" # @param ["CAB_ref_gt", "CAB_ref_lvl_1", "LIN_ref_lvl_1", "HGE_ref_lvl_1", "IBRnet", "12SCENES_apt_1_living_ref_lvl_10", "NeRF"] {allow-input: false}
!python mariner/main.py predict \
    -c configs/MaRINeR.yml \
    --ckpt_path /content/mariner/pretrained_weights/mariner.ckpt \
    --data_dir /content/mariner/data/test_data/{dataset}/


### Evaluate Test Data

In [None]:
!python scripts/eval_metrics.py \
    --images /content/mariner/data/test_data/{dataset}/out \
            /content/mariner/data/test_data/{dataset}/gt

## Training

### Download Training Dataset

In [None]:
# https://drive.google.com/file/d/1x9Q6np6VklEthr5f3Ne15pUzfcc7Megk/view?usp=drive_link
!gdown 1x9Q6np6VklEthr5f3Ne15pUzfcc7Megk
!unzip train_data.zip -d data/train_data

### Train

Adjust the `batch_size` according to the runtime. To train with the default parameters `batch_size = 9`, 32 GB VRAM are needed.

In [None]:
batch_size = 2 # @param {type:"number"}

!python mariner/main.py fit \
     -c configs/MaRINeR.yml \
     --train_data_dir /content/mariner/data/train_data/CAB_merged_LIN/train \
     --test_data_dir //content/mariner/data/train_data/CAB_merged_LIN/test \
     --batch_size {batch_size}