# Visualize Results of a Patch-Based Model
This notebook allows you to see the results of a patch-based model.<br>
You can look at just sinograms, or both sinograms and reconstructions.<br>
You can also choose which sample you want to see results from, and which slices you want to see.<br>
The default behaviour is to randomly choose a sample number and slices.<br>

### Imports

In [None]:
from pathlib import Path
import numpy as np
import torch
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["figure.dpi"] = 300
%matplotlib inline

from network.patch_visualizer import PatchVisualizer
from network.models import MaskedGAN
from network.models.generators import PatchUNet
from network.models.discriminators import PatchDiscriminator

# Number of threads for OpenMP. If too high, may cause error
%env OMP_NUM_THREADS=8

## Parameters
Specify the parameters you want to run with here.<br>

`data_dir` is the directory containing the input data.<br>
`model_file` is the path to the model you want to visualize.<br>
`mask_file` is the path to the binary mask indicating locations of stripes.<br>
`recon` is a bool indicating whether reconstructions should be plotted as well.<br>
`sample_no` is the sample number to load data from:<br>
 - `0`    :    119617.nxs (i12 Data)
 - `1`    :    68067.nxs  (Nghia's data)
 - `2`    :    123893.nxs (?)
 - `3`    :    123951.nxs (?)
 - `4`    :    124272.nxs (?)
 - `5`    :    124273.nxs (?)
 
By default, `sample_no` is randomly generated.<br>

In [None]:
data_dir = Path('/path/to/data')
model_file = Path('/path/to/model.tar')
mask_file = Path('/path/to/stripe_mask.npz')
recon = True

rng = np.random.default_rng()
sample_no = rng.integers(6)
print(f"Sample No.: {sample_no}")

### Setup

In [None]:
if torch.cuda.is_available():
    d = torch.device('cuda')
else:
    d = torch.device('cpu')

# Load model state dict from disk
checkpoint = torch.load(model_file, map_location=d)
# Initialize Generator and Discriminator
gen = torch.nn.DataParallel(PatchUNet())
gen.load_state_dict(checkpoint['gen_state_dict'])
disc = torch.nn.DataParallel(PatchDiscriminator())
disc.load_state_dict(checkpoint['disc_state_dict'])
# Initialize Model
model = MaskedGAN(gen, disc, mode='test', device=d)

# Initialize Visualizer
v = PatchVisualizer(data_dir, model, sample_no=sample_no, mask_file=mask_file)

## Plot Fake Artifacts

`fake_artifact_idx` is the slice index of the sinogram you want to plot.<br>
Be sure to choose a sinogram that has **no** real-life artifacts, as otherwise you'll see black boxes where the real artifact is.<br>
The `clean_idxs` attribute of `PatchVisualizer` is a list of all sinogram slice indexes that don't contain any real-life artifacts.<br>
By default, `fake_artifact_idx` is randomly chosen from `clean_idxs`.<br>

In [None]:
fake_artifact_idx = rng.choice(v.clean_idxs)
v.plot_all(fake_artifact_idx, recon=recon)
print(f"Index: {fake_artifact_idx}")

## Plot Real-life Artifacts
`real_artifact_idx` is the slice index of the sinogram you want to plot.<br>
Be sure to choose a sinogram that has **at least one** real-life artifact, as otherwise you won't be able to see the effect of the model.<br>
The `stripe_idxs` attribute of `PatchVisualizer` is a list of all sinogram slice indexes that contain at least one real-life artifact.<br>
By default, `real_artifact_idx` is randomly chosen from `stripe_idxs`.<br>

In [None]:
real_artifact_idx = rng.choice(v.stripe_idxs)
v.plot_all_raw(real_artifact_idx, recon=recon)
print(f"Index: {real_artifact_idx}")