In [None]:
import os, glob
import matplotlib.pyplot as plt
import json
import numpy as np
import random
import dipy.io.image
import dipy.io
import dipy.core.gradients

In [None]:
data_dir = '/home/ebrahim/data/abcd/DMRI_extracted'
img_dirs = glob.glob(os.path.join(data_dir,'*ABCD-MPROC-DTI*/sub-*/ses-*/dwi/'))

In [None]:
reg_mats = []
for img_dir in img_dirs:
    json_path = glob.glob(os.path.join(img_dir, '*.json'))[0]
    with open(json_path) as f:
        json_info = json.load(f)
    reg_mats.append(np.array(json_info['registration_matrix_T1']))
    
if all((m==np.eye(4)).all() for m in reg_mats):
    print("All matrices provided for affine registration to the T1 image are the identity matrix.")

In [None]:
def load_img_dir(img_dir):
    dwi_path = glob.glob(os.path.join(img_dir, '*.nii'))[0]
    bval_path = glob.glob(os.path.join(img_dir, '*.bval'))[0]
    bvec_path = glob.glob(os.path.join(img_dir, '*.bvec'))[0]

    data, affine = dipy.io.image.load_nifti(dwi_path)
    bvals, bvecs = dipy.io.read_bvals_bvecs(bval_path, bvec_path)
    gtab = dipy.core.gradients.gradient_table(bvals, bvecs)

    return data, gtab

In [None]:
checkerboard = lambda i1, i2, e1, e2, n1, n2 : (-1)**(int((i1/e1)*n1) + int((i2/e2)*n2))
e1=140
e2=140
n1 = 8
n2 = 8
cb = np.array([[checkerboard(i1,i2,e1,e2,n1,n2) for i2 in range(e2)] for i1 in range(e1)])
cb_mask = (cb==1)

In [None]:
# load a random pair of images
img_dir1, img_dir2 = np.random.choice(img_dirs,2,replace=False)
data1, gtab1 = load_img_dir(img_dir1)
data2, gtab2 = load_img_dir(img_dir2)

# pick a random index for which diffusion weighted image to look at
dwi_index1 = np.random.randint(data1.shape[-1])
dwi_index2 = np.random.randint(data2.shape[-1])
# pick a random axial slice to look at, from the middle-ish
axial_slice = np.random.randint(data1.shape[2]/3, data1.shape[2] * 2/3)

fig, axs = plt.subplots(1,2, figsize=(16,16))
im1 = data1[:,:,axial_slice,dwi_index1].T
im2 = data2[:,:,axial_slice,dwi_index2].T
axs[0].imshow(im1, cmap='gray', origin='lower')
axs[1].imshow(im2, cmap='gray', origin='lower')
plt.show()

assert(im1.shape==im2.shape)
assert(im1.shape==cb.shape)
im3 = np.zeros_like(im1)
im3[cb_mask] = im1[cb_mask]
im3[~cb_mask] = im2[~cb_mask] / im2.max() * im1.max()
plt.figure(figsize=(12,12))
plt.imshow(im3, cmap='gray', origin='lower')
plt.show()