# Reinforcement Learning- and FEM-based Inverse Design

## Experiment Logger

In [1]:
import os
import neptune.new as neptune
from neptune.new.types import File
from neptune.new.utils import stringify_unsupported

os.environ['NEPTUNE_PROJECT']="pil-clemson/metamtl-rl-test"
os.environ['NEPTUNE_NOTEBOOK_ID']="45d03d69-6ac7-41ca-8af8-80caaa73aad5"
os.environ['NEPTUNE_NOTEBOOK_PATH']="metamaterial-rl/RemoteFEM-DQN.ipynb"

exp = None

In [2]:
experiment_repeat = 5

In [3]:
tags = ['Cloak', 'RandInit', 'Dev']

## Import

In [4]:
from __future__ import annotations
from typing import Union, Optional, Callable, Any
from typing import Tuple, List, Set, Dict
from typing import NamedTuple
from typing import Generator

In [5]:
from collections import defaultdict, deque
from types import SimpleNamespace
import queue
from queue import PriorityQueue
from enum import Enum

In [6]:
from dataclasses import dataclass, field

In [7]:
import traceback
import tracemalloc

In [8]:
import ipywidgets as widgets
from IPython.display import clear_output

In [9]:
import os
import sys
import copy
import time
from datetime import datetime, timedelta
from pprint import pformat, pprint
import multiprocessing
import random
import math
import itertools
import uuid

In [10]:
from tqdm.notebook import trange, tqdm

In [11]:
import matplotlib.pyplot as plt

In [12]:
import plotly.express as px

In [13]:
import torch
from torch import nn

from torch import Tensor, BoolTensor

from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from torchvision.transforms import PILToTensor

print('PyTorch version:', torch.__version__)

PyTorch version: 1.13.0


In [14]:
import torchinfo

In [15]:
import numpy as np

In [16]:
from skimage.measure import block_reduce

In [17]:
from SimHubClient import SimHubClient

In [18]:
torch.set_default_dtype(torch.double)

## Computing Devices

In [19]:
print('CPU Cores:', multiprocessing.cpu_count())

CPU Cores: 56


In [20]:
# Getting all memory using os.popen()
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')  # e.g. 4015976448
mem_gib = mem_bytes/(1024.**3)
print('Memory size:', int(mem_gib), 'GiB')

Memory size: 376 GiB


In [21]:
available_gpus = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
print('GPUs:', available_gpus)

GPUs: ['Tesla V100S-PCIE-32GB', 'Tesla V100S-PCIE-32GB']


In [22]:
cuda = torch.device('cpu') if torch.cuda.is_available() else torch.device('cpu')
print('Current computing device:', cuda)

Current computing device: cpu


## DEBUG FLAG

In [23]:
DEBUG = False

In [24]:
class DEBUG:
    result_generation=True
    result_visualization=False
    
    transition_log=True
    transition_log_buffer=''
    transition_log_buffer_gen=''

    action_log=False
    prediction_log=False
    state_log=True
    epsilon_log=True
    reward_log=True
    state_visualization=False
    state_target_diff_visualization=False
    
    visualization_sampling_rate=.000
    
    optimizer_sample_log=False
    
    start_from_goal=False
    
    trace_memory=False
    
    in_generation_mode=False
    
    transition_checked=None

In [25]:
if DEBUG.trace_memory:
    os.environ['PYTHONTRACEMALLOC'] = '3'
    tracemalloc.start()

## Helper Functions

In [26]:
clip = lambda x, l, u: l if x < l else u if x > u else x

## Hyperparameters

In [27]:
environment_config = {
    'rings': 3,
    
    'result_size': (400, 400),
    'result_range': (293.15, 353.15),
}

