In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
import random


import os, sys
sys.path.insert(0,'..')

from collections import deque 

from dfibert.tracker.nn.rl import Agent, Action_Scheduler, DQN
import dfibert.envs.RLTractEnvironment as RLTe
from dfibert.cache import save_vtk_streamlines

from dfibert.envs._state import TractographyState

import matplotlib.pyplot as plt
%matplotlib notebook

from train import load_model

In [2]:
import importlib
importlib.reload(RLTe)

<module 'dfibert.envs.RLTractEnvironment' from '../dfibert/envs/RLTractEnvironment.py'>

In [3]:
env = RLTe.RLTractEnvironment(stepWidth=0.8, action_space=100, device = 'cpu', 
                              pReferenceStreamlines='data/dti_ijk_0.8_maxDirecGetter.vtk', 
                              tracking_in_RAS = False, odf_state = False)
n_actions = env.action_space.n

Loading precomputed streamlines (data/dti_ijk_0.8_maxDirecGetter.vtk) for ID 100307
sphere_odf = sphere_action = repulsion100
Computing ODF


### Load pretrained agent and analyse its behaviour

In [4]:
streamline_index = 2
device = "cpu"
max_steps = 30000000
replay_memory_size = 100000
agent_history_length = 1
evaluate_every = 200000
eval_runs = 5#20
network_update_every = 10000
start_learning = 10000
eps_annealing_steps = 400000

max_episode_length = 2000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 512
learning_rate = 0.000001 


state = env.reset(streamline_index=streamline_index, start_index=-2, start_middle_of_streamline=False)
env.referenceStreamline_ijk, state.getCoordinate()

(tensor([[26.9703, 85.2038, 58.4752],
         [26.2059, 85.4285, 58.4027],
         [25.6337, 85.6031, 57.8716],
         [25.0615, 85.7776, 57.3405],
         [24.4312, 85.6666, 56.8605],
         [23.6809, 85.4569, 56.6787],
         [22.9376, 85.5205, 56.3899],
         [22.1942, 85.5841, 56.1011],
         [21.4509, 85.6478, 55.8123],
         [20.7255, 85.3239, 55.9061],
         [20.0000, 85.0000, 56.0000],
         [19.2745, 84.6761, 56.0939]]), tensor([20., 85., 56.]))

In [5]:
# this agent got DWI input data and then predicts the direction out of 20 possible actions
# reward: odf20
model, step_counter, mean_reward, epsilon = load_model('defi_2.86_dwi_odf100.pt')

#model, step_counter, mean_reward, epsilon = load_model('defi_4.95_dwi_odf20.pt')
agent = Agent(n_actions=20, inp_size=state.getValue().shape, device=device, hidden=10, gamma=0.99, 
              agent_history_length=agent_history_length, 
              memory_size=replay_memory_size, batch_size=batch_size, learning_rate=learning_rate)

agent.main_dqn.load_state_dict(model)
agent.target_dqn.load_state_dict(model)

Loading checkpoint from defi_2.86_dwi_odf100.pt


<All keys matched successfully>

In [6]:
if(False):
    # this agent got ODF input data and then predicts the direction out of 100 possible actions
    # reward: odf20
    model, step_counter, mean_reward, epsilon = load_model('defi_4.81_odf20state_odf20.pt')
    agent = Agent(n_actions=20, inp_size=state.getValue().shape, device=device, hidden=10, gamma=0.99, 
                  agent_history_length=agent_history_length, 
                  memory_size=replay_memory_size, batch_size=batch_size, learning_rate=learning_rate)

    agent.main_dqn.load_state_dict(model)
    agent.target_dqn.load_state_dict(model)

In [9]:
eval_rewards = []
all_distances = []
all_states = []
l2s = []
max_episode_length = 15
fa_threshold = 0.1
K = 3

