In [10]:
import torch
import gym
import numpy as np
from dfibert.tracker.nn.rl import Agent, Action_Scheduler

import os, sys
sys.path.append('ext/deepFibreTracking/')

import dfibert.envs.RLtractEnvironment as RLTe

In [11]:
max_steps = 3000000
replay_memory_size = 5000
agent_history_length = 1
evaluate_every = 20000
eval_runs = 5
network_update_every = 600

max_episode_length = 200


In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = RLTe.RLtractEnvironment(device = 'cpu')
n_actions = env.action_space.n
#print(n_actions)

Loading precomputed streamlines (data/HCP307200_DTI_smallSet.vtk) for ID 100307


In [15]:
print("Init agent")
state = env.reset()
agent = Agent(n_actions=n_actions, inp_size=state.getValue().shape[0], device=device, hidden=256, agent_history_length=agent_history_length, memory_size=replay_memory_size)

print("Init epsilon-greedy action scheduler")
action_scheduler = Action_Scheduler(num_actions=n_actions, max_steps=max_steps, replay_memory_start_size=replay_memory_size, model=agent.main_dqn)

step_counter = 0
    
rewards = []

print("Start training...")
while step_counter < max_steps:
    epoch_step = 0

    ######## fill memory begins here
    while epoch_step < evaluate_every:  # To Do implement evaluation
        state = env.reset()
        episode_reward_sum = 0
        
        #fill replay memory while interacting with env
        for _ in range(max_episode_length):
            # get action with epsilon-greedy strategy
            try:
                action = action_scheduler.get_action(step_counter, state.getValue().unsqueeze(0))
            except PointOutsideOfDWIError:
                action = n_actions-1

            # perform step on environment
            next_state, reward, terminal = env.step(action)
            
            # increase counter
            step_counter += 1
            epoch_step += 1

            # accumulate reward for current episode
            episode_reward_sum += reward
            

            # add current state, action, reward and terminal flag to memory
            agent.replay_memory.add_experience(action=action,
                                               state=state,
                                               reward=reward,
                                               terminal=terminal)
            
            # prepare for next step
            state = next_state

            ####### optimization is happening here
            if step_counter > replay_memory_size:
                loss = agent.optimize()

            ####### target network update
            if step_counter > replay_memory_size and step_counter % network_update_every == 0:
                agent.target_dqn.load_state_dict(agent.main_dqn.state_dict())
            
            # if episode ended before maximum step
            if terminal:
                #print("terminal reached")
                terminal = False
                state = env.reset()
                episode_reward_sum = 0
                break
        #print("Append reward")
        rewards.append(episode_reward_sum)

        if len(rewards) % 10 == 0:
            print("[{}] {}, {}".format(len(rewards), step_counter, np.mean(rewards[-100:])))

########## evaluation starting here
    eval_rewards = []
    for _ in range(eval_runs):
        eval_steps = 0
        state = env.reset()
        eval_episode_reward = 0
        while eval_steps < max_episode_length:
            action = action_scheduler.get_action(step_counter, state.getValue().unsqueeze(0), evaluation=True)

            next_state, reward, terminal = env.step(action)

            eval_steps += 1
            eval_episode_reward += reward
            state = next_state

            if terminal:
                terminal = False
                break

        eval_rewards.append(eval_episode_reward)
    
    print("Evaluation score:", np.mean(eval_rewards))


KeyboardInterrupt: 

In [None]:
#!mkdir -p 'checkpoints/'
#torch.save(agent.main_dqn.state_dict(), 'checkpoints/fiber_agent_{}_reward_{:.2f}.pth'.format(step_counter, np.mean(rewards[-100:])))