In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import PowerNorm, LogNorm

from bm3d import bm3d
from scipy.ndimage import median_filter
from scipy.ndimage import convolve
from pyhdf.SD import SD, SDC

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy.signal import windows
from tqdm import tqdm

In [2]:
def load_hdf_as_numpy(file_path, var_name):
    f = SD(file_path, SDC.READ)
    data = f.select(var_name).get()
    arr = np.array(data, dtype=np.float32)
    return arr

def print_results(data_filt, 
                  data_unf, 
                  lower_percentile=2,
                  upper_percentile=98,
                  cmap="inferno",
                  norm_type="linear",
                  gamma=0.5):
    
    # --- percentile-based vmin / vmax ---
    all_data = np.concatenate([data_filt.ravel(), data_unf.ravel()])
    vmin = np.percentile(all_data, lower_percentile)
    vmax = np.percentile(all_data, upper_percentile)

    # --- choose normalization ---
    if norm_type == "log":
        norm = LogNorm(vmin=max(vmin, 1e-6), vmax=vmax)
    elif norm_type == "power":
        norm = PowerNorm(gamma=gamma, vmin=vmin, vmax=vmax)
    else:  # linear
        norm = None

    # --- plot ---
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))

    im0 = axes[0].imshow(data_unf, cmap=cmap, vmin=vmin, 
                         vmax=vmax, norm=norm, aspect="auto")
    axes[0].set_title(f"UNFILTERED")

    im1 = axes[1].imshow(data_filt, cmap=cmap, vmin=vmin, 
                         vmax=vmax, norm=norm, aspect="auto")
    axes[1].set_title(f"FILTERED")

    fig.colorbar(im0, ax=axes, fraction=0.025)
    plt.tight_layout()
    plt.show()

### Baseline: Median Filter