# Deterministic Models

* Fourier Feature Networks (FFN)
* Siren
* Modulated Siren (ModSiren)
* Multiplicative Filter Networks (MFN)
    * Fourier
    * Gabor

In [None]:
import sys, os
from pyprojroot import here

# spyder up to find the root
root = here(project_files=[".root"])

# append to path
sys.path.append(str(root))

In [None]:
from typing import Dict, Any, cast
import tabulate
from IPython.display import display, HTML

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import ReLU

from tqdm.notebook import tqdm as tqdm
import os, imageio

from inr4ssh._src.models.mlp import MLP
from inr4ssh._src.models.activations import Swish
from inr4ssh._src.data.images import load_fox

# from inr4ssh._src.features import get_image_coordinates
from inr4ssh._src.datamodules.images import ImageFox, ImageCameraman

import pytorch_lightning as pl
from inr4ssh._src.models.image import ImageModel
from inr4ssh._src.models.siren import Siren, SirenNet, Modulator, ModulatedSirenNet
from inr4ssh._src.models.mfn import FourierNet, GaborNet
from inr4ssh._src.models.activations import get_activation

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.utilities.argparse import add_argparse_args
from pytorch_lightning.loggers import WandbLogger


pl.seed_everything(123)

import matplotlib.pyplot as plt
import seaborn as sns
import wandb

sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%load_ext autoreload
%autoreload 2

## Logger

In [None]:
# wandb_logger = WandbLogger(
#     mode="online", #"offline",
#     project="inr4ssh",
#     entity="ige",
#     dir="/Users/eman/code_projects/logs",
#     resume=False
# )

## Data

The input data is a coordinate vector, $\mathbf{x}_\phi$, of the image coordinates.

$$
\mathbf{x}_\phi \in \mathbb{R}^{D_\phi}
$$

where $D_\phi = [\text{x}, \text{y}]$. So we are interested in learning a function, $\boldsymbol{f}$, such that we can input a coordinate vector and output a scaler/vector value of the pixel value.

$$
\mathbf{u} = \boldsymbol{f}(\mathbf{x}_\phi; \boldsymbol{\theta})
$$

In [None]:
img = load_fox()

In [None]:
plt.figure()
plt.imshow(img)
plt.show()

### Data Module


#### Train-Test Split

In this example, we are only taking every other pixel for training and validation. It is a very simple and well-defined problem which each of the neural networks should be able to solve. The final test image is the original full resolution image.

In [None]:
# dm = ImageFox(batch_size=1024).setup()
dm = ImageFox(batch_size=4096, shuffle=True).setup()

In [None]:
len(dm.ds_train), len(dm.ds_valid), len(dm.ds_test)

Notice how we have `131_072` points from training and validation and `262_144` for the testing. This is because we have *raveled* the image where each coordinate is a vector of `x,y`. So these are a lot of points...

In [None]:
init = dm.ds_train[:32]
x_init, y_init = init
x_init.shape, y_init.shape

### Optimizer

For this, we will use a simple adam optimizer with a `learning_rate` of 1e-4. From many studies, it appears that a lower learning rate works well with this methods because there is a lot of data. In addition, a bigger `batch_size` is also desireable. We will set the `num_epochs` to `1_000` which should be good enough for a single image. Obviously more epochs and a better learning rate scheduler would result in better results but this will be sufficient for this demo.

In [None]:
num_epochs = 500
learning_rate = 1e-3

### Scheduler

<p align="center">
<img src="http://www.bdhammel.com/assets/learning-rate/resnet_loss.png" alt="drawing" width="300"/>
<figcaption align = "center">
  <b>Fig.1 - An example for learning rate reduction when the validation loss stagnates. Source: 
    <a href="http://www.bdhammel.com/assets/learning-rate/resnet_loss.png">Blog</a>
  </b>
  </figcaption>
</p>

We will use a simple learning rate scheduler - `reduce_lr_on_plateau`. This will automatically reduce the learning rate as the validation loss stagnates. It will ensure that we really squeeze out as much performance as possible from our models during the training procedure.We start with a (relatively) high `learning_rate` of `1e-4` so we will set the `patience` to 5 epochs. So if there is no change in with every epoch, we decrease the learning rate by a factor of `0.1`.

