# Structural Similarity of Model

In [4]:
from pathlib import Path
import numpy as np

import torch
from torch.nn.parallel import DataParallel as DP

from network.patch_visualizer import PatchVisualizer
from network.models import MaskedGAN
from network.models.generators import PatchUNet
from network.models.discriminators import PatchDiscriminator
from utils.metrics import structural_similarity

%load_ext autoreload
%autoreload 2

ImportError: /home/algol/miniconda3/envs/pytorch/lib/python3.11/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/algol/miniconda3/envs/pytorch/lib/python3.11/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-311-x86_64-linux-gnu.so)

## Setup

### Helper Functions

In [5]:
def load_model(path, device=None):
    # Load model state dict from disk
    checkpoint = torch.load(path, map_location=device)
    # Initialize Generator and Discriminator
    gen = DP(PatchUNet())
    gen.load_state_dict(checkpoint['gen_state_dict'])
    disc = DP(PatchDiscriminator())
    disc.load_state_dict(checkpoint['disc_state_dict'])
    # Initialize Model
    model = MaskedGAN(gen, disc, mode='test', device=device)
    return model

### Parameters

In [6]:
i12 = Path('/media/algol/HD-LXU3/No_stripes_net_data')
data_dir = i12/'data'/'wider_stripes'
model_file = i12/'pretrained_models'/'five_sample'/'4x4'/'val'/'five_sample_4x4_100.tar'
mask_file = i12/'stripe_masks.npz'

if torch.cuda.is_available():
    d = torch.device('cuda')
else:
    d = torch.device('cpu')

rng = np.random.default_rng()
cor = 1253
sample_no = 0
print(f"Sample No.: {sample_no}")

Sample No.: 0


## Load Model & Visualizer

In [15]:
from network.patch_visualizer import PatchVisualizer
from network.models import MaskedGAN

ImportError: /home/algol/miniconda3/envs/pytorch/lib/python3.11/site-packages/zmq/backend/cython/../../../../.././libstdc++.so.6: version `GLIBCXX_3.4.30' not found (required by /home/algol/miniconda3/envs/pytorch/lib/python3.11/site-packages/scipy/fft/_pocketfft/pypocketfft.cpython-311-x86_64-linux-gnu.so)

In [14]:
len(v.clean_idxs)

NameError: name 'v' is not defined

In [None]:
model = load_model(model_file, device=d)
v = PatchVisualizer(data_dir, model, sample_no=sample_no, mask_file=mask_file,
                    full_sino_size=(1801, 2560), patch_size=(1801, 256))

# `v.clean_idxs` contains naturally clean sinograms
# `v.stripe_idxs` contains naturally stripey sinograms
# Can only do SSIM for artificial stripes, as no ground truth for natural stripes

# 'sro' means stripe regions only
clean_sinos = []
clean_sro = []
stripe_sinos = []
stripe_sro = []
model_sinos = []
model_sro = []
for i in v.clean_idxs:
    clean_sino = v.get_sinogram(i, 'clean')
    stripe_sino = v.get_sinogram(i, 'stripe')
    model_sino = v.get_model_sinogram(i)
    mask = v.mask[:, i, :].asype(np.bool_)
    
    clean_sinos.append(clean_sino)
    stripe_sinos.append(stripe_sino)
    model_sinos.append(model_sino)
    clean_sro.append(clean_sino[mask])
    stripe_sro.append(stripe_sino[mask])
    model_sro.append(model_sino[mask])

## Calculate Structural Similarity

### SSIM from Clean to Model Outputs

#### Whole Sinograms

In [None]:
model_sino_ssim = structural_similarity(clean_sinos, model_sinos)
print(f"SSIM from Clean to Model Outputs, Whole Sinograms: {model_sino_ssim}")

#### Stripe Regions Only

In [None]:
model_sro_ssim = structural_similarity(clean_sro, model_sro)
print(f"SSIM from Clean to Model Outputs, Stripe Regions Only: {model_sro_ssim}")

### SSIM from Clean to Stripe Data

#### Whole Sinograms

In [None]:
control_sino_ssim = structural_similarity(clean_sinos, stripe_sinos)
print(f"SSIM from Clean to Stripe Data, Whole Sinograms: {control_sino_ssim}")

#### Stripe Regions Only

In [None]:
control_sro_ssim = structural_similarity(clean_sro, model_sro)
print(f"SSIM from Clean to Stripe Data, Stripe Regions Only: {control_sro_ssim}")