#agent.main_dqn.eval()
for _ in range(1):
    eval_steps = 0
    state = env.reset(streamline_index=streamline_index, start_middle_of_streamline=False, start_index=-1)
    next_state = state
    all_states.append(state.getCoordinate())
    eval_episode_reward = 0
    episode_final = 0
    while eval_steps < max_episode_length:
        
        with torch.no_grad():
            state_v = torch.from_numpy(state.getValue()).unsqueeze(0).float().to(device)
            action = torch.argmax(agent.main_dqn(state_v)).item()
        
        my_position = all_states[-1]
        current_direction = None
        
        if(eval_steps > 0):
            # compute tangent of previous step
            current_direction = all_states[-1] - all_states[-2]
            current_direction = current_direction / torch.sqrt(torch.sum(current_direction**2))
            current_direction = current_direction.view(1,3)
        
        #action = get_multi_best_action(current_direction, odf_interpolator, my_position, mysphere, sphere_verts_torch, K = K)
        action = env._get_best_action(current_direction, my_position)

        
        next_state, reward, terminal, _ = env.step(action)
        
        #reward = reward
        
        print(eval_steps, my_position, "=>", next_state.getCoordinate().numpy(), action, reward)

        
        all_distances.append(reward)
        all_states.append(next_state.getCoordinate().squeeze())
                
        state = next_state
        print("---")
        if terminal:
            terminal = False
            break
            
        eval_episode_reward += reward.squeeze()
        eval_steps += 1

    eval_rewards.append(eval_episode_reward)

print("Evaluation score:", np.min(eval_rewards))

########################
### visualise streamline
########################
%matplotlib notebook 
state = env.reset(streamline_index=streamline_index) 

states = torch.stack(all_states)

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(env.referenceStreamline_ijk.T[0], env.referenceStreamline_ijk.T[1], env.referenceStreamline_ijk.T[2], '-*')
ax.plot3D(states.T[0], states.T[1], states.T[2])
plt.legend(['gt', 'agent'])

Max reward: 1.00
0 tensor([19.2745, 84.6761, 56.0939]) => [[20.00000015 84.99999907 55.99999948]] tensor(96) tensor(1., dtype=torch.float64)
---
Max reward: 1.00
1 tensor([20.0000, 85.0000, 56.0000], dtype=torch.float64) => [[20.72545845 85.32388119 55.90613453]] tensor(96) tensor([1.0000], dtype=torch.float64)
---
Max reward: 1.00
2 tensor([20.7255, 85.3239, 55.9061], dtype=torch.float64) => [[21.45091674 85.64776332 55.81226957]] tensor(96) tensor([1.0000], dtype=torch.float64)
---
Max reward: 0.77
3 tensor([21.4509, 85.6478, 55.8123], dtype=torch.float64) => [[22.19424354 85.58413612 56.10108933]] tensor(34) tensor([0.7680], dtype=torch.float64)
---
Max reward: 1.00
4 tensor([22.1942, 85.5841, 56.1011], dtype=torch.float64) => [[22.93757034 85.52050893 56.38990909]] tensor(34) tensor([1.], dtype=torch.float64)
---
Max reward: 1.00
5 tensor([22.9376, 85.5205, 56.3899], dtype=torch.float64) => [[23.68089713 85.45688173 56.67872885]] tensor(34) tensor([0.9999], dtype=torch.float64)
---

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x2aad28e344e0>

### Debug deterministic tracking of Dipy

In [None]:
from scipy.interpolate import RegularGridInterpolator
import dipy.reconst.dti as dti

from dipy.direction import peaks_from_model
from dipy.data import get_sphere


# fit DTI model to data
dti_model = dti.TensorModel(env.dataset.data.gtab, fit_method='LS')
dti_fit = dti_model.fit(env.dataset.data.dwi, mask=env.dataset.data.binarymask)

#TODO: Issue => are we using the correct data for tractography actually??? The data got 288 gradient directions
# seems like its using the data of all bvals!!!
mysphere = get_sphere('repulsion100')
odf = dti_fit.odf(mysphere)