This is a rather crude (but effective) method but it tends to work well in some situations. A better method might be the `cosine_annealing` method or the `exponential_decay` method. See other [examples](https://www.kaggle.com/code/snnclsr/learning-rate-schedulers/).

### Loss

We are going with a very simple `loss` function: the *mean squared error* (MSE). This is given by:

$$
\mathcal{L}(\boldsymbol{\theta}) = \frac{1}{N} \sum_n^N \left( \mathbf{y}_n - \boldsymbol{f}_{\boldsymbol{\theta}}(\mathbf{x}_n) \right)^2
$$

We won't code this from scratch, we will just use the PyTorch function, `nn.MSELoss`, and we will use the `mean` reduction parameter.


### PSNR

We will also keep track of the signal to noise ratio (PSNR) which will give us an indication of how well we are learning.

$$
\text{PSNR}(\mathbf{x}) = - 10 \log (2 * \text{MSE}(\mathbf{x}))
$$

In [None]:
results = dict()
images = dict()

## Experiment

In [None]:
class ImageModel(pl.LightningModule):
    def __init__(self, model, **kwargs):
        super().__init__()

        self.save_hyperparameters()
        self.model = model
        self.hyperparams = cast(Dict[str, Any], self.hparams)
        self.loss = nn.MSELoss(reduction="mean")

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        # loss function
        loss = self.loss(y, pred)

        self.log("train_loss", loss, prog_bar=True, logger=True)
        self.log(
            "train_psnr", -10 * torch.log10(2.0 * loss), prog_bar=True, logger=True
        )

        return loss

    def validation_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        # loss function
        loss = self.loss(y, pred)

        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_psnr", -10 * torch.log10(2.0 * loss), prog_bar=True, logger=True)

        return loss

    def test_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        # loss function
        loss = self.loss(y, pred)

        self.log("test_loss", loss)
        self.log("test_psnr", -10 * torch.log10(2.0 * loss))

        return loss

    def predict_step(self, batch, batch_idx):
        # output
        x, y = batch

        pred = self.forward(x)

        return pred

    def configure_optimizers(self):

        optimizer = Adam(self.model.parameters(), lr=self.hyperparams.get("lr", 1e-4))
        scheduler = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.hyperparams.get("lr_schedule_warmup", 25),
            max_epochs=num_epochs,
        )
        # scheduler = ReduceLROnPlateau(
        #     optimizer, patience=self.hyperparams.get("lr_schedule_patience", 5)
        # )
        return {
            "optimizer": optimizer,
            "lr_scheduler": scheduler,
            "monitor": "val_loss",
        }

### Callbacks

In [None]:
callbacks = [TQDMProgressBar(refresh_rate=10)]

## Multi-layer Perceptron (MLP)


### Swish Activation Layer

In [None]:
out = Swish()(x_init)

out.shape

### MLP Layer

$$
\mathbf{f}_\ell(\mathbf{x}) = \sigma\left(\mathbf{w}^{(\ell)}\mathbf{x} + \mathbf{b}^{(\ell)} \right)
$$

where $\sigma$ is the *swish* activation function.

$$
\sigma(\mathbf{x}) = \mathbf{x} \odot \text{Sigmoid}(\mathbf{x})
$$

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 5
activation = "swish"  # Swish()  # nn.ReLU()#
final_activation = "sigmoid"

mlp_net = MLP(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    activation=get_activation(activation),
    final_activation=get_activation(final_activation),
)

In [None]:
out = mlp_net(x_init)

In [None]:
out.shape

In [None]:
learn = ImageModel(mlp_net, learning_rate=learning_rate)

In [None]:
out = learn.forward(x_init)

out.shape

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="mps",
    # devices=1,
    enable_progress_bar=True,
    # logger=wandb_logger,
    callbacks=callbacks,
)

In [None]:
trainer.fit(
    learn,
    datamodule=dm,
)

In [None]:
res = trainer.test(learn, dataloaders=dm.test_dataloader())

results["mlp"] = res[0]

In [None]:
table = [
    [
        key,
        f"{results[key]['test_loss']:4.4f}",
        f"{results[key]['test_psnr']:4.3f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "MSE",
                "PSNR",  # "Num Parameters"
            ],
        )
    )
)

In [None]:
# wandb_logger.log_metrics(
#     {"mse": results["mlp"][0]["test_loss"],
#      "pnsr": results["mlp"][0]["test_psnr"],
#     }
# )

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)

images["mlp"] = img_pred

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(img_pred, cmap="gray")
plt.tight_layout()

# wandb_logger.log_image("reconstruction", [wandb.Image(fig)])
plt.savefig(
    f"./results/img_reg_nn.png",
    bbox_inches="tight",
)
plt.close()

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(8, 6))
ax[0].imshow(img, cmap="gray")
ax[1].imshow(img_pred, cmap="gray")
ax[2].imshow(np.abs(img_pred - img), cmap="gray")
plt.tight_layout()
plt.show()

## Fourier Feature Networks

These methods include a 

### Encoders

$$
\mathbf{z} = \boldsymbol{\gamma}(\mathbf{x})
$$

**Identity**

$$
\mathbf{z} = \boldsymbol{\gamma}(\mathbf{x})
$$

**Basic Mapping**

$$
\boldsymbol{\gamma}(\mathbf{x}) = [\sin(2\pi \mathbf{x}), \cos(2\pi\mathbf{x})]^\top
$$

