In [None]:
import torch
from hdn.lib.gaussianMixtureNoiseModel import GaussianMixtureNoiseModel
from hdn.lib import histNoiseModel 
from hdn.lib.utils import plotProbabilityDistribution
from datasets import load_datasets_yml
from hdn.lib import histNoiseModel
import tifffile
import numpy as np
import matplotlib.pyplot as plt
import os
import argparse
import logging as log

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from PIL import Image

In [None]:
dataset_name = 'support'
gt_name = 'n2v'
dataset_yml='datasets.yml'

device = torch.device("cuda:0")
dset = [d for d in load_datasets_yml(dataset_yml=dataset_yml) if d['name'] == dataset_name][0]

gt_path = f'predictions/{dataset_name}/{gt_name}.tiff'
out_folder = os.path.join('noise_models', dataset_name, gt_name)
os.makedirs(out_folder, exist_ok=True)

log.info(f"Loading signals...")
observations = tifffile.imread(dset['path'])
datamin, datamax = observations.min(), observations.max()
signal = tifffile.imread(gt_path).squeeze()

In [None]:
min_signal=np.percentile(signal, 0)
max_signal=np.percentile(signal, 99)
min_signal, max_signal
bins=256

In [None]:


histogram = histNoiseModel.createHistogram(bins=bins, minVal=datamin, maxVal=datamax, observation=observations, signal=signal)
histogramFD = histogram[0]
gaussianMixtureNoiseModel = GaussianMixtureNoiseModel(min_signal = min_signal, 
                                                      max_signal = max_signal, 
                                                      path=out_folder,
                                                      weight = None, 
                                                      n_gaussian = 3, 
                                                      n_coeff = 2, 
                                                      device = device, 
                                                      min_sigma = 50)
gaussianMixtureNoiseModel.train(signal, observations, batchSize = 250000, n_epochs = 50, learning_rate = 0.1, name = 'GMM', lowerClip = 0.5, upperClip = 99.5)




def plotProbabilityDistribution(ax1, ax2, signalBinIndex, histogram, gaussianMixtureNoiseModel, min_signal, max_signal, n_bin, device):
    """Plots probability distribution P(x|s) for a certain ground truth signal."""
    histBinSize = (max_signal - min_signal) / n_bin
    querySignal_numpy = (signalBinIndex / float(n_bin) * (max_signal - min_signal) + min_signal)
    querySignal_numpy += histBinSize / 2
    querySignal_torch = torch.from_numpy(np.array(querySignal_numpy)).float().to(device)
    
    queryObservations_numpy = np.arange(min_signal, max_signal, histBinSize)
    queryObservations_numpy += histBinSize / 2
    queryObservations = torch.from_numpy(queryObservations_numpy).float().to(device)
    pTorch = gaussianMixtureNoiseModel.likelihood(queryObservations, querySignal_torch)
    pNumpy = pTorch.cpu().detach().numpy()
    
    ax1.clear()
    ax2.clear()
    
    ax1.set_xlabel('Observation Bin')
    ax1.set_ylabel('Signal Bin')
    ax1.imshow(histogram**0.25, cmap='gray')
    ax1.axhline(y=signalBinIndex + 0.5, linewidth=5, color='blue', alpha=0.5)
    
    ax2.plot(queryObservations_numpy, histogram[signalBinIndex, :] / histBinSize, label='GT Hist: bin =' + str(signalBinIndex), color='blue', linewidth=2)
    ax2.plot(queryObservations_numpy, pNumpy, label='GMM : ' + ' signal = ' + str(np.round(querySignal_numpy, 2)), color='red', linewidth=2)
    ax2.set_xlabel('Observations (x) for signal s = ' + str(querySignal_numpy))
    ax2.set_ylabel('Probability Density')
    ax2.set_title("Probability Distribution P(x|s) at signal =" + str(querySignal_numpy))
    ax2.legend()

# Initialize figure and axes
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

# Animation function
def animate(i):
    plotProbabilityDistribution(ax1, ax2, signalBinIndex=i, histogram=histogramFD, gaussianMixtureNoiseModel=gaussianMixtureNoiseModel, min_signal=datamin, max_signal=datamax, n_bin=bins, device=device)
    return ax1, ax2

# Create the animation
ani = animation.FuncAnimation(fig, animate, frames=bins, interval=200, blit=False)

# Save the animation as a GIF
ani.save(os.path.join(out_folder, 'animation.gif'), writer='pillow')

# Display the GIF in a Jupyter Notebook (if you're using one)
from IPython.display import Image as IPImage
IPImage(filename=os.path.join(out_folder, 'animation.gif'))

In [None]:
plt.imshow(histogram[0]**0.25, cmap='grey')