In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gym
import numpy as np
import random


import os, sys
sys.path.insert(0,'..')

from collections import deque 

from dfibert.tracker.nn.rl import Agent, Action_Scheduler, DQN
import dfibert.envs.RLtractEnvironment as RLTe
from dfibert.cache import save_vtk_streamlines

from dfibert.envs._state import TractographyState

import matplotlib.pyplot as plt
%matplotlib notebook

from train import load_model

In [2]:
import importlib
importlib.reload(RLTe)

<module 'dfibert.envs.RLtractEnvironment' from '../dfibert/envs/RLtractEnvironment.py'>

In [3]:
#env = RLTe.RLtractEnvironment(stepWidth=0.8, action_space=100, device = 'cpu', pReferenceStreamlines='data/HCP307200_DTI_min40.vtk')
env = RLTe.RLtractEnvironment(stepWidth=0.8, action_space=100, device = 'cpu', 
                              pReferenceStreamlines='dti_ijk_0.8_maxDirecGetter.vtk', tracking_in_RAS = False)
n_actions = env.action_space.n

Loading precomputed streamlines (dti_ijk_0.8_maxDirecGetter.vtk) for ID 100307
Repulsion100!
Computing ODF


In [None]:
from scipy.interpolate import RegularGridInterpolator
import dipy.reconst.dti as dti

from dipy.direction import peaks_from_model
from dipy.data import get_sphere


# fit DTI model to data
dti_model = dti.TensorModel(env.dataset.data.gtab, fit_method='LS')
dti_fit = dti_model.fit(env.dataset.data.dwi, mask=env.dataset.data.binarymask)

#TODO: Issue => are we using the correct data for tractography actually??? The data got 288 gradient directions
# seems like its using the data of all bvals!!!
mysphere = get_sphere('repulsion100')
odf = dti_fit.odf(mysphere)

## set up interpolator for directions
x_range = np.arange(odf.shape[0])
y_range = np.arange(odf.shape[1])
z_range = np.arange(odf.shape[2])



#affine = env.dataset.data.aff # tracking in RAS
affine = np.eye(4) # tracking in IJK

In [None]:
#dir_interpolator = RegularGridInterpolator((x_range,y_range,z_range), dir)
odf_interpolator = RegularGridInterpolator((x_range,y_range,z_range), odf)
fa_interpolator = RegularGridInterpolator((x_range,y_range,z_range), dti_fit.fa)
#pd_interpolator = RegularGridInterpolator((x_range,y_range,z_range), peak_indices.peak_dirs)

In [None]:
env.dataset.get_fa

In [None]:
peak_indices = peaks_from_model(
    model=dti_model, data=env.dataset.data.dwi, sphere=mysphere, relative_peak_threshold=.2,
    min_separation_angle=25, mask=env.dataset.data.binarymask, npeaks=2) # Peaks and Metrics object

In [None]:
import dipy
dg = dipy.direction.DeterministicMaximumDirectionGetter.from_shcoeff(peak_indices.shm_coeff, 80, peak_indices.sphere)

In [None]:
from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
fa_img = dti_fit.fa
fa_img[np.isnan(fa_img)] = 0
stopping_criterion = ThresholdStoppingCriterion(fa_img, .1)

In [None]:
from dipy.tracking import utils

seed_mask = fa_img.copy()
seed_mask[seed_mask >= 0.2] = 1
seed_mask[seed_mask < 0.2] = 0

seeds = utils.seeds_from_mask(seed_mask, affine=np.eye(4), density=1) # tracking in IJK

In [None]:
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.streamline import Streamlines

# Initialize local tracking - computation happens in the next step. 
# EuDX, https://dipy.org/documentation/1.0.0./examples_built/tracking_introduction_eudx/#garyfallidis12
# EuDx => change dg into pam
streamlines_generator = LocalTracking(
    dg, stopping_criterion, seeds, affine=np.eye(4), step_size=.8) # tracking in IJK

# Generate streamlines object
streamlines = Streamlines(streamlines_generator)
streamlines[0]
# tracked_streamlines = filter(lambda sl: len(sl) >= 10, tracked_streamlines)

In [None]:
streamlines_cropped = list(filter(lambda sl: len(sl) >= 10, streamlines))
len(streamlines_cropped) / len(streamlines)

In [None]:
save_vtk_streamlines(streamlines=streamlines_cropped, filename="dti_ijk_0.8_maxDirecGetter.vtk")

## Ground-truth direction