hyperparameters = {
    'target_update_interval': 10,
    'optimization_iterations': 1,
    'experience_replay_capacity': 10000,
    'replay_batch_size': 32,
    'lr': .001,
    'discount_factor': .9,
    'epsilon_initial': 1.,
    'epsilon_minimal': .1,
    'epsilon_halflife': 2000,
    'epsilon_boost_preterminal': .3,
    'epsilon_generate': .1,

    'max_episode': 300,
    'max_step_per_episode': 500,
    
    'goal_reward': 10000.,
    'terminal_threshold': [.1, .5],  #
    'invalid_state_penalty': 0.,
    'failed_episode_penalty': -10.,
    
    'low_value_earlier_stop_min_step': 150,
    'low_value_earlier_stop_threshold': [0.3, 1.5],
    'low_value_earlier_stop_step_count': 10,
    
    
    'reward_func': ['-logx-0.5(norm,1.4,1.7)', '-logx'],  #
    
}    




## Reinforcement Learning Environment

## Interfaces and Dataclasses

In [28]:
class State(dict):
    def __init__(self) -> None:
        super().__init__()
        
    def step(self, action: 'Action') -> 'State': 
        return action(copy.deepcopy(self))

    def to_tensor(self) -> Tensor: raise NotImplementedError

In [29]:
class Action:
    def __init__(self, name: str, action: Callable[[State], State]) -> None:
        self.name = name
        self.action = action
        
    def __repr__(self) -> str:
        return self.name
        
    def __call__(self, state: State) -> State:
        return self.action(state)
# Action = Callable[[State], State]

In [30]:
class Environment:    
    def __init__(self) -> None:
        self._state: State = None
        self._action_space: List[Action] = list()
        self._valid_actions: BoolTensor = None

    def __repr__(self) -> str:
        return f'''{self.__class__.__name__}(
    Action space size: {self.action_count()}
    Current state: {self.state}
)'''
        
    @property
    def state(self) -> State: return self._state
    
    @property
    def action_space(self) -> List[Action]: return self._action_space        

    def action_count(self) -> int: return len(self._action_space)        
    
    def reset(self) -> None: raise NotImplementedError
        
    def step(self, action_index: int) -> None: raise NotImplementedError


In [31]:
class ReplayTransition(NamedTuple):
    state: State
    action_index: int
    reward: float
    next_state: State
    note: str

In [32]:
@dataclass
class SimulationTransition:
    # From stepping
    episode: int
    step: int
    
    state: State
    action_index: int
    next_state: State
    
    
    action_name: str = None
    action_type: str = None
    
    # From FEM simulator
    state_id: str = None
    next_state_id: str = None
    
    state_sim: Dict[str, Any] = None
    next_state_sim: Dict[str, Any] = None
    
    # From reward function
    state_sim_value: float = None
    is_state_terminal: bool = None
    state_terminal_type: int = None
    
    next_state_sim_value: float = None
    is_next_state_terminal: bool = None
    next_state_terminal_type: int = None
    
    reward: float = None    
    
    def __repr__(self) -> str:
        return f'[{self.episode}-{self.step}] ' +\
                f'{self.state}({self.state_sim_value}, {self.is_state_terminal})' +\
                f' =={self.action_name}({self.action_type})==> ' +\
                f'{self.next_state}({self.state_sim_value}, {self.is_state_terminal})' +\
                f'  R:{self.reward} {"Loop!" if self.state_id == self.next_state_id else ""}'

In [33]:
RewardFunc = Callable[[SimulationTransition], SimulationTransition]

## State and Environment

In [34]:
class HarvestRingState(State):
    def __init__(self) -> None:
        super().__init__()
        # # Fixed start
        # self['r'] = [50, 60., 70., 80., 100.]
        # self['k'] = [0., 30., 30., 30., 30., 10.]
        # Random start
        self['r'] = random.sample([float(x) for x in range(55, 100, 5)], k=3)
        self['r'].append(50.)
        self['r'].append(100.)
        self['r'].sort()
        
        self['k'] = [0.]
        for i in range(4):
            self['k'].append(float(random.randint(0, 13) * 5))
        self['k'].append(10.)
    
    def to_tensor(self) -> torch.Tensor:
        return torch.cat([torch.tensor(self['r']), torch.tensor(self['k'])]).double()

In [35]:
HarvestRingState().to_tensor()

tensor([ 50.,  55.,  65.,  90., 100.,   0.,  30.,  25.,  35.,   5.,  10.])

