In [1]:
from ri_ssim.rims_ssim import micro_MS_SSIM
from ri_ssim import micro_SSIM

In [2]:
from disentangle.core.tiff_reader import load_tiff

gt_path = '/group/jug/ashesh/ri_ssim/gt-N2V-2402-31.tif'
pred_path = '/group/jug/ashesh/ri_ssim/pred-N2V-2402-31.tif'
noisy_gt_path = '/group/jug/ashesh/data/ventura_gigascience/actin-60x-noise2-lowsnr.tif'


gt = load_tiff(gt_path)
pred = load_tiff(pred_path)
gt_lowsnr = load_tiff(noisy_gt_path)

In [3]:
gt.shape

(100, 2048, 2048)

In [4]:
idx = 0 
gt_tmp = gt[idx]
pred_tmp = pred[idx]
print(micro_SSIM(gt_tmp, pred_tmp, ri_factor=None, data_range=gt_tmp.max() - gt_tmp.min()))


0.8087280845276335


In [5]:
from ri_ssim.rims_ssim import * 
def micro_MS_SSIM(
    target_img,
    pred_img,
    *,
    betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
    win_size=None,
    data_range=None,
    channel_axis=None,
    gaussian_weights=False,
    ri_factor: Union[float, None] = None,
    return_individual_components: bool = False,
    **kwargs,
):
    mcs_list = []
    for _ in range(len(betas)):
        ssim_dict = structural_similarity_dict(
            target_img,
            pred_img,
            win_size=win_size,
            data_range=data_range,
            channel_axis=channel_axis,
            gaussian_weights=gaussian_weights,
            **kwargs,
        )
        if ri_factor is None:
            ri_factor = get_ri_factor(ssim_dict)
        
        ux, uy, vx, vy, vxy, C1, C2 = (
            ssim_dict["ux"],
            ssim_dict["uy"],
            ssim_dict["vx"],
            ssim_dict["vy"],
            ssim_dict["vxy"],
            ssim_dict["C1"],
            ssim_dict["C2"],
        )
        A1, A2, B1, B2 = (
            2 * ri_factor * ux * uy + C1,
            2 * ri_factor * vxy + C2,
            ux**2 + (ri_factor**2) * uy**2 + C1,
            vx + (ri_factor**2) * vy + C2,
        )
        assert A1.shape == A2.shape == B1.shape == B2.shape
        assert len(A1.shape) == 2
        sim = (A1/B1).mean().reshape(1,)
        contrast_sensitivity = (A2/B2).mean().reshape(1,)

        mcs_list.append(contrast_sensitivity)

        pred_img = block_reduce(pred_img, (2,2), np.mean)
        target_img = block_reduce(target_img, (2, 2), np.mean)

    mcs_list[-1] = sim
    mcs_stack = np.stack(mcs_list)

    betas = np.array(betas).reshape(-1,1)
    mcs_weighted = mcs_stack**betas
    return np.prod(mcs_weighted, axis=0)



In [9]:
from torchmetrics.image import StructuralSimilarityIndexMeasure
import torch
gt_torch = torch.Tensor(gt_tmp[None,None]*1.0)
pred_torch = torch.Tensor(pred_tmp[None,None] * 1.0)
ssim = StructuralSimilarityIndexMeasure(data_range=gt_torch.max() - gt_torch.min())
ssim(pred_torch, gt_torch)

tensor(0.8068)

In [10]:
print(micro_SSIM(gt_tmp, pred_tmp, ri_factor=1.0, data_range=gt_tmp.max() - gt_tmp.min(), gaussian_weights=True))


0.806612505838215


In [11]:
from skimage.metrics import structural_similarity as ssim
ssim(gt_tmp, pred_tmp, data_range=gt_tmp.max() - gt_tmp.min(), gaussian_weights=True)

0.806612505838215

In [12]:
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
import torch
gt_torch = torch.Tensor(gt_tmp[None,None]*1.0)
pred_torch = torch.Tensor(pred_tmp[None,None] * 1.0)
ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=gt_torch.max() - gt_torch.min())
ms_ssim(pred_torch, gt_torch)

tensor(0.8136)

In [13]:
print(micro_MS_SSIM(gt_tmp, pred_tmp, ri_factor=1.0, data_range=gt_tmp.max() - gt_tmp.min()))

[0.82303994]
