# ReadMe

The structure of this notebook is as follows:

0) Imports; defining the dataset. Test
1) Functions for looking for the *source* slices, closest to a given *target* scan 
2) Generating "raw" closeness dictionary 
2) Generating sorted lists of closests slices (in 2D/2.5D/3D fashion -- see the paper)
3) Examples of the "raw" dicitonary and closest slices lists entries

# Miscellaneous

In [1]:
import os.path
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import json
from os.path import join as jp
from kswap.module.unet import UNet2D
import random
import piq
from scipy import ndimage
from joblib import Parallel, delayed
from dpipe.io import load
from kswap.utils import choose_root

from kswap.dataset.cc359 import CC359, scale_mri
from dpipe.dataset.wrappers import apply, cache_methods

from tqdm.notebook import tqdm

In [2]:
# Defining the Dataset
data_path = choose_root(
    '/path/cc359',
    '/path/raid/cc359',
)

low = 0
high = 172

preprocessed_dataset = apply(CC359(data_path, low, high), load_image=scale_mri)
dataset = apply(cache_methods(apply(preprocessed_dataset, load_image=np.float16)), load_image=np.float32)

# Functions

In [3]:
def find_nearest_25d(dataset, s_slices_ids, t_slice, coordinate, n_nearest, offset=(60, 60), device_name='cpu'):
    
    """
    measuring the *source* slices closeness to a given *target* slice
    """
    
    # crop around the mass centre
    offset_x, offset_y = offset
    x, y = map(int, ndimage.measurements.center_of_mass(t_slice))
    t_slice = t_slice[x - offset_x: x + offset_x, y - offset_y: y + offset_y].copy()
    t_slice = torch.tensor(torch.from_numpy(t_slice), device=device_name)[None, None, ...]
        
    slices_scores = [(id_, assess_slice_25d(dataset, coordinate, offset_x, offset_y, device_name, t_slice, id_))
                 for id_ in s_slices_ids]
    return slices_scores

def assess_slice_25d(dataset, coordinate, offset_x, offset_y, device_name, t_slice, id_):
    
    """
    (*source* slice proximity <-> *target* slice) closeness
    """
    
    scan = dataset.load_image(id_)
    h = scan.shape[-1]
    to_return = []
    if (coordinate > 1) and (coordinate < h-2):
        for i in range(-2, 3):
            s_slice = scan[:, :, coordinate+i]
            x, y = map(int, ndimage.measurements.center_of_mass(s_slice))
            s_slice_trunc = s_slice[x - offset_x: x + offset_x, y - offset_y: y + offset_y].copy()
            s_slice_trunc = torch.tensor(torch.from_numpy(s_slice_trunc), device=device_name)[None, None, ...]
            ssim_cur = piq.srsim(s_slice_trunc, t_slice, data_range=1.)  # best one
            to_return.append((float(ssim_cur.numpy()), i))
    else:
        s_slice = scan[:, :, coordinate]
        x, y = map(int, ndimage.measurements.center_of_mass(s_slice))
        s_slice_trunc = s_slice[x - offset_x: x + offset_x, y - offset_y: y + offset_y].copy()
        s_slice_trunc = torch.tensor(torch.from_numpy(s_slice_trunc), device=device_name)[None, None, ...]
        ssim_cur = piq.srsim(s_slice_trunc, t_slice, data_range=1.)  # best one
        to_return.append((float(ssim_cur.numpy()), 0))
    
    return tuple(to_return)

In [4]:
id2dom = {0: 'sm15', 1: 'sm3', 2: 'ge15', 3: 'ge3', 4: 'ph15', 5: 'ph3'}

In [6]:
# pair to the number of an experiment in the split

id2dom = {0: 'sm15', 1: 'sm3', 2: 'ge15', 3: 'ge3', 4: 'ph15', 5: 'ph3'}
pairs = [(1, 2), (3, 5), (5, 2), (0, 5), (5, 0), (2, 4)]

count = 0
pair2exp = {}
for i in range(6):
    for j in range(6):
        if i!=j:
            pair2exp[(i,j)] = count
            count += 1
            
exps = [pair2exp[el] for el in pairs]

In [7]:
exps

[6, 19, 27, 4, 25, 13]

# Closeness dictionary

In [None]:
scores_val = {}
scores_test = {}

