In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]=" "

In [None]:
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
sess = tf.Session()

In [None]:
vmax = 255

In [None]:
a = tf.random.uniform([6, 320, 256, 12], maxval=vmax+1, dtype='int32')
b = tf.random.uniform([6, 320, 256, 12], maxval=vmax+1, dtype='int32')

In [None]:
def mutual_information_np(hist2d):
    pxy = hist2d / np.sum(hist2d)
    px = np.sum(pxy, axis=1)
    py = np.sum(pxy, axis=0)
    px_py = px[:, None] * py[None, :]
    nzs = pxy > 0
    return np.sum(pxy[nzs] * np.log(pxy[nzs] / px_py[nzs]))

In [None]:
def mutual_information_single(hist2d):
    tmp = tf.cast(hist2d, dtype='float64')
    pxy = tmp / tf.reduce_sum(tmp)
    px = tf.reduce_sum(pxy, axis=1)
    py = tf.reduce_sum(pxy, axis=0)
    px_py = px[:, None] * py[None, :]
    nzs = tf.greater(pxy, 0)
    return tf.reduce_sum(tf.boolean_mask(pxy, nzs) * tf.log(tf.boolean_mask(pxy, nzs) / tf.boolean_mask(px_py, nzs)))


In [None]:
def tf_joint_histogram(y_true, y_pred):
    """
    y_true : [batch, height, width, channel]
    y_pred : [batch, height, width, channel]
    """
    
    vmax = 255
    
    b, h, w, c = y_true.get_shape()
    
    # [batch, height, width, channel]
    # -> [batch, height * width, channel]
    # -> [batch, channel, height * width]
    
    flat_true = tf.transpose(tf.reshape(y_true, [b, h*w, c]), [0, 2, 1])
    flat_true = tf.reshape(flat_true, [b*c, h*w])
    flat_pred = tf.transpose(tf.reshape(y_pred, [b, h*w, c]), [0, 2, 1])
    flat_pred = tf.reshape(flat_pred, [b*c, h*w])
    
    output = (flat_pred * (vmax+1)) + (flat_true+1)
    # [b*c, 65536]
    output = tf.map_fn(lambda x : tf.histogram_fixed_width(x, value_range=[1, (vmax+1)**2], nbins=(vmax+1)**2), output)
    # [b, c, 256, 256] -> [b, 256, 256, c]
    output = tf.transpose(tf.reshape(output, [b, c, vmax+1, vmax+1]), [0, 2, 3, 1])
    return output, y_true, y_pred

In [None]:
def mutual_information(y_true, y_pred):
    
    # [b, 256, 256, c]
    joint_histogram, _, _ = tf_joint_histogram(y_true, y_pred)
    b, h, w, c = joint_histogram.get_shape()
    
    # [b*c, 256, 256]
    joint_histogram = tf.reshape(tf.transpose(joint_histogram, [0, 3, 1, 2]), [b*c, h, w])
    
    output = tf.map_fn(lambda x : mutual_information_single(x), joint_histogram, dtype=tf.float64)
    output = tf.reshape(output, [b, c])
    return output, y_true, y_pred, joint_histogram

In [None]:
test, test_a, test_b, test_hist2d = sess.run(mutual_information(a, b))
hist2d, _, _ = np.histogram2d(test_a[0,...,0].ravel(), test_b[0,...,0].ravel(), bins=vmax+1, range=[[0,vmax], [0,vmax]])


In [None]:
print(test[0,0])
print(mutual_information_np(hist2d))