# Inference with Variational Lossy Autoencoder

In [None]:
import os

import tifffile
import torch
import pytorch_lightning as pl
from tqdm import tqdm

from dvlae import DVLAE

In [None]:
use_cuda = torch.cuda.is_available()

### Load test data
For the C. Majalis dataset, we follow previous works and denoise the top-left quadrant.

In [None]:
data_path = "data/flower.tif"

test_data = tifffile.imread(data_path)[:, None, :512, :512]
test_data = torch.from_numpy(test_data.astype(float)).to(torch.float)

### Create prediction dataloader

In [None]:
class PredictDataset(torch.utils.data.Dataset):

    def __init__(self, images):
        self.images = images

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        return image

<code>n_samples</code> Number of denoised images to average for MMSE estimate.<br>
<code>predict_batch_size</code> Number of denoised images to produce at a time. Reduce if needed to fit in memory.

In [None]:
n_samples = 100
predict_batch_size = 10

predict_set = PredictDataset(test_data)
predict_loader = torch.utils.data.DataLoader(
    predict_set,
    batch_size=predict_batch_size,
    shuffle=False,
    pin_memory=True,
)

### Load trained model

In [None]:
model_name = "convallaria"
checkpoint_path = os.path.join("checkpoints", model_name)

dvlae = DVLAE.load_from_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))

### Denoise <br>
<code>results_path</code> Where to store denoised results.

In [None]:
results_path = os.path.join("results", model_name)

trainer = pl.Trainer(
    accelerator="gpu" if use_cuda else "cpu",
    devices=1,
)

samples = []
MMSEs = torch.zeros_like(test_data)
for _ in tqdm(range(n_samples)):
    out = trainer.predict(dvlae, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)
    MMSEs += out

MMSEs = MMSEs / n_samples

samples = torch.stack(samples, dim=0)
samples = torch.moveaxis(samples, 1, 0)

if not os.path.exists(results_path):
    os.makedirs(results_path)

tifffile.imwrite(os.path.join(results_path, 'MMSEs.tif'), MMSEs.numpy())
tifffile.imwrite(os.path.join(results_path, 'samples.tif'), samples.numpy())