In [None]:
cool_sl = 2
idx = 4
ref_sl = env.referenceStreamline_ijk

In [None]:
diff_vector = (ref_sl[idx+1] - ref_sl[idx])
diff_vector_norm = diff_vector / torch.sqrt(torch.sum(diff_vector**2))
"gt", diff_vector_norm

In [None]:
#pv_norm = pd_interpolator(ref_sl[idx+1]) / np.sqrt(np.sum((pd_interpolator(ref_sl[idx+1])**2)))

odf_max = np.argmax(odf_interpolator(ref_sl[idx]))
odf_max_norm = mysphere.vertices[odf_max]

#"pv", pv_norm, 
"odf", odf_max_norm

In [None]:
odf_x = odf_interpolator(ref_sl[idx]).squeeze()
plt.plot(odf_x,'.')

## Start RL training

In [None]:
max_steps = 30000000
replay_memory_size = 100000
agent_history_length = 1
evaluate_every = 200000
eval_runs = 5#20
network_update_every = 10000
start_learning = 10000
eps_annealing_steps = 400000

max_episode_length = 2000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 512
learning_rate = 0.000001 

In [None]:
state = env.reset().getValue()
print(state.shape)
agent = Agent(n_actions=n_actions, inp_size=state.shape, device=device, hidden=10, gamma=0.99, agent_history_length=agent_history_length, memory_size=replay_memory_size, batch_size=batch_size, learning_rate=learning_rate)

In [None]:
#### Fill replay memory with perfect actions for supervised approach

from tqdm import trange
state = env.reset().getValue()
agent = Agent(n_actions=n_actions, inp_size=state.shape, device=device, hidden=10, gamma=0.99, agent_history_length=agent_history_length, memory_size=replay_memory_size, batch_size=512, learning_rate=learning_rate)

overall_runs = 0
overall_reward = []
for overall_runs in trange(60):
    state = env.reset(streamline_index=overall_runs)
    #episode_step_counter = 0
    episode_reward = 0
    terminal = False
    #print("New run")
    #print(env.stepCounter, state.getCoordinate().numpy())
    while not terminal:
        #print(env.stepCounter)
        #if np.random.rand(1) < 0.1: 
        #    action = np.random.randint(0, n_actions)
        #else:
        action = env._get_best_action()
        next_state, reward, terminal, _ = env.step(action)
        
            
        agent.replay_memory.add_experience(action=action,
                                state = state.getValue(),
                                reward = reward,
                                new_state = next_state.getValue(),
                                terminal=terminal)
        
        episode_reward += reward
        
        state = next_state
        
        if terminal == True:
            break
            
    overall_runs += 1
    overall_reward.append(episode_reward)
    print(overall_runs, np.mean(overall_reward[-100:]))
print("Replay memory ready")

Training cell

In [None]:
torch.optim.Adam(agent.target_dqn.parameters(), 0.0001)
losses = []
for i in trange(70000):
    states, actions, _, _, _ = agent.replay_memory.get_minibatch()

    states = torch.FloatTensor(states).to(agent.device)
    actions = torch.LongTensor(actions).to(agent.device)
    predicted_q = agent.main_dqn(states)
    loss = torch.nn.functional.cross_entropy(predicted_q, actions)
    #print(loss.item())
    agent.optimizer.zero_grad()
    loss.backward()
    agent.optimizer.step()
    losses.append(loss.item())
    
mean_losses = []
for i in range(len(losses)):
    mean_losses.append(np.mean(losses[i:i+99]))
#print(mean_losses[-20:])

fig, ax = plt.subplots()
ax.plot(range(len(losses[:])), losses[:])
ax.plot(range(len(losses[:])), mean_losses[:])
plt.show()

# Debug data  generation

In [None]:
import os, sys

import gym
from gym.spaces import Discrete, Box
import numpy as np

from dipy.data import get_sphere
from dipy.data import HemiSphere, Sphere
from dipy.core.sphere import disperse_charges
import torch


from dfibert.data.postprocessing import res100, resample
from dfibert.data import HCPDataContainer, ISMRMDataContainer, PointOutsideOfDWIError
from dfibert.tracker import StreamlinesFromFileTracker
from dfibert.util import get_grid

import shapely.geometry as geom
from shapely.ops import nearest_points
from shapely.strtree import STRtree


from collections import deque

dataset = HCPDataContainer('100307')
dataset.normalize()

