# Series mean intensity histogram

In [None]:
SAMPLE_IDX = 1

## Series plot functions

In [None]:
import os
import sys

sys.path.append('../')

import matplotlib.pyplot as plt
import numpy as np
import torch

from reg.data import LungDataModule

In [None]:
def setup_data_module():
    n_available_cores = len(os.sched_getaffinity(0)) - 1
    n_available_cores = 1 if n_available_cores == 0 else n_available_cores
    data_module = LungDataModule(
        root_dir="/media/agjvc_rad3/_TESTKOLLEKTIV/Daten/Daten",
        split=(0.7, 0.1, 0.2),
        seed=42,
        pin_memory=True,
        num_workers=n_available_cores,
    )
    data_module.setup()
    return data_module

def fetch_sample_from_dataloader(dataloader, sample_idx):
    for i, batch in enumerate(dataloader):
        if i == sample_idx or sample_idx is None:
            return batch

In [None]:
def histogram_main(sample_idx):
    CUT_OFF = 30

    data_module = setup_data_module()
    dataloader = data_module.test_dataloader()

    if sample_idx is None:
        sample_idx = np.random.randint(0, 64)

    moving_series = fetch_sample_from_dataloader(dataloader, sample_idx)

    image_means = moving_series.mean(axis=(2, 3))[0, 0][CUT_OFF:]
    mean_of_means = torch.mean(image_means)
    std_of_means = torch.std(image_means)
    diff = torch.abs(image_means - mean_of_means)
    _, max_diff_i = torch.topk(diff, 1, largest=True)
    _, mean_i = torch.topk(diff, 1, largest=False)
    _, max_i = torch.topk(image_means, 1)

    image_indices = np.array(list(range(0, len(image_means)))) + CUT_OFF

    figsize = (16, 5)

    fig, ax = plt.subplots(1, 1, figsize=figsize)
    fig.set_tight_layout(True)

    ax.set_title("Mean of Image Series")
    ax.set_xlabel("Image Index")
    ax.set_ylabel("Mean Value")

    ax.plot(image_indices, image_means, "-", color='b', lw=2, label="Image Means")
    ax.axvline(x=(max_diff_i + CUT_OFF).numpy()[0], color='r', linestyle='-', lw=2, label=f"Peak at idx = {(max_diff_i + CUT_OFF).numpy()[0]}")
    ax.axhline(y=image_means[-1], color='purple', linestyle="dashed", lw=2, label="Last Mean")
    ax.axhline(y=image_means[mean_i], color='green', linestyle="dashdot", lw=2, label="Mean of Means")
    ax.axhline(y=image_means[max_i], color='orange', linestyle="dotted", lw=2, label="Max Mean")

    # Plot sigma lines
    ax.axhline(y=mean_of_means + std_of_means, color='y', linestyle="dotted", lw=2, label="Mean + 1 Sigma")
    ax.axhline(y=mean_of_means - std_of_means, color='y', linestyle="dotted", lw=2, label="Mean - 1 Sigma")
    ax.axhline(y=mean_of_means + 2 * std_of_means, color='orange', linestyle="dotted", lw=2, label="Mean + 2 Sigma")
    ax.axhline(y=mean_of_means - 2 * std_of_means, color='orange', linestyle="dotted", lw=2, label="Mean - 2 Sigma")

    # Set x-axis ticks every 10 values
    ax.set_xticks(np.arange(image_indices[0], image_indices[-1] + 1, 10))

    ax.legend(loc='lower right', fontsize='small')
    ax.grid(True, which='both', linestyle='--', lw=0.5)

    plt.show()
    plt.close()

## Main

In [None]:
histogram_main(SAMPLE_IDX)