In [36]:
class HarvRingEnvironment(Environment):
    def __init__(self) -> None:
        super().__init__()
        
        self.rings = environment_config['rings']
        
        self.reset()
        
        def adjust_ring_r(ring: int, r_mod: float, lower_bound: float, upper_bound: float):
            def action(state: State):
                old_value = state['r'][ring]
                new_value = clip(old_value + r_mod, lower_bound, upper_bound)
                
                if new_value == state['r'][ring - 1] or new_value == state['r'][ring + 1]:
                    new_value = old_value
                    
                state['r'][ring] = new_value
                return state
            return Action(f'{ring}:r{r_mod}', action)
                
        def adjust_ring_k(ring, k_mod: float, lower_bound: float, upper_bound: float):
            def action(state: State):
                old_value = state['k'][ring]
                new_value = clip(old_value + k_mod, lower_bound, upper_bound)
                state['k'][ring] = new_value
                return state
            return Action(f'{ring}:k{k_mod}', action)
          
        for ring in [1, 2, 3]:
            self._action_space.append(adjust_ring_r(ring, +5., 55., 95.))
            self._action_space.append(adjust_ring_r(ring, -5., 55., 95.))
        
        for ring in [1, 2, 3, 4]:
            self._action_space.append(adjust_ring_k(ring, +5., 0., 60.))
            self._action_space.append(adjust_ring_k(ring, -5., 0., 60.))
            
        # TBD: adjust k of the board (region 3)
            
    def reset(self) -> None: 
        self._state = HarvestRingState()
                
        
    def step(self, action_index: int) -> None: 
        action = self._action_space[action_index]
        self._state = self._state.step(action)

## DQN

### FEM-based Reward & Terminal Function

In [37]:
def badloe(reward_thresholds: List[float]):
    low, mid, high = reward_thresholds
    def func(eta):
        if eta < low:
            return -10
        if eta > high:
            return 10000
        return (eta / mid) ** 9 - 1 
    return func

reward_funcs = {
    'linear': lambda x: x,
    '-logx': lambda x: -np.log(x),
    '-logx-0.5': lambda x: -np.log(x) - 0.5,
    '-logx-0.5(norm,1.4,1.7)': lambda x: (-np.log(x) + 0.9) / 1.7 - 1,
    # 'badloe': badloe(hyperparameters['reward_thresholds']),
}