In [None]:
coord, data = next_state.getCoordinate(), next_state.getValue()
grid = get_grid(np.array([3,3,3]))
ras_points = env.dataset.to_ras(coord)
ras_points = grid + ras_points

interpolated_dwi = env.dataset.get_interpolated_dwi(ras_points, postprocessing=None)
#interpolated_dwi = np.rollaxis(interpolated_dwi,3)

dti_fit = dti_model.fit(interpolated_dwi)
mysphere = get_sphere('repulsion100')
odf = dti_fit.odf(mysphere)

In [None]:
plt.figure()
plt.plot(range(100), np.mean(odf.reshape(-1,100), axis=0))
plt.plot(range(100), odf[1,1,1,:])

In [None]:
interpolated_dwi.shape

In [None]:
states, actions, _, _, _ = agent.replay_memory.get_minibatch()
states = torch.FloatTensor(states).to(agent.device)
predicted_q = torch.argmax(agent.main_dqn(states), dim=1)

false = 0
for i in range(len(actions)):
    if predicted_q[i] != actions[i]:
        false += 1 
    
print("Accuracy =", 1 - false / len(actions))

In [None]:
step_counter = 0
eps_rewards = []
episode_lengths = []

eps = 1.0

print("Start training...")
while step_counter < max_steps:
    epoch_step = 0
    while (epoch_step < evaluate_every) or (step_counter < start_learning):
        state = env.reset()
        episode_reward_sum = 0
        terminal = False
        episode_step_counter = 0
        positive_run = 0
        points_visited = 0
        
        negative_rewards = 0
        
        
        # reduce epsilon
        if step_counter > start_learning:
            eps = max(eps * 0.999, 0.01)
        
        # play an episode
        while episode_step_counter <= 1000.:
            
            # get an action with epsilon-greedy strategy
            if random.random() < eps:                                 
                action = np.random.randint(env.action_space.n)           # either random action
                #action = env._get_best_action()
            else:                                                        # or action from agent
                agent.main_dqn.eval()
                with torch.no_grad():
                    state_v = torch.from_numpy(state.getValue()).unsqueeze(0).float().to(device)
                    action = torch.argmax(agent.main_dqn(state_v)).item()
                agent.main_dqn.train()
            
            # perform step on environment
            next_state, reward, terminal, _ = env.step(action)

            
            episode_step_counter += 1
            step_counter += 1
            epoch_step += 1
            
            episode_reward_sum += reward
            
            # store experience in replay buffer
            agent.replay_memory.add_experience(action=action, state = state.getValue(), reward=reward, new_state = next_state.getValue(), terminal=terminal)
            
            state = next_state
            
            # optimize agent after certain amount of steps
            if step_counter > start_learning and step_counter % 4 == 0:
                
                # original optimization function
                #agent.optimize()
                
                ### debugging optimization function
                
                states, actions, rewards, new_states, terminal_flags = agent.replay_memory.get_minibatch()
                
                #states = torch.tensor(states)#.view(replay_memory.batch_size, -1) # 1, -1
                #next_states = torch.tensor(new_states)#.view(replay_memory.batch_size, -1)
                #actions = torch.LongTensor(actions)
                #rewards = torch.tensor(rewards)
                #terminal_flags = torch.BoolTensor(terminal_flags)

                states = torch.from_numpy(states).to(device)
                next_states = torch.from_numpy(new_states).to(device)
                actions = torch.from_numpy(actions).unsqueeze(1).long().to(device)
                rewards = torch.from_numpy(rewards).to(device)
                terminal_flags = torch.from_numpy(terminal_flags).to(device)
                
                
                state_action_values = agent.main_dqn(states).gather(1, actions).squeeze(-1)
                next_state_actions = torch.argmax(agent.main_dqn(next_states), dim=1)
                next_state_values = agent.target_dqn(next_states).gather(1, next_state_actions.unsqueeze(-1)).squeeze(-1)
                #
                next_state_values[terminal_flags] = 0.0
                #
                expected_state_action_values = next_state_values.detach() * 0.9995 + rewards
                #
                loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
                agent.optimizer.zero_grad()
                loss.backward()
                agent.optimizer.step()
                
            # update target network after certain amount of steps    
            if step_counter > start_learning and step_counter % network_update_every == 0:
                agent.target_dqn.load_state_dict(agent.main_dqn.state_dict())
            
            # if epsiode has ended, step out of the episode while loop
            if terminal:
                break
                
        # keep track of past episode rewards
        eps_rewards.append(episode_reward_sum)
        if len(eps_rewards) % 20 == 0:
            print("{}, done {} episodes, {}, current eps {}".format(step_counter, len(eps_rewards), np.mean(eps_rewards[-100:]), eps))#action_scheduler.eps_current))
            
    ## evaluation        
    eval_rewards = []
    episode_final = 0
    agent.main_dqn.eval()
    for _ in range(eval_runs):
        eval_steps = 0
        state = env.reset()
        
        eval_episode_reward = 0
        negative_rewards = 0
        
        # play an episode
        while eval_steps < 1000:
            # get the action from the agent
            with torch.no_grad():
                    state_v = torch.from_numpy(state.getValue()).unsqueeze(0).float().to(device)
                    action = torch.argmax(agent.main_dqn(state_v)).item()
                  
            # perform a step on the environment
            next_state, reward, terminal, _ = env.step(action)
            
            eval_steps += 1
            
            eval_episode_reward += reward
            state = next_state
            
            # step out of the episode while loop if 
            if terminal:
                terminal = False
                if reward == 1.:
                    episode_final += 1
                break

        eval_rewards.append(eval_episode_reward)

    print("Evaluation score:", np.mean(eval_rewards))
    print("{} of {} episodes ended close to / at the final state.".format(episode_final, eval_runs))