**Positional Encoding**


$$
\boldsymbol{\gamma}(\mathbf{x}) = [\sin(2\pi \boldsymbol{\Omega}\mathbf{x}), \cos(2\pi \boldsymbol{\Omega}\mathbf{x})]^\top
$$

where $\boldsymbol{\Omega} \in \mathbb{R}^{d} \sim \mathcal{N}\left(\mathbf{0}, \gamma^2\mathbf{I}\right)$

In [None]:
from inr4ssh._src.models.encoders import (
    IdentityPositionalEncoding,
    NeRFPositionalEncoding,
    GaussianFourierFeatureTransform,
)
from inr4ssh._src.models.ffn import FourierFeatureMLP

In [None]:
# encoder = IdentityPositionalEncoding(in_dim=dim_in)
# encoder = NeRFPositionalEncoding(in_dim=dim_in, mapping_size=50)
encoder = GaussianFourierFeatureTransform(
    in_dim=x_init.shape[1], mapping_size=256, sigma=10.0
)

In [None]:
out = encoder(x_init)
x_init.shape, out.shape

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
activation = "swish"  # Swish() # nn.ReLU() #
final_activation = "sigmoid"  # nn.Sigmoid()

ffn_net = FourierFeatureMLP(
    encoder=encoder,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    activation=get_activation(activation),
    final_activation=get_activation(final_activation),
)

In [None]:
out = ffn_net(x_init)
x_init.shape, out.shape

In [None]:
learn = ImageModel(ffn_net, learning_rate=learning_rate)

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="mps",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(learn, datamodule=dm)

In [None]:
res = trainer.test(learn, dataloaders=dm.test_dataloader())

results["ffn"] = res[0]

In [None]:
table = [
    [
        key,
        f"{results[key]['test_loss']:4.4f}",
        f"{results[key]['test_psnr']:4.3f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "MSE",
                "PSNR",  # "Num Parameters"
            ],
        )
    )
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)
images["ffn"] = img_pred

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(img_pred, cmap="gray")
plt.tight_layout()
plt.savefig(
    f"./results/img_reg_ffn.png",
    bbox_inches="tight",
)
plt.close()

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(8, 6))
ax[0].imshow(img, cmap="gray")
ax[1].imshow(img_pred, cmap="gray")
ax[2].imshow(np.abs(img_pred - img), cmap="gray")
plt.tight_layout()
plt.show()

## Siren

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
w0 = 1.0
w0_initial = 30.0
c = 6.0
final_activation = "sigmoid"  # nn.Sigmoid()

siren_net = SirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    c=c,
    final_activation=get_activation(final_activation),
)

In [None]:
out = siren_net(x_init)
x_init.shape, out.shape

In [None]:
learn = ImageModel(siren_net, learning_rate=learning_rate)

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="mps",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(learn, datamodule=dm)

In [None]:
res = trainer.test(learn, dataloaders=dm.test_dataloader())

results["siren"] = res[0]

In [None]:
table = [
    [
        key,
        f"{results[key]['test_loss']:4.4f}",
        f"{results[key]['test_psnr']:4.3f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "MSE",
                "PSNR",  # "Num Parameters"
            ],
        )
    )
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)

images["siren"] = img_pred

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(img_pred, cmap="gray")
plt.tight_layout()
plt.savefig(
    f"./results/img_reg_siren.png",
    bbox_inches="tight",
)
plt.close()

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(8, 6))
ax[0].imshow(img, cmap="gray")
ax[1].imshow(img_pred, cmap="gray")
ax[2].imshow(np.abs(img_pred - img), cmap="gray")
plt.tight_layout()
plt.show()

## Modulated Siren

In [None]:
dim_in = x_init.shape[1]
dim_hidden = 256
dim_out = y_init.shape[1]
num_layers = 4
w0 = 1.0
w0_initial = 30.0
c = 6.0
final_activation = "sigmoid"
latent_dim = 256
num_layers_latent = 4
latent_operation = "sum"


modsiren_net = ModulatedSirenNet(
    dim_in=dim_in,
    dim_hidden=dim_hidden,
    dim_out=dim_out,
    num_layers=num_layers,
    w0=w0,
    w0_initial=w0_initial,
    c=c,
    final_activation=get_activation(final_activation),
    latent_dim=latent_dim,
    num_layers_latent=num_layers_latent,
    operation=latent_operation,
)

In [None]:
out = modsiren_net(x_init)
x_init.shape, out.shape

In [None]:
learn = ImageModel(modsiren_net, learning_rate=learning_rate)

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="cpu",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(learn, datamodule=dm)

In [None]:
res = trainer.test(learn, dataloaders=dm.test_dataloader())

results["modsiren"] = res[0]