class FEMReward():
    def __init__(self,
                 hyperparameters: Dict[str, Any]) -> None:
        
        self.reward_func = [reward_funcs[func] for func in hyperparameters['reward_func']]
        
        self.max_step_per_episode = hyperparameters['max_step_per_episode']

        self.goal_reward = hyperparameters['goal_reward']
        self.terminal_threshold = hyperparameters['terminal_threshold']
        self.invalid_state_penalty = hyperparameters['invalid_state_penalty']
        self.failed_episode_penalty = hyperparameters['failed_episode_penalty']
        
        self.current_episode = 0
        
        self.low_value_min_step = hyperparameters['low_value_earlier_stop_min_step']
        self.low_value_threshold = hyperparameters['low_value_earlier_stop_threshold']
        self.low_value_step_threshold = hyperparameters['low_value_earlier_stop_step_count']
        self.low_value_step_count = 0
        

    def __call__(self, transition: SimulationTransition) -> SimulationTransition:
        """
        Calculate reward value for a transition, and determine if a terminal state is reached

        Parameters
        ----------
        transition : SimulationTransition
            A transition with completed simulation data

        Returns
        -------
        float | None
            Reward value, None if the next_state is terminal
        bool
            The next_state is terminal

        Raises
        ------
        TODO
        """
        
        if transition.episode != self.current_episode:
            self.low_value_step_count = 0
            self.current_episode = transition.episode
        

        size = environment_config['result_size']
        center = (int(size[0] / 2), int(size[1] / 2))
        cloaked_radius = int(transition.state['r'][0])
        t_delta_ref = .3 * cloaked_radius
        
        
        Y, X = np.ogrid[:size[0], :size[1]]
        dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
        t_dist_mask = dist_from_center <= cloaked_radius
        t_dist_ref = np.load('ref400.npy')
        
        if transition.state_sim and transition.state_sim['status'] == 'done':
            state_t_dist = transition.state_sim['output']['temperature_distribution'][2].reshape(environment_config['result_size'])
            state_t_a = state_t_dist[center[0], center[1] - cloaked_radius]
            state_t_b = state_t_dist[center[0], center[1] + cloaked_radius]
            
            state_t_delta = np.abs(state_t_a - state_t_b)
            
            state_t_neuturality = np.sum(np.abs(state_t_dist - t_dist_ref)) / t_dist_mask.size
            
            transition.state_sim_value = [state_t_delta / t_delta_ref, state_t_neuturality]
            transition.is_state_terminal = transition.state_sim_value[0] <= self.terminal_threshold[0] \
                and transition.state_sim_value[1] <= self.terminal_threshold[1]

        
        if transition.next_state_sim and transition.next_state_sim['status'] == 'done':
            next_state_t_dist = transition.next_state_sim['output']['temperature_distribution'][2].reshape(environment_config['result_size'])
            next_state_t_a = next_state_t_dist[center[0], center[1] - cloaked_radius]
            next_state_t_b = next_state_t_dist[center[0], center[1] + cloaked_radius]
            
            next_state_t_delta = np.abs(next_state_t_a - next_state_t_b)
            
            next_state_t_neuturality = np.sum(np.abs(next_state_t_dist - t_dist_ref)) / t_dist_mask.size
            
            transition.next_state_sim_value = [next_state_t_delta / t_delta_ref, next_state_t_neuturality]
            transition.is_next_state_terminal = transition.next_state_sim_value <= self.terminal_threshold \
                and transition.state_sim_value[1] <= self.terminal_threshold[1]
        
        if transition.state_sim_value:
            transition.reward = self.reward_func[0](transition.state_sim_value[0]) \
                + self.reward_func[1](transition.state_sim_value[1])
            exp['reward_dT'].append(self.reward_func[0](transition.state_sim_value[0]))
            exp['reward_Mv'].append(self.reward_func[1](transition.state_sim_value[1]))
            exp[f'reward_total'].append(transition.reward)
            exp[f'reward_total_episode/{transition.episode}'].append(transition.reward)
        else:
            transition.reward = self.invalid_state_penalty
        
        # Reward only for primary terminal condition
        if transition.is_state_terminal:
            transition.reward = self.goal_reward
            transition.state_terminal_type = 0  # 0 for success episode
        else:
            # Determine alternative terminal
            # Alt-term 1 (after X step)
            if transition.step >= self.max_step_per_episode - 1:
                transition.is_state_terminal = True;
                transition.state_terminal_type = 1
                transition.reward = self.failed_episode_penalty
            # Alt-term 2 (trapped in low-val area)
            elif transition.step >= self.low_value_min_step: 
                if transition.state_sim_value[0] > self.low_value_threshold[0] \
                    and transition.state_sim_value[1] > self.low_value_threshold[1]:
                    self.low_value_step_count += 1
                    exp['low_val'].append(f'E{transition.episode} Low value step {self.low_value_step_count} {transition.state_sim_value}')
                    if self.low_value_step_count >= self.low_value_step_threshold:
                        transition.is_state_terminal = True;
                        transition.state_terminal_type = 2
                        transition.reward = self.failed_episode_penalty
                        exp['low_val'].append(f'E{transition.episode} Low value terminal {self.low_value_step_count}')
                else:
                    if self.low_value_step_count > 0:
                        exp['low_val'].append(f'E{transition.episode} Low value count reset')
                    self.low_value_step_count = 0

            
        # LOGGING
        if transition.state_sim and transition.state_sim['status'] == 'done':
            if DEBUG.state_log:
                if not DEBUG.in_generation_mode:
                    exp['state_values_dT'].append(transition.state_sim_value[0])
                    exp['state_values_Mv'].append(transition.state_sim_value[1])
                else:
                    exp['state_values_gen_dT'].append(transition.state_sim_value[0])  
                    exp['state_values_gen_Mv'].append(transition.state_sim_value[1])            
                    
        if DEBUG.reward_log and not DEBUG.in_generation_mode:                
            exp['reward'].append(transition.reward)
        
        return transition