# Evaluation of trained agent

In [None]:
import glob
paths = glob.glob('test_newReward/checkpoints/*.pt') # new reward with penalty when leaving brain mask

p_cp = max(paths, key=os.path.getctime)
model, step_counter, mean_reward, epsilon = load_model(p_cp)

agent.main_dqn.load_state_dict(model)
agent.target_dqn.load_state_dict(model)

In [None]:
env.dataset.get_fa

In [None]:
sphere_verts_torch = torch.from_numpy(mysphere.vertices)

def get_multi_best_action(current_direction, odf_interpolator, my_position, sphere, sphere_verts_torch, K = 3):
    
    # main peak from ODF
    reward = get_multi_best_action_ODF(odf_interpolator, my_position, sphere, K)
            
    if(current_direction is not None):
        reward = reward * (torch.nn.functional.cosine_similarity(sphere_verts_torch, current_direction)).view(1,-1)
    
    reward = torch.max(reward, axis = 0).values
    best_action = torch.argmax(reward)
    print("Max reward: %.2f" % (torch.max(reward).cpu().detach().numpy()))
    return best_action

def get_best_action(current_direction, odf_interpolator, my_position, sphere, sphere_verts_torch, K = 2):
    #odf = torch.from_numpy(odf_interpolator(my_position).squeeze())
    
    # main peak from ODF
    peak_dir = get_best_action_ODF(odf_interpolator, my_position, sphere)
    #peak_dir = get_best_action_pd(odf_interpolator, my_position, sphere)
    
    # cosine similarity wrt. all directions
    reward = abs(torch.nn.functional.cosine_similarity(torch.from_numpy(peak_dir).view(1,-1), sphere_verts_torch))
        
    if(current_direction is not None):
        reward = reward * (torch.nn.functional.cosine_similarity(sphere_verts_torch, current_direction))
    
    best_action = torch.argmax(reward)
    print("Max reward: %.2f" % (torch.max(reward).cpu().detach().numpy()))
    return best_action

def get_best_action_ODF(odf_interpolator, my_position, sphere):
    '''
    ODF computation at 3x3x3 grid
    #coolsl0_odf = odf_interpolator(my_position).squeeze()
    #coord, data = next_state.getCoordinate(), next_state.getValue()
    #grid = get_grid(np.array([3,3,3]))
    ras_points = env.dataset.to_ras(my_position)
    ras_points = grid + ras_points

    interpolated_dwi = env.dataset.get_interpolated_dwi(ras_points, postprocessing=None)

    dti_fit = dti_model.fit(interpolated_dwi)
    coolsl0_odf = dti_fit.odf(sphere)
    coolsl0_odf = np.mean(coolsl0_odf.reshape(-1,100), axis=0)
    '''
    
    # ODF interpolation
    coolsl0_odf = odf_interpolator(my_position).squeeze()
    
    best_action = np.argmax(coolsl0_odf)
    peak_dir = sphere.vertices[best_action]
    return peak_dir

