# Mutal Information 구하기

-----------------------------
- 방법1. sklearn.metrics 이용

- 방법2. 코드 짜서 이용 (출처: https://github.com/connorlee77/pytorch-mutual-information)

웬만하면 방법1 사용하자 혹시나해서 방법2 코드 남겨둠.

## 방법1

In [1]:
import numpy as np 
from PIL import Image
from sklearn.metrics import normalized_mutual_info_score

In [11]:
# path1 = "/SSD3_8TB/Daniel/06_pGAN/pGAN-cGAN/results/1_pGAN_run_align/test_latest/images/IXI0_fake_B.png"
# path2 = "/SSD3_8TB/Daniel/06_pGAN/pGAN-cGAN/results/1_pGAN_run_align/test_latest/images/IXI0_real_B.png"
path1 = "/SSD3_8TB/Daniel/06_pGAN/pGAN-cGAN/datasets/MI_FID/IXI0_real_B.png"
path2 = "/SSD3_8TB/Daniel/06_pGAN/pGAN-cGAN/datasets/MI_FID/IXI0_real_A (3).png"

### Create test cases ###
img1 = Image.open(path1).convert('L')
img2 = Image.open(path2).convert('L')

arr1 = np.array(img1)
arr2 = np.array(img2)

# MI 계산
normalized_mutual_info_score(arr1.ravel(), arr2.ravel())

0.22438758557365795

## 방법2

In [7]:
import os
import numpy as np 

import torch
import torch.nn as nn

import skimage.io
import matplotlib.pyplot as plt

from PIL import Image
from torchvision import transforms




class MutualInformation(nn.Module):

	def __init__(self, sigma=0.1, num_bins=256, normalize=True):
		super(MutualInformation, self).__init__()

		self.sigma = sigma
		self.num_bins = num_bins
		self.normalize = normalize
		self.epsilon = 1e-10

		self.bins = nn.Parameter(torch.linspace(0, 255, num_bins, device=device).float(), requires_grad=False)


	def marginalPdf(self, values):

		residuals = values - self.bins.unsqueeze(0).unsqueeze(0)
		kernel_values = torch.exp(-0.5*(residuals / self.sigma).pow(2))
		
		pdf = torch.mean(kernel_values, dim=1)
		normalization = torch.sum(pdf, dim=1).unsqueeze(1) + self.epsilon
		pdf = pdf / normalization
		
		return pdf, kernel_values


	def jointPdf(self, kernel_values1, kernel_values2):

		joint_kernel_values = torch.matmul(kernel_values1.transpose(1, 2), kernel_values2) 
		normalization = torch.sum(joint_kernel_values, dim=(1,2)).view(-1, 1, 1) + self.epsilon
		pdf = joint_kernel_values / normalization

		return pdf


	def getMutualInformation(self, input1, input2):
		'''
			input1: B, C, H, W
			input2: B, C, H, W
			return: scalar
		'''

		# Torch tensors for images between (0, 1)
		input1 = input1*255
		input2 = input2*255

		B, C, H, W = input1.shape
		assert((input1.shape == input2.shape))

		x1 = input1.view(B, H*W, C)
		x2 = input2.view(B, H*W, C)
		
		pdf_x1, kernel_values1 = self.marginalPdf(x1)
		pdf_x2, kernel_values2 = self.marginalPdf(x2)
		pdf_x1x2 = self.jointPdf(kernel_values1, kernel_values2)

		H_x1 = -torch.sum(pdf_x1*torch.log2(pdf_x1 + self.epsilon), dim=1)
		H_x2 = -torch.sum(pdf_x2*torch.log2(pdf_x2 + self.epsilon), dim=1)
		H_x1x2 = -torch.sum(pdf_x1x2*torch.log2(pdf_x1x2 + self.epsilon), dim=(1,2))

		mutual_information = H_x1 + H_x2 - H_x1x2
		
		if self.normalize:
			mutual_information = 2*mutual_information/(H_x1+H_x2)

		return mutual_information


	def forward(self, input1, input2):
		'''
			input1: B, C, H, W
			input2: B, C, H, W
			return: scalar
		'''
		return self.getMutualInformation(input1, input2)


In [8]:
device = 'cuda:0'

path_fake = "/SSD3_8TB/Daniel/06_pGAN/pGAN-cGAN/results/1_pGAN_run_align/test_latest/images/IXI0_fake_B.png"
path_real = "/SSD3_8TB/Daniel/06_pGAN/pGAN-cGAN/results/1_pGAN_run_align/test_latest/images/IXI0_real_B.png"

### Create test cases ###
img1 = Image.open(path_fake).convert('L')
img2 = Image.open(path_real).convert('L')

arr1 = np.array(img1)
arr2 = np.array(img2)

mi_true_1 = normalized_mutual_info_score(arr1.ravel(), arr2.ravel())
mi_true_2 = normalized_mutual_info_score(arr2.ravel(), arr2.ravel())

img1 = transforms.ToTensor() (img1).unsqueeze(dim=0).to(device)
img2 = transforms.ToTensor() (img2).unsqueeze(dim=0).to(device)

# Pair of different images, pair of same images
input1 = torch.cat([img1, img2])
input2 = torch.cat([img2, img2])

MI = MutualInformation(num_bins=256, sigma=0.1, normalize=True).to(device)
mi_test = MI(input1, input2)

mi_test_1 = mi_test[0].cpu().numpy()
mi_test_2 = mi_test[1].cpu().numpy()

print('Image Pair 1 | sklearn MI: {}, this MI: {}'.format(mi_true_1, mi_test_1))
print('Image Pair 2 | sklearn MI: {}, this MI: {}'.format(mi_true_2, mi_test_2))

assert(np.abs(mi_test_1 - mi_true_1) < 0.05)
assert(np.abs(mi_test_2 - mi_true_2) < 0.05)

Image Pair 1 | sklearn MI: 0.378635048312357, this MI: 0.3786352872848511
Image Pair 2 | sklearn MI: 1.0, this MI: 1.0
