In [1]:
import numpy as np
import matplotlib.pyplot as plt
import numpy.random as rnd
from matplotlib.patches import Ellipse
import seaborn as sns
import sys
sys.path.insert(1, '../../')
from SPD_SURE_pytorch import *
from tweedie_ncF import *

In [2]:
def plot_dti_2d(X, save = True, figname = "fig.pdf"):
    eigv, eigvec = np.linalg.eigh(X)
    h, w, N = X.shape[0:3]
    width = eigv[:,:,1]
    height = eigv[:,:,0]
    angle = np.arctan2(eigvec[:,:,0,0], eigvec[:,:,0,1])/np.pi*180


    ells = [Ellipse(xy=(i, j), width=width[i, j], height=height[i, j], angle=angle[i, j])
            for i in range(w) for j in range(h)]

    fig = plt.figure(0, figsize = (6, 6))
    ax = fig.add_subplot(111, aspect='equal')
    for e in ells:
        ax.add_artist(e)
        e.set_clip_box(ax.bbox)
        #e.set_alpha(rnd.rand())
        #e.set_facecolor(rnd.rand(3))

    ax.set_xlim(0-1, w)
    ax.set_ylim(0-1, h)
    ax.axis('off')
    
    if save:
        plt.tight_layout()
        plt.savefig(figname, dpi=100)

    plt.show()


In [3]:
h, w = 20, 20
normal = np.expand_dims(np.array([[0.3, 0],[0, 1]]), axis = (0, 1))
abnormal = np.expand_dims(np.array([[1, 0],[0, 0.3]]), axis = (0, 1))
group1_mean = np.tile(normal, (h, w, 1, 1))
group2_mean = np.tile(normal, (h, w, 1, 1))
group2_mean[int(h/2):h, int(w/2):w] = np.tile(abnormal, (int(h/2), int(w/2), 1, 1))

In [4]:
#plot_dti_2d(group1_mean, figname="group1_mean.pdf")
#plot_dti_2d(group2_mean, figname="group2_mean.pdf")

In [5]:
n1 = 30
n2 = 30
h, w = 20, 20
N = 2
q = int(N*(N+1)/2)
group1_high = np.zeros((n1, h, w, N, N)) # high variance
group2_high = np.zeros((n2, h, w, N, N))
group1_low = np.zeros((n1, h, w, N, N)) # low variance
group2_low = np.zeros((n2, h, w, N, N))
np.random.seed(2021)
for i in range(h):
    for j in range(w):
        # high variance
        sig2 = np.random.uniform(0.3, 0.8)
        group1_high[:,i,j] = SPD_normal(n1, group1_mean[i,j], sig2 * np.eye(q))
        group2_high[:,i,j] = SPD_normal(n2, group2_mean[i,j], sig2 * np.eye(q))
        
        # low variance
        sig2 = np.random.uniform(0.1, 0.3)
        group1_low[:,i,j] = SPD_normal(n1, group1_mean[i,j], sig2 * np.eye(q))
        group2_low[:,i,j] = SPD_normal(n2, group2_mean[i,j], sig2 * np.eye(q))

In [6]:
# compute the FMs for each group
group1_high_FM = np.zeros((h, w, N, N))
group2_high_FM = np.zeros((h, w, N, N))
group1_low_FM = np.zeros((h, w, N, N))
group2_low_FM = np.zeros((h, w, N, N))
for i in range(h):
    for j in range(w):
        group1_high_FM[i,j] = FM_logE(group1_high[:,i,j])
        group2_high_FM[i,j] = FM_logE(group2_high[:,i,j])

        group1_low_FM[i,j] = FM_logE(group1_low[:,i,j])
        group2_low_FM[i,j] = FM_logE(group2_low[:,i,j])

In [7]:
# compute the covariance matrices (at each location) for each group
group1_high_cov = np.zeros((h, w, q, q))
group2_high_cov = np.zeros((h, w, q, q))
group1_low_cov = np.zeros((h, w, q, q))
group2_low_cov = np.zeros((h, w, q, q))
for i in range(h):
    for j in range(w):
        group1_high_cov[i,j] = cov_logE(group1_high[:,i,j])
        group2_high_cov[i,j] = cov_logE(group2_high[:,i,j])

        group1_low_cov[i,j] = cov_logE(group1_low[:,i,j])
        group2_low_cov[i,j] = cov_logE(group2_low[:,i,j])
        