def get_multi_best_action_ODF(odf_interpolator, my_position, sphere, K = 3):
    my_odf = odf_interpolator(my_position).squeeze()

    k_largest = np.argpartition(my_odf.squeeze(),-K)[-K:]
    peak_dirs_torch = torch.from_numpy(sphere.vertices[k_largest]).view(K,3)
    rewards = torch.stack([abs(torch.nn.functional.cosine_similarity(peak_dirs_torch[k:k+1,:], sphere_verts_torch.view(-1, 3))) for k in range(K)])
    
    '''rewards = torch.stack([torch.nn.functional.cosine_similarity(peak_dirs_torch[0:1,:], sphere_verts_torch.view(-1, 3)),
             torch.nn.functional.cosine_similarity(peak_dirs_torch[1:2,:], sphere_verts_torch.view(-1, 3)),
             torch.nn.functional.cosine_similarity(peak_dirs_torch[2:3,:], sphere_verts_torch.view(-1, 3))
             ])
    '''
    return rewards

def get_best_action_pd(pd_interpolator, my_position, sphere):
    peak_dir = pd_interpolator(my_position)[0,0]
    return peak_dir

In [None]:
eval_rewards = []
all_distances = []
all_states = []
l2s = []
max_episode_length = 10
streamline_index = 11
fa_threshold = 0.1
K = 3

#agent.main_dqn.eval()
for _ in range(1):
    eval_steps = 0
    state = env.reset(streamline_index=streamline_index)
    next_state = state
    #state = env.reset()
    #print(state.getCoordinate())
    all_states.append(state.getCoordinate())
    #transition = init_transition()
    #all_states.append(torch.tensor(list(transition)[:3]))
    eval_episode_reward = 0
    episode_final = 0
    #print(env.referenceStreamline_ijk[:6])
    while eval_steps < max_episode_length:
        '''
        #action = torch.argmax(main_dqn(torch.FloatTensor(state.getValue()).unsqueeze(0).to(device)))
        #action = env._get_best_action()
        with torch.no_grad():
            state_v = torch.from_numpy(state.getValue()).unsqueeze(0).float().to(device)
            action = torch.argmax(agent.main_dqn(state_v)).item()
        '''
        #action = get_best_action(next_state, env, dir_interpolator)
        
        my_position = all_states[-1]
        current_direction = None
        
        if(eval_steps > 0):
            # compute tangent of previous step
            current_direction = all_states[-1] - all_states[-2]
            current_direction = current_direction / torch.sqrt(torch.sum(current_direction**2))
            current_direction = current_direction.view(1,3)
        
        #action = get_multi_best_action(current_direction, odf_interpolator, my_position, mysphere, sphere_verts_torch, K = K)
        action = env._get_best_action(current_direction, my_position)

        
        next_state, reward, terminal, _ = env.step(action)
        
        eval_episode_reward += reward
        print(eval_steps, action, next_state.getCoordinate().numpy(), reward)       
        eval_steps += 1
        
        #if(env.line.distance(geom.Point(next_state.getCoordinate())) > 0.1):
        #    print("We left our streamline. Switching to closest one")
        #    break
        l2s.append(env.l2_distance.detach().cpu().numpy())
        
        if eval_steps == 1000:
            terminal = True
        all_distances.append(reward)
        all_states.append(next_state.getCoordinate())
        
        
        fa_x = fa_interpolator(next_state.getCoordinate())
        if(fa_x < fa_threshold):
            print("fa_threshold reached.. terminated. %.2f" % (fa_x))
            terminal = True
        
        state = next_state
        if terminal:
            terminal = False
            break

    eval_rewards.append(eval_episode_reward)

print("Evaluation score:", np.min(eval_rewards))

########################
### visualise streamline
########################
%matplotlib notebook 
state = env.reset(streamline_index=streamline_index) 

states = torch.stack(all_states)

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(env.referenceStreamline_ijk.T[0][0:max_episode_length+1], env.referenceStreamline_ijk.T[1][0:max_episode_length+1], env.referenceStreamline_ijk.T[2][0:max_episode_length+1], '-*')
ax.plot3D(states.T[0], states.T[1], states.T[2])

In [None]:
k_largest.shape

In [None]:
np.argmax(my_odf.squeeze())

In [None]:
all_states

In [None]:
### visualise streamline
%matplotlib notebook 
state = env.reset(streamline_index=50) 

states = torch.stack(all_states)

fig = plt.figure()
ax = plt.axes(projection='3d')
ax.plot3D(env.referenceStreamline_ijk.T[0][0:10], env.referenceStreamline_ijk.T[1][0:10], env.referenceStreamline_ijk.T[2][0:10])
ax.plot3D(states.T[0][0:10], states.T[1][0:10], states.T[2][0:10])