# Visualize Results of a Patch-Based Model
This notebook allows you to see the results of a patch-based model.<br>
You can look at just sinograms, or both sinograms and reconstructions.<br>
You can also choose which sample you want to see results from, and which slices you want to see.<br>
The default behaviour is to randomly choose a sample number and slices.<br>

### Imports

In [None]:
from pathlib import Path
import wandb
import numpy as np
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP, DataParallel as DP
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (12, 8)
plt.rcParams["figure.dpi"] = 300
%matplotlib inline

from network.patch_visualizer import PatchVisualizer
from network.models import MaskedGAN
from network.models.generators import PatchUNet
from network.models.discriminators import PatchDiscriminator

from utils.tomography import reconstruct

# Number of threads for OpenMP. If too high, may cause error
%env OMP_NUM_THREADS=16

# Auto-reload code from disk
%load_ext autoreload
%autoreload 2

## Parameters
Specify the parameters you want to run with here.<br>

`data_dir` is the directory containing the input data.<br>
`model_file` is the path to the model you want to visualize.<br>
`mask_file` is the path to the binary mask indicating locations of stripes.<br>
`sample_no` is the sample number to load data from.<br>

In [None]:
i12 = Path('/dls/i12/data/2022/nt33730-1/processing/NoStripesNet')
data_dir = i12/'data'/'wider_stripes'
model_file = i12/'pretrained_models'/'five_sample'/'4x4'/'val'/'five_sample_4x4_100.tar'
mask_file = i12/'stripe_masks.npz'

rng = np.random.default_rng()
cor = 1253
sample_no = 0
print(f"Sample No.: {sample_no}")

### Setup

In [None]:
if torch.cuda.is_available():
    d = torch.device('cuda')
else:
    d = torch.device('cpu')

# Load model state dict from disk
checkpoint = torch.load(model_file, map_location=d)
# Initialize Generator and Discriminator
gen = DP(PatchUNet())
gen.load_state_dict(checkpoint['gen_state_dict'])
disc = DP(PatchDiscriminator())
disc.load_state_dict(checkpoint['disc_state_dict'])
# Initialize Model
model = MaskedGAN(gen, disc, mode='test', device=d)

# Initialize Visualizer
v = PatchVisualizer(data_dir, model, sample_no=sample_no, mask_file=mask_file)

# Figure Dictionary
fig_dict = {}

## Plot Fake Artifacts

`fake_artifact_idx` is the slice index of the sinogram you want to plot.<br>
Be sure to choose a sinogram that has **no** real-life artifacts, as otherwise you'll see black boxes where the real artifact is.<br>
The `clean_idxs` attribute of `PatchVisualizer` is a list of all sinogram slice indexes that don't contain any real-life artifacts.<br>
By default, `fake_artifact_idx` is randomly chosen from `clean_idxs`.<br>

In [None]:
fake_artifact_idx = rng.choice(v.clean_idxs)
print(f"Index: {fake_artifact_idx}")
clean = v.get_sinogram(fake_artifact_idx, 'clean')
stripe = v.get_sinogram(fake_artifact_idx, 'stripe')
gen_out = v.get_model_sinogram(fake_artifact_idx, 'fake')

clean_r = reconstruct(clean, rot_center=cor, ncore=16)
stripe_r = reconstruct(stripe, rot_center=cor, ncore=16)
gen_out_r = reconstruct(gen_out, rot_center=cor, ncore=16)

In [None]:
fig, axs = plt.subplots(2, 3)
fig.suptitle("Synthetic Stripes", size='xx-large')

axs[0, 0].set_title(f"Clean {fake_artifact_idx}")
axs[0, 0].axis('off')
axs[0, 0].imshow(clean, cmap='gray')

axs[0, 1].set_title(f"Stripe {fake_artifact_idx}")
axs[0, 1].axis('off')
axs[0, 1].imshow(stripe, cmap='gray')

axs[0, 2].set_title(f"Model Output {fake_artifact_idx}")
axs[0, 2].axis('off')
axs[0, 2].imshow(gen_out, cmap='gray')

axs[1, 0].axis('off')
axs[1, 0].imshow(clean_r, cmap='gray', vmin=-0.03, vmax=0.15)
axs[1, 1].axis('off')
axs[1, 1].imshow(stripe_r, cmap='gray', vmin=-0.03, vmax=0.15)
axs[1, 2].axis('off')
axs[1, 2].imshow(gen_out_r, cmap='gray', vmin=-0.03, vmax=0.15)

fig_dict[f"Sinogram {fake_artifact_idx}"] = fig

## Plot Real-life Artifacts
`real_artifact_idx` is the slice index of the sinogram you want to plot.<br>
Be sure to choose a sinogram that has **at least one** real-life artifact, as otherwise you won't be able to see the effect of the model.<br>
The `stripe_idxs` attribute of `PatchVisualizer` is a list of all sinogram slice indexes that contain at least one real-life artifact.<br>
By default, `real_artifact_idx` is randomly chosen from `stripe_idxs`.<br>

In [None]:
real_artifact_idx = rng.choice(v.stripe_idxs)
print(f"Index: {real_artifact_idx}")
stripe = v.get_sinogram(real_artifact_idx, 'raw')
gen_out = v.get_model_sinogram(real_artifact_idx, 'real')

stripe_r = reconstruct(stripe, rot_center=cor, ncore=16)
gen_out_r = reconstruct(gen_out, rot_center=cor, ncore=16)

In [None]:
fig, axs = plt.subplots(2, 2)
fig.suptitle("Real-life Stripes", size='xx-large')

axs[0, 0].set_title(f"Stripe {real_artifact_idx}")
axs[0, 0].axis('off')
axs[0, 0].imshow(stripe, cmap='gray')

axs[0, 1].set_title(f"Model Output {real_artifact_idx}")
axs[0, 1].axis('off')
axs[0, 1].imshow(gen_out, cmap='gray')

axs[1, 0].axis('off')
axs[1, 0].imshow(stripe_r, cmap='gray', vmin=-0.03, vmax=0.15)
axs[1, 1].axis('off')
axs[1, 1].imshow(gen_out_r, cmap='gray', vmin=-0.03, vmax=0.15)

fig_dict[f"Sinogram {real_artifact_idx}"] = fig

## Upload to Weights & Biases

In [None]:
fig_dict

In [None]:
api = wandb.Api()
runs = api.runs("nostripesnet/NoStripesNet")
run_name = ""
for rns in runs:
    if rns.name == run_name:
        wandb.init(project='NoStripesNet', entity='nostripesnet', id=rns.id, resume='must')
        wandb.log(fig_dict)

In [None]:
wandb.finish()