# pooled covariances
pool_cov_high = ((n1-1)*group1_high_cov + (n2-1)*group2_high_cov)/(n1+n2-2)
pool_cov_low = ((n1-1)*group1_low_cov + (n2-1)*group2_low_cov)/(n1+n2-2)


In [8]:
# compute the Hotelling's t^2 statistics
high_t2 = np.zeros((h,w))
low_t2 = np.zeros((h,w))
for i in range(h):
    for j in range(w):
        d = vec(group1_high_FM[i,j]) - vec(group2_high_FM[i,j])
        high_t2[i,j] = 1/(1/n1 + 1/n2)* np.matmul(d, np.matmul(np.linalg.inv(pool_cov_high[i,j]), d.T))[0][0]
        d = vec(group1_low_FM[i,j]) - vec(group2_low_FM[i,j])
        low_t2[i,j] = 1/(1/n1 + 1/n2)* np.matmul(d, np.matmul(np.linalg.inv(pool_cov_low[i,j]), d.T))[0][0]
        
    
# transform to F statistics
nu = n1 + n2 -2
high_f = (nu - q - 2)/(nu * q) * high_t2
low_f = (nu - q - 2)/(nu * q) * low_t2

In [9]:
# compute the MOM and the EB estimates for the non-centrality parameters
df1 = q
df2 = nu - q - 1

high_MOM = np.maximum(df1*(df2 - 2)/df2 * high_f - df1, 0)
low_MOM = np.maximum(df1*(df2 - 2)/df2 * low_f - df1, 0)
high_EB = tweedie_ncF(high_f.reshape((-1)), df1, df2, K = 3, maxit = 5000).reshape((h,w))
low_EB = tweedie_ncF(low_f.reshape((-1)), df1, df2, K = 3, maxit = 1000).reshape((h,w))

In [10]:
# heatmaps
plt.imshow(high_MOM, cmap='hot')
plt.colorbar()
plt.gca().invert_yaxis()
plt.clim(0, 170) 
plt.savefig('high_MOM.png', dpi=300)
plt.clf()

<Figure size 432x288 with 0 Axes>

In [11]:
plt.imshow(low_MOM, cmap='hot')
plt.colorbar()
plt.gca().invert_yaxis()
plt.clim(0, 530) 
plt.savefig('low_MOM.png', dpi=300)
plt.clf()


<Figure size 432x288 with 0 Axes>

In [12]:
plt.imshow(high_EB, cmap='hot')
plt.colorbar()
plt.gca().invert_yaxis()
plt.clim(0, 170) 
plt.savefig('high_EB.png', dpi=300)
plt.clf()

<Figure size 432x288 with 0 Axes>

In [13]:
plt.imshow(low_EB, cmap='hot')
plt.colorbar()
plt.gca().invert_yaxis()
plt.clim(0, 530)
plt.savefig('low_EB.png', dpi=300)
plt.clf()

<Figure size 432x288 with 0 Axes>

In [14]:
from scipy.linalg import logm, expm
m1 = vec(normal[0,0])
m2 = vec(abnormal[0,0])
d = 1/(1/n1 + 1/n2)* np.sum((m1-m2)**2)


In [15]:
# density plot
# high variance
lo = 0.3
up = 0.8
u = np.random.uniform(lo, up, size = 10000)
sns.set_style('whitegrid')
#sns.kdeplot(np.array(d/u))
sns.kdeplot(high_MOM[int(h/2):h, int(w/2):w].reshape(-1), label = 'MOM')
sns.kdeplot(high_EB[int(h/2):h, int(w/2):w].reshape(-1), label = 'Tweedie')
plt.axvline(x = d*np.log(up/lo)/(up-lo), color = 'green', linestyle = '--',
            label = 'Non-centrality parameter')
plt.legend()
plt.savefig('high_density.png', dpi=300)
plt.clf()


<Figure size 432x288 with 0 Axes>

In [16]:
# low variance
lo = 0.1
up = 0.3
u = np.random.uniform(lo, up, size = 10000)
sns.set_style('whitegrid')
#sns.kdeplot(np.array(d/u))
sns.kdeplot(low_MOM[int(h/2):h, int(w/2):w].reshape(-1), label = 'MOM')
sns.kdeplot(low_EB[int(h/2):h, int(w/2):w].reshape(-1), label = 'Tweedie')
plt.axvline(x = d*np.log(up/lo)/(up-lo), color = 'green', linestyle = '--',
            label = 'Non-centrality parameter')
plt.legend()
plt.savefig('low_density.png', dpi=300)
plt.clf()



<Figure size 432x288 with 0 Axes>