## set up interpolator for directions
x_range = np.arange(odf.shape[0])
y_range = np.arange(odf.shape[1])
z_range = np.arange(odf.shape[2])



#affine = env.dataset.data.aff # tracking in RAS
affine = np.eye(4) # tracking in IJK

In [None]:
#dir_interpolator = RegularGridInterpolator((x_range,y_range,z_range), dir)
odf_interpolator = RegularGridInterpolator((x_range,y_range,z_range), odf)
fa_interpolator = RegularGridInterpolator((x_range,y_range,z_range), dti_fit.fa)
#pd_interpolator = RegularGridInterpolator((x_range,y_range,z_range), peak_indices.peak_dirs)

In [None]:
env.dataset.get_fa

In [None]:
peak_indices = peaks_from_model(
    model=dti_model, data=env.dataset.data.dwi, sphere=mysphere, relative_peak_threshold=.2,
    min_separation_angle=25, mask=env.dataset.data.binarymask, npeaks=2) # Peaks and Metrics object

In [None]:
import dipy
dg = dipy.direction.DeterministicMaximumDirectionGetter.from_shcoeff(peak_indices.shm_coeff, 80, peak_indices.sphere)

In [None]:
from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
fa_img = dti_fit.fa
fa_img[np.isnan(fa_img)] = 0
stopping_criterion = ThresholdStoppingCriterion(fa_img, .1)

In [None]:
from dipy.tracking import utils

seed_mask = fa_img.copy()
seed_mask[seed_mask >= 0.2] = 1
seed_mask[seed_mask < 0.2] = 0

seeds = utils.seeds_from_mask(seed_mask, affine=np.eye(4), density=1) # tracking in IJK

In [None]:
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.streamline import Streamlines

# Initialize local tracking - computation happens in the next step. 
# EuDX, https://dipy.org/documentation/1.0.0./examples_built/tracking_introduction_eudx/#garyfallidis12
# EuDx => change dg into pam
streamlines_generator = LocalTracking(
    dg, stopping_criterion, seeds, affine=np.eye(4), step_size=.8) # tracking in IJK

# Generate streamlines object
streamlines = Streamlines(streamlines_generator)
streamlines[0]
# tracked_streamlines = filter(lambda sl: len(sl) >= 10, tracked_streamlines)

In [None]:
streamlines_cropped = list(filter(lambda sl: len(sl) >= 10, streamlines))
len(streamlines_cropped) / len(streamlines)

In [None]:
save_vtk_streamlines(streamlines=streamlines_cropped, filename="dti_ijk_0.8_maxDirecGetter.vtk")

## Ground-truth direction

In [None]:
cool_sl = 2
idx = 4
ref_sl = env.referenceStreamline_ijk

In [None]:
diff_vector = (ref_sl[idx+1] - ref_sl[idx])
diff_vector_norm = diff_vector / torch.sqrt(torch.sum(diff_vector**2))
"gt", diff_vector_norm

In [None]:
#pv_norm = pd_interpolator(ref_sl[idx+1]) / np.sqrt(np.sum((pd_interpolator(ref_sl[idx+1])**2)))

odf_max = np.argmax(odf_interpolator(ref_sl[idx]))
odf_max_norm = mysphere.vertices[odf_max]

#"pv", pv_norm, 
"odf", odf_max_norm

In [None]:
odf_x = odf_interpolator(ref_sl[idx]).squeeze()
plt.plot(odf_x,'.')

# Debug data  generation

In [None]:
import os, sys

import gym
from gym.spaces import Discrete, Box
import numpy as np

from dipy.data import get_sphere
from dipy.data import HemiSphere, Sphere
from dipy.core.sphere import disperse_charges
import torch


from dfibert.data.postprocessing import res100, resample
from dfibert.data import HCPDataContainer, ISMRMDataContainer, PointOutsideOfDWIError
from dfibert.tracker import StreamlinesFromFileTracker
from dfibert.util import get_grid