for pair in pairs:
    
    # Getting all the required scans IDs
    n_exp = pair2exp[pair]
    split_path = '/path/exps/split/experiment_' + str(n_exp)
    
    scores_val = {}
    scores_test = {}
    
    path_train_source = jp(split_path, 'train_s_ids.json')
    train_source_ids = load(path_train_source)

    path_val_target = jp(split_path, 'val_t_ids.json')
    val_target_ids = load(path_val_target)

    path_test_target = jp(split_path, 'test_t_ids.json')
    test_target_ids = load(path_test_target)
    
    # Val Scans: val slice -> closest slices
    for id_target in val_target_ids:
        scores_val[id_target] = {}
        scan = dataset.load_image(id_target)
        h = scan.shape[-1]
        for i in tqdm(range(h)):
            slice_ = scan[:, :, i]
            slice_nearest = find_nearest_25d(dataset=dataset, s_slices_ids=train_source_ids, t_slice=slice_, 
                                         coordinate=i, n_nearest=7, offset=(60, 60), device_name='cpu')
            scores_val[id_target][i] = tuple(slice_nearest)
    
    with open('/path/exps/closest_scans/closest_scans_val_' + str(n_exp) + '.json', 'w') as fp:
        json.dump(scores_val, fp)
            
    # Test Scans: test slice -> closest slices
    for id_target in test_target_ids:
        scores_test[id_target] = {}
        scan = dataset.load_image(id_target)
        h = scan.shape[-1]
        for i in tqdm(range(h)):
            slice_ = scan[:, :, i]
            slice_nearest = find_nearest_25d(dataset=dataset, s_slices_ids=train_source_ids, t_slice=slice_, 
                                         coordinate=i, n_nearest=7, offset=(60, 60), device_name='cpu')
            scores_test[id_target][i] = tuple(slice_nearest)

    with open('/path/exps/closest_scans/closest_scans_test_' + str(n_exp) + '.json', 'w') as fp:
        json.dump(scores_test, fp)

# 2D, 2.5D and 3D closeness

## 2D

In [None]:
for exp in exps:
    
    path = '/path/exps/closest_scans/closest_scans_val_' + str(exp) + '.json'
    with open(path) as fp:
        closest_val = json.load(fp)

    closest_val_arranged = {}
    for t_id in closest_val.keys():
        closest_val_arranged[t_id] = {}
        for slice_id in closest_val[t_id].keys():
            closest_val_arranged[t_id][slice_id] = []
            for s_id_res in closest_val[t_id][slice_id]:
                s_id, scores = s_id_res
                for el in scores:
                    score, position = el
                    if position == 0:
                        closest_val_arranged[t_id][slice_id].append((s_id, position, score))
            closest_val_arranged[t_id][slice_id] = sorted(tuple(closest_val_arranged[t_id][slice_id]), 
                                                          key = lambda x: x[2], reverse=True)
        
    with open('/path/exps/closest_scans/closest_scans_val_' + str(exp) + '_2d.json', 'w') as fp:
        json.dump(closest_val_arranged, fp)

## 2.5D

In [None]:
for exp in exps:
    
    path = '/path/exps/closest_scans/closest_scans_test_' + str(exp) + '.json'
    with open(path) as fp:
        closest_val = json.load(fp)

    closest_val_arranged = {}
    for t_id in closest_val.keys():
        closest_val_arranged[t_id] = {}
        for slice_id in closest_val[t_id].keys():
            closest_val_arranged[t_id][slice_id] = []
            for s_id_res in closest_val[t_id][slice_id]:
                s_id, scores = s_id_res
                for el in scores:
                    score, position = el
                    closest_val_arranged[t_id][slice_id].append((s_id, position, score))
            closest_val_arranged[t_id][slice_id] = sorted(tuple(closest_val_arranged[t_id][slice_id]), 
                                                          key = lambda x: x[2], reverse=True)
        
    with open('/path/exps/closest_scans/closest_scans_test_' + str(exp) + '_25d.json', 'w') as fp:
        json.dump(closest_val_arranged, fp)

## 3D

In [None]:
for exp in exps:
    
    path = '/path/exps/closest_scans/closest_scans_test_' + str(exp) + '_2d.json'
    with open(path) as fp:
        closest_val = json.load(fp)

    closest_val_3d = {}
    closest_val_3d_final = {}

    for t_id in closest_val.keys():
        # we intend to accumulate an auxiliary dict: closest_val_3d[t_id][s_id]
        closest_val_3d[t_id] = {} 
        # the dict of desired format: closest_val_3d_final[t_id][slice_id]
        closest_val_3d_final[t_id] = {}
        
        # closest_val_3d[t_id][s_id] -> list of scores (slice2slice closeness)
        for slice_id in closest_val[t_id].keys():
            res_slice = closest_val[t_id][slice_id] 
            closest_val_3d_final[t_id][slice_id] = []
            for el in res_slice:
                s_id, score = el[0], el[2]
                if s_id not in closest_val_3d[t_id].keys():
                    closest_val_3d[t_id][s_id] = [score]
                else:
                    closest_val_3d[t_id][s_id].append(score)

    # closest_val_3d[t_id] -> source scans sorted by closeness
    for t_id in closest_val_3d.keys():
        for s_id in closest_val_3d[t_id].keys():
            closest_val_3d[t_id][s_id] = np.array(closest_val_3d[t_id][s_id]).mean()
        proxy = sorted(list(closest_val_3d[t_id].items()), key = lambda x: x[1], reverse=True)
        closest_val_3d[t_id] = [[el[0], 0, el[1]] for el in proxy]

    # closest_val_3d_final[t_id][slice_id] is what we actually need
    for t_id in closest_val_3d.keys():   
        for i, slice_id in enumerate(slice_ids):
            closest_val_3d_final[t_id][slice_id] = closest_val_3d[t_id]
            
    with open('/path/exps/closest_scans/closest_scans_test_' + str(exp) + '_3d.json', 'w') as fp:
        json.dump(closest_val_3d_final, fp)