### Network Definition

In [38]:
# Network Container
class Model():
    def __init__(self, network: nn.Module, loss_func: _Loss, optimizer: Optimizer):
        self.network = network
        self.loss_func = loss_func
        self.optimizer = optimizer

    def __call__(self, network_input: Tensor) -> Tensor:
        return self.network(network_input)

In [39]:
def QNet(state_size: int = 11, action_number: int = 14, target_network: bool = False):
    net = nn.Sequential(
        nn.Linear(state_size, 128, device=cuda, dtype=torch.double),
        nn.ReLU(),
        nn.Linear(128, 256, device=cuda, dtype=torch.double),
        nn.ReLU(),
        nn.Linear(256, action_number, device=cuda, dtype=torch.double),
    )
    if target_network:
        return Model(network=net, loss_func=None, optimizer=None)
    else:
        # exp['Network'] = str(torchinfo.summary(net, input_size=(32, state_size), 
        #                                        device=cuda, verbose=0))
        return Model(network=net, loss_func=nn.HuberLoss(), optimizer=torch.optim.Adam(net.parameters(), 0.001))

### Replay Memory Class

In [40]:
class ReplayMemory():
    def __init__(self, capacity):
        self.memory: deque = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(ReplayTransition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

### Agent Class

In [41]:
class Agent():
    def __init__(self, environment: Environment, simulator: SimHubClient, reward_func: RewardFunc, 
                 policy_network: nn.Module, target_network: nn.Module, hyperparameters: Dict[str, Any]) -> None:
        self.environment: Environment = environment
        self.fem_simulator: SimHubClient = simulator
        self.fem_reward_func: RewardFunc = reward_func
        
        self.policy_network: Model = policy_network
        self.target_network: Model = target_network  
        
        self.target_update_interval: int = hyperparameters['target_update_interval']

        self.optimization_iterations: int = hyperparameters['optimization_iterations']
        self.max_step_per_episode: int = hyperparameters['max_step_per_episode']
        self.experience_replay: ReplayMemory = ReplayMemory(hyperparameters['experience_replay_capacity'])
        self.replay_batch_size: int = hyperparameters['replay_batch_size']

        self.discount_factor: float = hyperparameters['discount_factor']
        self.epsilon_initial: float = hyperparameters['epsilon_initial']
        self.epsilon_minimal: float = hyperparameters['epsilon_minimal']
        self.epsilon_halflife: float = hyperparameters['epsilon_halflife']
        self.epsilon_generate: float = hyperparameters['epsilon_generate']
        
        self.epsilon_boost_preterminal: float = hyperparameters['epsilon_boost_preterminal']
        
        
        
        self.pending_transitions: List[SimulationTransition] = list()
        
        # Set to true when generating result
        self.generation_mode: bool = False
        self.explored_step: int = 0
        
        self.total_steps: int = 0
        
        self.episode: int = 0
        self.step_num: int = 0
        
        # Logging result of episode and boost epsilon when needed
        self.previous_episode_terminal: List[bool] = list()
        self.terminal_reached: bool = False
        
        self.convergence_episode: int = 0
        self.convergence_step: int = 100000
        self.convergence_episode_gen: int = 0
        self.convergence_step_gen: int = 100000

        
    def updaet_action_mask(self) -> None:
        ...
        
    
    def select_action(self) -> Tuple[State, int, str]:
        """
        Decide an action based on epsilon greedy algorithm

        Returns
        -------
        State
            Current state instance
        int
            Index number of an action in the action space
        
        str
            Action type, literal string of "Prediction" or "Random"
        """
        state = self.environment.state
        # epsilon = self.epsilon_minimal + (self.epsilon_initial  - self.epsilon_minimal) * \
        #             math.exp(-1. * self.total_steps / self.epsilon_decay)
        
        # Determining epsilon
        
        epsilon = max(self.epsilon_initial 
                             * (0.5 ** (self.total_steps / self.epsilon_halflife)), 
                             self.epsilon_minimal)
        
        if not self.terminal_reached: epsilon += self.epsilon_boost_preterminal
        
        if self.generation_mode: epsilon = self.epsilon_generate
        
        # Epsilon determined
        
        if DEBUG.epsilon_log:
            if self.generation_mode:
                exp['epsilon_gen'].append(epsilon, step=self.episode + self.step_num / self.max_step_per_episode)
            else:
                exp['epsilon'].append(epsilon, step=self.episode + self.step_num / self.max_step_per_episode)
        
        
        
        if random.random() > epsilon:
            prediction = self.policy_network(state.to_tensor().flatten()).flatten()
            
            if DEBUG.prediction_log:
                if not self.generation_mode:
                    log_target = 'prediction'
                else:
                    log_target = 'prediction_gen'
                exp[f'{log_target}/{self.episode}'].append(f'Step {self.step_num}', step=self.step_num)
                exp[f'{log_target}/{self.episode}'].append(str(state), step=self.step_num+0.1)
                preds = []
                for i in range(len(prediction)):
                    preds.append((prediction[i].item(), str(self.environment.action_space[i])))
                preds.sort(reverse=True)
                exp[f'{log_target}/{self.episode}'].append(pformat(preds), step=self.step_num+0.2)
                
            action_index = prediction.argmax().item()
            action_type = 'Prediction'
            
        else:
            action_index = random.randrange(len(self.environment.action_space))
            action_type = 'Random'
            self.explored_step += 1
            
        if not self.generation_mode:
            self.total_steps += 1
        return state, action_index, action_type
    
    def step(self) -> SimulationTransition:
        """
        Perform an action in the in the environment and submit the transition as FEM task to simulator

        Returns
        -------
        SimulationTransition
            Return the transition
        """
        state, action_index, action_type = self.select_action()
        self.environment.step(action_index)
        next_state = self.environment.state
        
        
        transition = SimulationTransition(self.episode, self.step_num, state, action_index, next_state)
        transition.state_id, state_result = self.fem_simulator.submit_task(state)
        transition.next_state_id, next_state_result = self.fem_simulator.submit_task(next_state)
        transition.action_type = action_type
        transition.action_name = self.environment.action_space[action_index].name
        
        self.pending_transitions.append(transition)
        return transition
    
    def compute_reward(self, transition: SimulationTransition) -> None:
        """
        Compute reward value and terminal status for a COMPLETED transition. 
        The states, action and reward will be pushed into experience replay
        
        If the current state is terminal, transition.next_state will be set to None

        """
        self.fem_reward_func(transition)
        
        if DEBUG.transition_log:
            if not self.generation_mode:
                DEBUG.transition_log_buffer += str(transition) + '\n'
                # exp[f'transitions/{transition.episode}'].append(str(transition))
            else:
                DEBUG.transition_log_buffer_gen += str(transition) + '\n'
                # exp[f'transitions_gen/{transition.episode}'].append(str(transition))

        if not self.generation_mode:
            self.experience_replay.push(transition.state.to_tensor(), 
                                        transition.action_index, 
                                        transition.reward, 
                                        None if transition.is_state_terminal else transition.next_state.to_tensor(),
                                        f'{transition.episode}-{transition.step}')

        
    def compute_pending_rewards(self) -> Tuple[SimulationTransition, float]:
        episode_return: float = 0.
        for transition in tqdm(self.pending_transitions):
            terminal_transition = transition
            
            transition.state_sim = self.fem_simulator.wait_for_task(transition.state_id)
            transition.next_state_sim = self.fem_simulator.wait_for_task(transition.next_state_id)
        
            self.compute_reward(transition)
            
            episode_return = episode_return * self.discount_factor + transition.reward;
            
            if transition.is_state_terminal:
                self.terminal_reached = True
                if not self.generation_mode:
                    exp['terminal_type'].append(transition.state_terminal_type, step=transition.episode)
                else:
                    exp['terminal_type_gen'].append(transition.state_terminal_type, step=transition.episode)
                if transition.step < self.max_step_per_episode - 1: 
                    break

                
        self.pending_transitions.clear()
        return transition, episode_return
        
    def optimize(self) -> None:
        if len(self.experience_replay) < self.replay_batch_size: return

        for i in range(self.optimization_iterations):
            samples = self.experience_replay.sample(self.replay_batch_size)
            batch = ReplayTransition(*zip(*samples))
            
            if DEBUG.optimizer_sample_log:
                filename = f'logs/sampled_transition-{self.total_steps + i / self.optimization_iterations}.log'
                with open(filename, 'w') as fp:
                    pprint(samples, stream=fp)
                exp['sampled_transition'].upload_files(filename)

            non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                      batch.next_state)), device=cuda, dtype=torch.bool)
            # If none of the transition has a valid next_step, skip the round
            if not non_final_mask.any():
                return
            non_final_next_states = torch.stack([s.flatten() for s in batch.next_state
                                                            if s is not None])

            state_batch = torch.stack([s.flatten() for s in batch.state])
            action_batch = torch.tensor(batch.action_index, device=cuda).unsqueeze(1)
            reward_batch = torch.tensor(batch.reward, device=cuda)

            state_action_values = self.policy_network(state_batch).gather(1, action_batch)

            next_state_values = torch.zeros(self.replay_batch_size, device=cuda)
            next_state_values[non_final_mask] = self.target_network(non_final_next_states).max(1)[0].detach()

            expected_state_action_values = (next_state_values * self.discount_factor) + reward_batch

            loss = self.policy_network.loss_func(state_action_values, expected_state_action_values.unsqueeze(1))
            optimization_loss = float(loss)
            self.policy_network.optimizer.zero_grad()
            loss.backward()
            for param in self.policy_network.network.parameters():
                param.grad.data.clamp_(-1, 1)
            self.policy_network.optimizer.step()
            
            exp['optimization_loss'].append(optimization_loss, step=self.total_steps + i / self.optimization_iterations)
        
    def update_target_network(self) -> None:
        self.target_network.network.load_state_dict(self.policy_network.network.state_dict())
        
    def train(self, episodes: int) -> None:
        for episode in range(episodes):
            print('')
            print(f'Episode: {episode}')
            self.episode = episode
            
            self.environment.reset()
            
            print('Stepping...')
            for self.step_num in trange(self.max_step_per_episode):
                self.step()

                self.optimize()
                
                if self.total_steps % self.target_update_interval == 0:
                    self.update_target_network()
                    
            exp['total_explored'].append(self.explored_step, step=self.episode)

            print('Processing rewards...')
            transition, episode_return = self.compute_pending_rewards()
            
            if transition.is_state_terminal and transition.state_terminal_type == 0: 
                print(f'Terminal state found in episode {transition.episode} step {transition.step} deltaT {transition.state_sim_value}:')
                print(transition.state)
                exp['goal_reached'].append(f'{transition.state}-{transition.state_sim_value}', step=self.episode)
           
            print(f'Episode return: {episode_return}')
            exp['episode_return'].append(episode_return, step=self.episode)
                
            exp['terminal_step'].append(transition.step, step=self.episode)
            
            if transition.step < self.convergence_step:
                self.convergence_step = transition.step
                self.convergence_episode = transition.episode
            elif transition.step > self.convergence_step:
                self.convergence_step = 100000
                self.convergence_episode = 0                
            
            self.fem_simulator.clear_tasks()

            if DEBUG.result_generation:
                generated_result = self.generate()
                
                exp['generated_result'].append(str(generated_result), step=self.episode)
                
                if DEBUG.result_visualization:
                    print('Visualizing result...')
                    log_vis_sim(generated_result.state_sim['output']['temperature_distribution'][2], 'generated_state_vis', 
                                append=True, step=self.episode, vrange=(293.15, 353.15))
                
            if DEBUG.transition_log:
                log_file = f'logs/transition-{self.episode}.log'
                with open(log_file, 'w') as fp:
                    fp.write(DEBUG.transition_log_buffer)
                DEBUG.transition_log_buffer = ''
                exp['transition_log'].upload_files(log_file)
                
                log_file = f'logs/transition-gen-{self.episode}.log'
                with open(log_file, 'w') as fp:
                    fp.write(DEBUG.transition_log_buffer_gen)
                DEBUG.transition_log_buffer_gen = ''
                exp['transition_log'].upload_files(log_file)

            
            if DEBUG.trace_memory:
                snapshot = tracemalloc.take_snapshot()
                with open(f'logs/mem{self.episode}.log', 'w') as fp:
                    for line in snapshot.statistics('lineno')[:30]:
                        print(line, file=fp)
                        
            torch.save(self.policy_network, f'checkpoints/checkpoint_{self.episode}.pt')
            exp[f'model_checkpoints/policy_network/{self.episode}'].upload(f'checkpoints/checkpoint_{self.episode}.pt')
            
        
    def generate(self) -> State:
        self.generation_mode = True
        if DEBUG:
            DEBUG.in_generation_mode=True
            
        print('Generating...')
            
        self.environment.reset()
            
        for self.step_num in trange(self.max_step_per_episode):
            self.step()

            
        print('Evaluating states...')
        transition, episode_return = self.compute_pending_rewards()
        
        if transition.is_state_terminal and transition.state_terminal_type == 0: 
            print(f'Terminal state reached in step {transition.step} deltaT {transition.state_sim_value}:')
            print(transition.state)
            exp['goal_reached_gen'].append(f'{transition.state}-{transition.state_sim_value}', step=self.episode) 
        
        print(f'Episode return: {episode_return}')
        exp['episode_return_gen'].append(episode_return, step=self.episode)
            
        exp['terminal_step_gen'].append(transition.step, step=self.episode)
        
        if transition.step < self.convergence_step_gen:
            self.convergence_step_gen = transition.step
            self.convergence_episode_gen = transition.episode
        elif transition.step > self.convergence_step:
            self.convergence_step_gen = 100000
            self.convergence_episode_gen = 0         
        
        self.generation_mode = False
        if DEBUG:
            DEBUG.in_generation_mode=False
        return transition
        