import shapely.geometry as geom
from shapely.ops import nearest_points
from shapely.strtree import STRtree


from collections import deque

dataset = HCPDataContainer('100307')
dataset.normalize()

In [None]:
coord, data = next_state.getCoordinate(), next_state.getValue()
grid = get_grid(np.array([3,3,3]))
ras_points = env.dataset.to_ras(coord)
ras_points = grid + ras_points

interpolated_dwi = env.dataset.get_interpolated_dwi(ras_points, postprocessing=None)
#interpolated_dwi = np.rollaxis(interpolated_dwi,3)

dti_fit = dti_model.fit(interpolated_dwi)
mysphere = get_sphere('repulsion100')
odf = dti_fit.odf(mysphere)

In [None]:
plt.figure()
plt.plot(range(100), np.mean(odf.reshape(-1,100), axis=0))
plt.plot(range(100), odf[1,1,1,:])

In [None]:
interpolated_dwi.shape

In [None]:
states, actions, _, _, _ = agent.replay_memory.get_minibatch()
states = torch.FloatTensor(states).to(agent.device)
predicted_q = torch.argmax(agent.main_dqn(states), dim=1)

false = 0
for i in range(len(actions)):
    if predicted_q[i] != actions[i]:
        false += 1 
    
print("Accuracy =", 1 - false / len(actions))

# Analysis of our environment

In [None]:
env._get_best_action(None, my_position)

In [None]:
gt_direction = env.referenceStreamline_ijk.T[:,-1] - env.referenceStreamline_ijk.T[:,-2]
gt_direction = gt_direction / torch.sqrt(torch.sum(gt_direction**2))
gt_direction = gt_direction.view(1,3)

gt_direction

In [None]:
env.directions[96]

In [None]:
my_position

In [None]:
my_position = all_states[0]
eval_steps = 0

current_direction = None
        
if(eval_steps > 0):
    # compute tangent of previous step
    current_direction = all_states[-1] - all_states[-2]
    current_direction = current_direction / torch.sqrt(torch.sum(current_direction**2))
    current_direction = current_direction.view(1,3)

action = env._get_best_action(current_direction, my_position)
action

In [None]:
# main peak from ODF
peak_dir = env._get_best_action_ODF(my_position)

# cosine similarity wrt. all directions
reward = abs(torch.nn.functional.cosine_similarity(peak_dir.view(1,-1), env.directions))

if(current_direction is not None):
    reward = reward * (torch.nn.functional.cosine_similarity(current_direction, env.directions))

best_action = torch.argmax(reward)
print("Max reward: %.2f" % (torch.max(reward).cpu().detach().numpy()))

In [None]:
odf_l = torch.from_numpy(env.interpolate_odf_at_state(stateCoordinates=my_position))[:,1,1,1].view(100)
#odf_l = torch.mean(odf_l, 1)
reward = odf_l / torch.max(odf_l)
reward = reward * (torch.nn.functional.cosine_similarity(current_direction, env.directions))
torch.argmax(odf_l), torch.argmax(reward)

In [None]:
(torch.nn.functional.cosine_similarity(current_direction, env.directions)).shape

In [None]:
reward.shape

In [None]:
reward[73]

In [None]:
fig = plt.figure()
plt.plot(reward.cpu().detach().numpy().squeeze(), '.')

In [None]:
env.directions

In [None]:
current_direction

In [None]:
my_position.cpu().numpy()

In [None]:
states.T.shape

In [None]:
k_largest.shape

In [None]:
np.argmax(my_odf.squeeze())

In [None]:
all_states

In [None]:
### visualise streamline
%matplotlib notebook 
state = env.reset(streamline_index=50) 

states = torch.stack(all_states)

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(env.referenceStreamline_ijk.T[0][0:10], env.referenceStreamline_ijk.T[1][0:10], env.referenceStreamline_ijk.T[2][0:10])
ax.plot3D(states.T[0][0:10], states.T[1][0:10], states.T[2][0:10])