# Examples

## Raw closeness dict 

In [52]:
path = '/path/exps/closest_scans/closest_scans_test_' + str(6) + '.json'
with open(path) as fp:
    closest = json.load(fp)

In [55]:
list(closest['CC0243']['10'])[:3]

[['CC0216',
  [[0.7414243817329407, -2],
   [0.7566266655921936, -1],
   [0.7579689621925354, 0],
   [0.7535974383354187, 1],
   [0.7527554631233215, 2]]],
 ['CC0225',
  [[0.7327763438224792, -2],
   [0.7237334847450256, -1],
   [0.7115591168403625, 0],
   [0.7117815613746643, 1],
   [0.724881649017334, 2]]],
 ['CC0193',
  [[0.7191252708435059, -2],
   [0.7312114834785461, -1],
   [0.7469754815101624, 0],
   [0.73780357837677, 1],
   [0.7465262413024902, 2]]]]

In [56]:
list(closest['CC0243']['15'])[:3]

[['CC0216',
  [[0.777565598487854, -2],
   [0.7713488936424255, -1],
   [0.7729378342628479, 0],
   [0.7835105061531067, 1],
   [0.7913953065872192, 2]]],
 ['CC0225',
  [[0.7752798199653625, -2],
   [0.7741305232048035, -1],
   [0.7721046805381775, 0],
   [0.7716405391693115, 1],
   [0.773045003414154, 2]]],
 ['CC0193',
  [[0.7597838640213013, -2],
   [0.7712560296058655, -1],
   [0.7688422203063965, 0],
   [0.7702457904815674, 1],
   [0.7759168744087219, 2]]]]

## 2D

In [28]:
path = '/path/exps/closest_scans/closest_scans_test_' + str(6) + '_2d.json'
with open(path) as fp:
    closest_2d = json.load(fp)

In [29]:
list(closest_2d['CC0243']['10'])[:7]

[['CC0219', 0, 0.7949985861778259],
 ['CC0231', 0, 0.7910090088844299],
 ['CC0181', 0, 0.7881765961647034],
 ['CC0204', 0, 0.7867870330810547],
 ['CC0222', 0, 0.7814748287200928],
 ['CC0210', 0, 0.7783483266830444],
 ['CC0212', 0, 0.7776628732681274]]

In [30]:
list(closest_2d['CC0243']['15'])[:7]

[['CC0222', 0, 0.7991904020309448],
 ['CC0236', 0, 0.7952978610992432],
 ['CC0181', 0, 0.7952671051025391],
 ['CC0219', 0, 0.7950671315193176],
 ['CC0204', 0, 0.7925722002983093],
 ['CC0226', 0, 0.7894681692123413],
 ['CC0199', 0, 0.7872225642204285]]

## 3D

In [38]:
path = '/path/exps/closest_scans/closest_scans_test_' + str(6) + '_3d.json'
with open(path) as fp:
    closest_3d = json.load(fp)

In [39]:
list(closest_3d['CC0243']['10'])[:7]

[['CC0181', 0, 0.8067700686149819],
 ['CC0219', 0, 0.8047209330076395],
 ['CC0221', 0, 0.8046921737665353],
 ['CC0191', 0, 0.8031747906013976],
 ['CC0213', 0, 0.8026486472334973],
 ['CC0210', 0, 0.8011291844900265],
 ['CC0199', 0, 0.8009275679671487]]

In [40]:
list(closest_3d['CC0243']['15'])[:7]

[['CC0181', 0, 0.8067700686149819],
 ['CC0219', 0, 0.8047209330076395],
 ['CC0221', 0, 0.8046921737665353],
 ['CC0191', 0, 0.8031747906013976],
 ['CC0213', 0, 0.8026486472334973],
 ['CC0210', 0, 0.8011291844900265],
 ['CC0199', 0, 0.8009275679671487]]

## 2.5D

In [43]:
path = '/path/exps/closest_scans/closest_scans_test_' + str(6) + '_25d.json'
with open(path) as fp:
    closest_25d = json.load(fp)

In [44]:
list(closest_25d['CC0243']['10'])[:7]

[['CC0204', -2, 0.7964747548103333],
 ['CC0219', 0, 0.7949985861778259],
 ['CC0219', 1, 0.7927069067955017],
 ['CC0231', 0, 0.7910090088844299],
 ['CC0204', -1, 0.789961576461792],
 ['CC0222', 2, 0.7892383337020874],
 ['CC0231', 2, 0.7891520857810974]]

In [45]:
list(closest_25d['CC0243']['15'])[:7]

[['CC0222', -1, 0.8012664914131165],
 ['CC0219', 1, 0.7992594838142395],
 ['CC0222', 0, 0.7991904020309448],
 ['CC0219', 2, 0.7979041934013367],
 ['CC0181', -1, 0.7967404127120972],
 ['CC0236', 0, 0.7952978610992432],
 ['CC0181', 0, 0.7952671051025391]]