## Training

In [None]:


for i in range(experiment_repeat):
    #
    exp = neptune.init_run(project="pil-clemson/metamtl-rl-test",
                           capture_hardware_metrics=True,
                           capture_stderr=True,
                           capture_stdout=True,
                           source_files=['RemoteFEM-DQN-Harv.ipynb'],
                          )
    
    exp['sys/tags'].add(tags)
    
    #
    if i > 0:
        exp['sys/tags'].add(['Rerun'])

    #
    exp['EnvConfig'] = stringify_unsupported(environment_config)
    exp['Hyperparameters'] = stringify_unsupported(hyperparameters)

    #
    env = HarvRingEnvironment()

    fem = SimHubClient('10.128.97.115')
    fem.set_experiment('./elmer_thermal_cloak_ring/elmer_task.yml')

    reward_func = FEMReward(hyperparameters)

    agent = Agent(env, fem, reward_func, QNet(), QNet(target_network=True), hyperparameters)

    #
    agent.train(hyperparameters['max_episode'])

    #
    exp['convergence'] = f'{agent.convergence_episode}({agent.convergence_step})' \
                        + f'/{agent.convergence_episode_gen}({agent.convergence_step_gen})'

    #
    agent.episode += 1
    generated_result = agent.generate()

    exp['generated_result_final'] = str(generated_result)

    exp['sys/tags'].add(['Done'])
    if generated_result.state_sim_value <= hyperparameters['terminal_threshold']:
        exp['sys/tags'].add(['Sucessful'])

    print('Done')

    #
    fem.close()

    #
    exp.stop()
    
    clear_output(wait=True)

https://app.neptune.ai/pil-clemson/metamtl-rl-test/e/RLTEST-95
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api/run#stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.
Waiting for database connection..


  exp['EnvConfig'] = stringify_unsupported(environment_config)


Connected to database

Episode: 0
Stepping...


  0%|          | 0/500 [00:00<?, ?it/s]

Processing rewards...


  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
raise NotImplementedError

In [None]:
fem.close()

## 