In [None]:
table = [
    [
        key,
        f"{results[key]['test_loss']:4.4f}",
        f"{results[key]['test_psnr']:4.3f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "MSE",
                "PSNR",  # "Num Parameters"
            ],
        )
    )
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)

images["modsiren"] = img_pred

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(img_pred, cmap="gray")
plt.tight_layout()
plt.savefig(
    f"./results/img_reg_modsiren.png",
    bbox_inches="tight",
)
plt.close()

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(8, 6))
ax[0].imshow(img, cmap="gray")
ax[1].imshow(img_pred, cmap="gray")
ax[2].imshow(np.abs(img_pred - img), cmap="gray")
plt.tight_layout()
plt.show()

## Multiplicative Filter Networks

### Fourier Net

In [None]:
dim_in = x_init.shape[1]
dim_out = y_init.shape[1]
dim_hidden = 256
num_layers = 4
use_bias = True
input_scale = 256.0
weight_scale = 1.0
final_activation = "sigmoid"

fourier_net = FourierNet(
    dim_in=dim_in,
    dim_out=dim_out,
    dim_hidden=dim_hidden,
    num_layers=num_layers,
    input_scale=input_scale,
    weight_scale=weight_scale,
    use_bias=use_bias,
    final_activation=get_activation(final_activation),
)

In [None]:
out = fourier_net(x_init)
x_init.shape, out.shape

In [None]:
learn = ImageModel(fourier_net, learning_rate=learning_rate)

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="mps",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(learn, datamodule=dm)

In [None]:
res = trainer.test(learn, dataloaders=dm.test_dataloader())

results["fouriernet"] = res[0]

In [None]:
table = [
    [
        key,
        f"{results[key]['test_loss']:4.4f}",
        f"{results[key]['test_psnr']:4.3f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "MSE",
                "PSNR",  # "Num Parameters"
            ],
        )
    )
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)

images["fouriernet"] = img_pred

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(img_pred, cmap="gray")
plt.tight_layout()
plt.savefig(
    f"./results/img_reg_fouriernet.png",
    bbox_inches="tight",
)
plt.close()

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(8, 6))
ax[0].imshow(img, cmap="gray")
ax[1].imshow(img_pred, cmap="gray")
ax[2].imshow(np.abs(img_pred - img), cmap="gray")
plt.tight_layout()
plt.show()

### GaborNet

In [None]:
dim_in = x_init.shape[1]
dim_out = y_init.shape[1]
dim_hidden = 256
num_layers = 4
use_bias = True
input_scale = 256.0
weight_scale = 1.0
alpha = 6.0
beta = 1.0
final_activation = "sigmoid"


gabor_net = GaborNet(
    dim_in=dim_in,
    dim_out=dim_out,
    dim_hidden=dim_hidden,
    num_layers=num_layers,
    input_scale=input_scale,
    weight_scale=weight_scale,
    alpha=alpha,
    beta=beta,
    use_bias=use_bias,
    final_activation=get_activation(final_activation),
)

In [None]:
out = gabor_net(x_init)
x_init.shape, out.shape

In [None]:
learn = ImageModel(gabor_net, learning_rate=learning_rate)

In [None]:
trainer = Trainer(
    min_epochs=1,
    max_epochs=num_epochs,
    accelerator="mps",
    enable_progress_bar=True,
    logger=None,
    callbacks=callbacks,
)

In [None]:
trainer.fit(learn, datamodule=dm)

In [None]:
res = trainer.test(learn, dataloaders=dm.test_dataloader())

results["gabornet"] = res[0]

In [None]:
table = [
    [
        key,
        f"{results[key]['test_loss']:4.4f}",
        f"{results[key]['test_psnr']:4.3f}",
        # "{:,}".format(sum([np.prod(p.shape) for p in flow_dict[key]["model"].parameters()]))
    ]
    for key in results
]
display(
    HTML(
        tabulate.tabulate(
            table,
            tablefmt="html",
            headers=[
                "Model",
                "MSE",
                "PSNR",  # "Num Parameters"
            ],
        )
    )
)

In [None]:
# t0 = time.time()
predictions = trainer.predict(learn, dataloaders=dm, return_predictions=True)
predictions = torch.cat(predictions)
# t1 = time.time() - t0

In [None]:
img_pred = dm.coordinates_2_image(predictions)

images["gabornet"] = img_pred

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
ax.imshow(img_pred, cmap="gray")
plt.tight_layout()
plt.savefig(
    f"./results/img_reg_gabornet.png",
    bbox_inches="tight",
)
plt.close()

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(8, 6))
ax[0].imshow(img, cmap="gray")
ax[1].imshow(img_pred, cmap="gray")
ax[2].imshow(np.abs(img_pred - img), cmap="gray")
plt.tight_layout()
plt.show()

In [None]:
results

In [None]:
results_ = pd.DataFrame(results).T

In [None]:
results_.to_csv("./results/metrics.csv")