# Reinforcement Learning- and FEM-based Inverse Design

## Experiment Logger

In [1]:
import os
import neptune.new as neptune

os.environ['NEPTUNE_PROJECT']="pil-clemson/metamtl-rl"
os.environ['NEPTUNE_NOTEBOOK_ID']="45d03d69-6ac7-41ca-8af8-80caaa73aad5"
os.environ['NEPTUNE_NOTEBOOK_PATH']="metamaterial-rl/MPFEM-DQN.ipynb"

exp = neptune.init(project="pil-clemson/metamtl-rl",)

https://app.neptune.ai/pil-clemson/metamtl-rl/e/METAMTLRL-137
Remember to stop your run once youâ€™ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


## Import

In [1]:
from __future__ import annotations
from typing import Union, Optional, Callable, Any
from typing import Tuple, List, Set, Dict
from typing import NamedTuple
from typing import Generator

In [2]:
from collections import defaultdict, deque
from types import SimpleNamespace
import queue
from queue import PriorityQueue
from enum import Enum

In [3]:
from dataclasses import dataclass, field

In [4]:
import traceback

In [5]:
import os
import sys
import copy
import time
from datetime import datetime, timedelta
import multiprocessing
import random
import math
import itertools
import uuid

In [6]:
import torch
from torch import nn

from torch import Tensor, BoolTensor

from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from torchvision.transforms import PILToTensor

print('PyTorch version:', torch.__version__)

PyTorch version: 1.13.0


In [110]:
import torchinfo

## Computing Devices

In [14]:
print('CPU Cores:', multiprocessing.cpu_count())

CPU Cores: 56


In [15]:
# Getting all memory using os.popen()
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')  # e.g. 4015976448
mem_gib = mem_bytes/(1024.**3)
print('Memory size:', int(mem_gib), 'GiB')

Memory size: 376 GiB


In [16]:
available_gpus = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
print('GPUs:', available_gpus)

GPUs: ['Tesla V100S-PCIE-32GB', 'Tesla V100S-PCIE-32GB']


In [17]:
cuda = torch.device('cpu') if torch.cuda.is_available() else torch.device('cpu')
print('Current computing device:', cuda)

Current computing device: cpu


## Batch Logger

The Batch Logger logs value of one step without specify the step number every time.
Also, steps from all episodes are converted to a log-step, based on the episode and step number. In such case, all steps can be plotted on a single series.

In [19]:
class BatchLogger():
    def __init__(self, experiment_logger, episode_multiplier: int = 1000) -> None:
        self.exp = experiment_logger
        self.episode_multiplier: int = episode_multiplier
        self.log_step: int = 0
    
    def step(self, episode: int, step: int) -> None:
        self.episode = episode
        self.log_step = self.episode_multiplier * episode + step
    
    def __call__(self, key: str, value: Union[float, str]):
        self.exp[key].log(value, step=self.log_step)
        
    def __setitem__(self, key: str, value: Union[float, str]):
        self.exp[key].log(value, step=self.log_step)

In [20]:
log = BatchLogger(exp)

## Hyperparameters

In [77]:
def parameters():
    # Environment Parameters
    environment_parameters = {
        'grid_size': (4, 4),
    }

    # Hyperparameters
    hyperparameters = {
        'fem_task_pool_size': 32,
        'measurement_pooling_radius': 0.,

        'target_update_interval': 100,
        'optimization_iter': 10,
        'experience_replay_size': 10000,
        'replay_batch_size': 32,
        'lr': .001,
        'discount_factor': .9,
        'explore_factor_initial': 1.,
        'explore_factor_minimal': 0.05,
        'explore_factor_halflife': 2000.,

        'final_threshold': 0.01,
        'final_step_threshold': 100,

        'max_episode': 1000,
        'max_step_per_episode': 1000
    }    

    # exp['Environment_Parameters'] = environment_parameters
    # exp['Hyperparameters'] = hyperparameters

    return environment_parameters, hyperparameters


ep, hp = parameters()

ep = SimpleNamespace(**ep)
hp = SimpleNamespace(**hp)

## Reinforcement Learning Environment

### Interfaces

In [74]:
class State:
    def __init__(self) -> None:
        self._dict: Dict[str, Any] = dict()
        self.__getitem__ = self._dict.__getitem__
        self.__setitem__ = self._dict.__setitem__
        self.__delitem__ = self._dict.__delitem__
        self.__len__ = self._dict.__len__
        
    def __getitem__(self, key): return self._dict[key]
    
    def __setitem__(self, key, value): self._dict[key] = value
    
    def __len__(self): return len(self._dict)
        
    def step(self, action: Action) -> State: 
        return action(copy.deepcopy(self))

    def to_tensor(self) -> torch.Tensor: raise NotImplementedError

In [75]:
Action = Callable[[State], State]

In [76]:
class Environment:    
    def __init__(self) -> None:
        self._state: State = None
        self._action_space: List[Action] = list()
        
    @property
    def state(self) -> State: return self._state
    
    @property
    def action_space(self) -> List[Action]: return self._action_space        

    def action_count(self) -> int: return len(self._action_space)        
    
    def reset(self) -> None: raise NotImplementedError
        
    def step(self, action_index: int) -> None: raise NotImplementedError


### Data Classes

In [78]:
class Transition(NamedTuple):
    state: State
    action_index: int
    reward: float
    next_state: State

In [79]:
class TurnableGridState(State):
    def __init__(self) -> None:
        super().__init__()
        self._dict['angle_matrix'] = torch.zeros(ep.grid_size)
    
    def to_tensor(self) -> torch.Tensor:
        return self._dict['angle_matrix'].flatten()

In [106]:
class TurnableGridEnvironment(Environment):    
    def __init__(self) -> None:
        super().__init__()
        self._state = TurnableGridState()
        
        def angle_matrix_action(i, j, modifier):
            def action(state):
                state['angle_matrix'][i, j] += modifier
                return state
            return action
        
        for i in range(ep.grid_size[0]):
            for j in range(ep.grid_size[1]):
                for modifier in [-15, 15]:
                    self._action_space.append(angle_matrix_action(i, j, modifier))

    def reset(self) -> None: 
        self._state = TurnableGridState()
        
    def step(self, action_index: int) -> None: 
        action = self._action_space[action_index]
        self._state = self._state.step(action)

### Data Classes

In [23]:
# Data Classes

class TaskStatus(Enum):
    Pending = 0
    Running = 1
    Successful = 2
    Skipped = -1
    Failed = -2


# A FEMTask instance represent the FEM simulation of a RL transition
@dataclass(order=True)
class FEMTask():
    episode: int = field(compare=False)
    step: int
    state: State = field(compare=False)
    state_key: str = field(compare=False)
    action_index: int = field(compare=False)
    next_state: State = field(compare=False)
    next_state_key: str = field(compare=False)


# A FEMSubtask instance represent the FEM simulation of a single state
@dataclass()
class FEMSubtaskLog():
    fem_worker: str = None
    queue_time: timedelta = None
    mesh_time: Union[str, timedelta] = None
    solve_time: Union[str, timedelta] = None
    error_msg: str = ''


@dataclass(order=True)
class FEMSubtask():
    episode: int = field(compare=False)
    step: int
    state: State = field(compare=False)
    state_key: str = field(compare=False)
    successful: bool = field(default=False, compare=False, init=False)
    file_id: str = field(default='', compare=False, init=False)
    vtk_path: str = field(default='', compare=False, init=False)
    measurement: Tensor = field(default=None, compare=False, init=False, repr=False)
    log: FEMSubtaskLog = field(default=FEMSubtaskLog(), compare=False, init=False)


@dataclass(order=True)
class FEMTaskResult():
    episode: int = field(compare=False)
    step: int
    state: State = field(compare=False)
    state_measurement: Tensor = field(compare=False, repr=False)
    action_index: int = field(compare=False)
    next_state: State = field(compare=False)
    next_state_measurement: Tensor = field(compare=False, repr=False)

### GridHoleBoardEnv Class

In [24]:
class GridHoleBoardEnv():
    def __init__(self,
                 size: Tuple[float, float],
                 grid_size: Tuple[int, int]) -> None:
        self.size: Tuple[float, float] = size
        self.grid_size: Tuple[int, int] = grid_size
        self.cell_size: Tuple[float, float] = (size[0] / grid_size[0], size[1] / grid_size[1])

        self.hole_size_limit: Tuple[float, float] = (0., min(self.cell_size) / 2 - 1)

        # (x, y) -> size
        self.holes: Tensor = torch.zeros(self.grid_size, device=cuda)

        # (x, y) -> (x_coord, y_coord)
        self.holes_center: Tensor = torch.zeros((*self.grid_size, 2))

        self.action_space: List[Action] = list()
        self.action_indices: Dict[Action, int] = dict()

        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                self.holes_center[x, y, 0] = (x + 0.5) * self.cell_size[0]
                self.holes_center[x, y, 1] = (y + 0.5) * self.cell_size[1]
                action = ((x, y), 0.5)
                self.action_space.append(action)
                self.action_indices[action] = self.action_space.index(action)
                action = ((x, y), -0.5)
                self.action_space.append(action)
                self.action_indices[action] = self.action_space.index(action)

        self.valid_action_mask: BoolTensor = torch.tensor([True] * len(self.action_space))

    def reset(self) -> None:
        self.holes: Tensor = torch.ones(self.grid_size, device=cuda)

    def random(self) -> None:
        self.holes: Tensor = torch.rand(self.grid_size, device=cuda) * 3

    def step(self, action_index: int) -> None:
        action = self.action_space[action_index]
        (x, y), size_change = action

        new_size = self.holes[x, y] + size_change
        self.holes[x, y] = torch.clamp(new_size, *self.hole_size_limit)
        
#         # size-
#         if size_change < 0:
#             if new_size < self.hole_size_limit[0]:    # size=[0, lower_bound]
#                 new_size = 0.    # Jump to size 0
#             if new_size <= 0.:
#                 self.valid_action_mask[action_index] = False
#             # Enable size+ action
#             self.valid_action_mask[self.action_indices[(x, y), -size_change]] = True
                
#         # size+
#         elif size_change > 0:
#             if new_size >= self.hole_size_limit[1]:    # [upper_bound, +inf)
#                 new_size = self.hole_size_limit[1]
#                 self.valid_action_mask[action_index] = False
#             elif new_size < self.hole_size_limit[0]:   # (0, lower_bound)
#                 new_size = self.hole_size_limit[0]    # Jump to lower_bound
#             # Enable size- action
#             self.valid_action_mask[self.action_indices[(x, y), -size_change]] = True
            
#         self.holes[x, y] = new_size
        
        # if self.holes[x, y] < self.hole_size_limit[0]:
        #     if size_change <= 0:
        #         self.holes[x, y] = 0.
        #     else:
        #         self.holes[x, y] = self.hole_size_limit[0]
        # else:
        #     self.holes[x, y] = torch.clamp(self.holes[x, y], *self.hole_size_limit)

    def get_state(self) -> State:
        return self.holes.clone()

## Multiprocess FEM

### FEMConfig Class

### GridHoleBoardFEMTask

In [25]:
def FEMWorkerInitializer(mp_result_queue: Queue,
                         mp_file_cache: Dict[str, str],
                         mp_measurement_cache: Dict[str, Tensor],
                         env_board_size: Tuple[float, float],
                         env_grid_size: Tuple[int, int],
                         env_holes_center: Tensor,
                         env_measuring_spots: Tensor,
                         env_pooling_radius: float,
                         shared_fem_parameters: SimpleNamespace) -> None:
    global result_queue, file_cache, measurement_cache
    global board_size, grid_size, holes_center
    global measuring_spots, pooling_radius, fem_parameters

    result_queue = mp_result_queue
    file_cache = mp_file_cache
    measurement_cache = mp_measurement_cache
    board_size = env_board_size
    grid_size = env_grid_size
    holes_center = env_holes_center
    measuring_spots = env_measuring_spots
    pooling_radius = env_pooling_radius
    fem_parameters = shared_fem_parameters

In [26]:
def GridHoleBoardFEMTask(subtask: FEMSubtask, queue_time: datetime) -> None:
    import os
    import sys
    import multiprocess

    devnull = open(os.devnull, 'w')
    oldstdout_fno = os.dup(sys.stdout.fileno())
    os.dup2(devnull.fileno(), 1)

    import uuid
    from datetime import datetime

    import getfem as gf
    import numpy as np

    import torch

    import pyvista as pv
    
    from scipy.spatial import cKDTree

    
    # Creating references for shorter notation
    state = subtask.state
    state_key = subtask.state_key

    fp = fem_parameters

    # Timer
    start_time = datetime.now()
    queue_time = datetime.now() - queue_time

    fem_worker = multiprocess.current_process().name
    
    subtask.log.queue_time = queue_time
    subtask.log.fem_worker = fem_worker
    
    failed = False

    # Generate Mesh
    try:
        board = gf.MesherObject('rectangle', [0., 0.], list(board_size))
        holes: List[gf.MesherObject] = list()

        for x in range(grid_size[0]):
            for y in range(grid_size[1]):
                center = holes_center[x, y].tolist()
                size = state[x, y].item()

                if size < 0.01 * fp.element_diameter: continue

                holes.append(gf.MesherObject('ball', center, size))

        if holes:
            holes_union = gf.MesherObject('union', *holes)
            mesher = gf.MesherObject('set minus', board, holes_union)
        else:
            mesher = board

        #print('Beginning mesh generation')
        gf.util('trace level', 0)   # No trace for mesh generation
        mesh = gf.Mesh('generate', mesher, fp.element_diameter, 2, measuring_spots.T)

        boundary: Dict[str, int] = dict()

        # Boundary of the holes
        boundary['HOLE_BOUND'] = 1
        mesh.set_region(boundary['HOLE_BOUND'],
                        mesh.outer_faces_in_box([1., 1.],
                                                [board_size[0] - 1, board_size[1] - 1]))

        boundary['LEFT_BOUND'] = 2
        mesh.set_region(boundary['LEFT_BOUND'], mesh.outer_faces_with_direction([-1., 0.], 0.01))

        boundary['RIGHT_BOUND'] = 3
        mesh.set_region(boundary['RIGHT_BOUND'], mesh.outer_faces_with_direction([ 1., 0.], 0.01))

        boundary['TOP_BOUND'] = 4
        mesh.set_region(boundary['TOP_BOUND'], mesh.outer_faces_with_direction([0.,  1.], 0.01))

        boundary['BOTTOM_BOUND'] = 5
        mesh.set_region(boundary['BOTTOM_BOUND'], mesh.outer_faces_with_direction([0., -1.], 0.01))

        mesh.region_subtract( boundary['RIGHT_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(  boundary['LEFT_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(   boundary['TOP_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(boundary['BOTTOM_BOUND'], boundary['HOLE_BOUND'])

        region_id = 7
        for x in range(grid_size[0]):
            for y in range(grid_size[1]):
                center = holes_center[x, y].tolist()
                size = state[x, y].item()
                bound_key = f'HOLE{x}_{y}_BOUND'
                boundary[bound_key] = region_id
                mesh.set_region(boundary[bound_key],
                                mesh.outer_faces_in_ball(center, size + 0.01 * fp.element_diameter))
                if region_id == 7:
                    boundary['HOLE_UNION_BOUND'] = 6
                    mesh.set_region(boundary['HOLE_UNION_BOUND'],
                                mesh.outer_faces_in_ball(center, size + 0.01 * fp.element_diameter))
                else:
                    mesh.region_merge(boundary['HOLE_UNION_BOUND'], boundary[bound_key])
                region_id += 1

        np.testing.assert_array_equal(mesh.region(boundary['HOLE_BOUND']),
                                      mesh.region(boundary['HOLE_UNION_BOUND']))
        
        mesh_time = datetime.now() - start_time
        subtask.log.mesh_time = mesh_time
    except Exception as err:
        subtask.log.mesh_time = 'Failed'
        subtask.log.solve_time = 'Skipped'
        subtask.log.error_msg += str(err) + '\n'
        
        failed = True

    # Solve
    if not failed:
        try:


            #
            # Definition of finite elements methods and integration method
            #

            mfu = gf.MeshFem(mesh, 2)  # Finite element for the elastic displacement
            mfu.set_classical_fem(fp.elements_degree)
            mft = gf.MeshFem(mesh, 1)  # Finite element for temperature and electrical field
            mft.set_classical_fem(fp.elements_degree)
            mfvm = gf.MeshFem(mesh, 1) # Finite element for Von Mises stress interpolation
            mfvm.set_classical_discontinuous_fem(fp.elements_degree)
            mim = gf.MeshIm(mesh, fp.elements_degree * 2)   # Integration method

            md=gf.Model('real');
            md.add_fem_variable('u', mfu)       # Displacement of the structure
            md.add_fem_variable('theta', mft)   # Temperature
            md.add_fem_variable('V', mft)       # Electric potential

            # Membrane elastic deformation
            md.add_initialized_data('cmu', [fp.cmu])
            md.add_initialized_data('clambdastar', [fp.clambdastar])
            md.add_isotropic_linearized_elasticity_brick(mim, 'u', 'clambdastar', 'cmu')

            md.add_Dirichlet_condition_with_multipliers(mim, 'u', fp.elements_degree - 1, boundary['LEFT_BOUND'])
            md.add_initialized_data('Fdata', [fp.F * fp.epsilon, 0])
            md.add_source_term_brick(mim, 'u', 'Fdata', boundary['RIGHT_BOUND'])

            # Electrical field
            sigmaeps = '(eps/(rho_0*(1+alpha*(theta-T0))))'
            md.add_initialized_data('eps', [fp.epsilon])
            md.add_initialized_data('rho_0', [fp.rho_0])
            md.add_initialized_data('alpha', [fp.alpha])
            md.add_initialized_data('T0', [fp.T0])
            md.add_nonlinear_term(mim, sigmaeps+'*(Grad_V.Grad_Test_V)')
            md.add_Dirichlet_condition_with_multipliers(mim, 'V', fp.elements_degree - 1, boundary['RIGHT_BOUND'])
            md.add_initialized_data('DdataV', [2.])
            md.add_Dirichlet_condition_with_multipliers(mim, 'V', fp.elements_degree - 1, boundary['LEFT_BOUND'], 'DdataV')

            # Thermal problem
            md.add_initialized_data('kaeps', [fp.kappa * fp.epsilon])
            md.add_generic_elliptic_brick(mim, 'theta', 'kaeps')
            md.add_initialized_data('D2', [fp.D * 2])
            md.add_initialized_data('D2airt', [fp.air_temp * fp.D * 2])
            md.add_mass_brick(mim, 'theta', 'D2')
            md.add_source_term_brick(mim, 'theta', 'D2airt')
            md.add_initialized_data('Deps', [fp.D / fp.epsilon])
            md.add_initialized_data('Depsairt', [fp.air_temp * fp.D / fp.epsilon])
            md.add_Fourier_Robin_brick(mim, 'theta', 'Deps', boundary['TOP_BOUND'])
            md.add_source_term_brick(mim, 'theta', 'Depsairt', boundary['TOP_BOUND'])
            md.add_Fourier_Robin_brick(mim, 'theta', 'Deps', boundary['BOTTOM_BOUND'])
            md.add_source_term_brick(mim, 'theta', 'Depsairt', boundary['BOTTOM_BOUND'])

            # Joule heating term
            md.add_nonlinear_term(mim, '-' + sigmaeps + '*Norm_sqr(Grad_V)*Test_theta')

            # Thermal expansion term
            md.add_initialized_data('beta', [fp.alpha_th * fp.E / (1 - 2 * fp.nu)])
            md.add_linear_term(mim, 'beta*(T0-theta)*Trace(Grad_Test_u)')

            #
            # Model solve
            #

            md.disable_variable('u')
            md.solve('max_res', 1E-9, 'max_iter', 100)

            #
            # Solution export
            #
            THETA = md.variable('theta')

            file_id = uuid.uuid4()

            vtk_path = f'temp/{file_id}.vtk'
            mft.export_to_vtk(vtk_path, mft, THETA, 'Temperature')

            solve_time = datetime.now() - start_time - mesh_time
            subtask.log.solve_time = solve_time
        except Exception as err:
            subtask.log.solve_time = 'Failed'
            subtask.log.error_msg += str(err) + '\n'

            failed = True

    if not failed:
        subtask.file_id = file_id
        subtask.vtk_path = vtk_path
        
        file_cache[state_key] = vtk_path

        m = pv.read(vtk_path)

        tree = cKDTree(m.points.astype(np.double))
        spots_3d = torch.hstack([measuring_spots, torch.zeros(measuring_spots.size()[0]).unsqueeze(1)])

        if pooling_radius == 0.:
            _, point_indices = tree.query(spots_3d)
            values = np.array(m.active_scalars)[point_indices]
        else:
            point_pools = tree.query_ball_point(spots_3d, pooling_radius)
            values = []
            for pool in point_pools:
                values.append(np.array(m.active_scalars)[pool].mean())

        measurement = torch.tensor(values)
        measurement_cache[state_key] = measurement
        subtask.measurement = measurement

    os.dup2(oldstdout_fno, 1)
    devnull.close()

    result_queue.put(subtask)

### RenderTask

### GridHoleBoardFEMSimulator

In [27]:
class GridHoleBoardFEMSimulator():
    def __init__(self,
                 environment: GridHoleBoardEnv,
                 *,
                 fem_task_pool_size: int = hp.fem_task_pool_size,
                 measuring_spots: List[List[float]] = ep.measuring_spots,
                 pooling_radius: float = hp.measurement_pooling_radius) -> None:
        self.environment = environment
        self.board_size = environment.size
        self.grid_size = environment.grid_size
        self.holes_center = environment.holes_center.cpu()
        self.measuring_spots = torch.tensor(measuring_spots)
        self.pooling_radius = pooling_radius

        self.mp_manager: Manager = Manager()

        self.mp_result_queue: Queue = self.mp_manager.Queue()

        self.mp_file_cache: Dict[str, str] = self.mp_manager.dict()
        self.mp_measurement_cache: Dict[str, Tensor] = self.mp_manager.dict()

        self.mp_fem_task_pool_size: int = fem_task_pool_size
        self.mp_fem_task_pool: Pool = Pool(fem_task_pool_size,
                                           initializer=FEMWorkerInitializer,
                                           initargs=(self.mp_result_queue,
                                                     self.mp_file_cache,
                                                     self.mp_measurement_cache,
                                                     self.board_size,
                                                     self.grid_size,
                                                     self.holes_center,
                                                     self.measuring_spots,
                                                     self.pooling_radius,
                                                     fp))

        self.task_list: List[FEMTask] = list()
        self.subtask_status: Dict[str, TaskStatus] = dict()    # state_key -> task_status

        self.cache_stat: Dict[str, int] = {'newly_solved': 0, 'same_episode': 0, 'value_cache': 0, 'file_cache': 0}

    def clear_image_cache(self):
        self.mp_measurement_cache.clear()

    def terminate_current_queue(self):
        self.mp_fem_task_pool.terminate()

    def reset_task_pool(self):
        exp['cache_stat'].log(str(self.cache_stat))
        self.cache_stat = {'newly_solved': 0, 'same_episode': 0, 'value_cache': 0, 'file_cache': 0}
        exp['cache_stat'].log(str({'subtask_count': len(self.subtask_status), 
                                   'value_cache_size': len(self.mp_measurement_cache), 
                                   'file_cache_size': len(self.mp_file_cache)}))
        self.task_list.clear()
        self.subtask_status.clear()
        self.mp_fem_task_pool = Pool(self.mp_fem_task_pool_size,
                                     initializer=FEMWorkerInitializer,
                                     initargs=(self.mp_result_queue,
                                               self.mp_file_cache,
                                               self.mp_measurement_cache,
                                               self.board_size,
                                               self.grid_size,
                                               self.holes_center,
                                               self.measuring_spots,
                                               self.pooling_radius,
                                               fp))

    def queue_subtask(self, subtask: FEMSubtask) -> None:
        state_key = subtask.state_key

        if state_key not in self.subtask_status:
            # Fallback to image cache
            if state_key in self.mp_measurement_cache:
                exp['state_cache'].log(f'From ImgCache: {state_key}')
                self.cache_stat['value_cache'] += 1
                self.subtask_status[state_key] = TaskStatus.Successful
            else:
                # Fallback to file cache
                if state_key in self.mp_file_cache:
                    exp['state_cache'].log(f'From FileCache: {state_key}')
                    self.cache_stat['file_cache'] += 1
                    self.load_from_cache(state_key)
                else:
                    # Run FEM for unseen state
                    exp['state_cache'].log(f'Solve New: {state_key}')
                    self.cache_stat['newly_solved'] += 1
                    self.subtask_status[state_key] = TaskStatus.Pending
                    queue_time = datetime.now()

                    def err_callback(err):
                        with open('fem_err.log', 'a') as log:
                            print('========', file=log)
                            print(datetime.now(), file=log)
                            print(subtask, file=log)
                            print(err, file=log)

                    self.mp_fem_task_pool.apply_async(func=GridHoleBoardFEMTask,
                                                      args=(subtask, queue_time),
                                                      error_callback=err_callback)
        else:
            exp['state_cache'].log(f'From Same Episode: {state_key}')
            self.cache_stat['same_episode'] += 1

    def load_from_cache(self, state_key: str):
        vtk_path = self.mp_file_cache[state_key]
        
        m = pv.read(vtk_path)

        tree = cKDTree(m.points.astype(np.double))
        spots_3d = torch.hstack([self.measuring_spots, torch.zeros(self.measuring_spots.size()[0]).unsqueeze(1)])

        if self.pooling_radius == 0.:
            _, point_indices = tree.query(spots_3d)
            values = np.array(m.active_scalars)[point_indices]
        else:
            point_pools = tree.query_ball_point(spots_3d, self.pooling_radius)
            values = []
            for pool in point_pools:
                values.append(np.array(m.active_scalars)[pool].mean())
                
        self.mp_measurement_cache[state_key] = torch.tensor(values)
        self.subtask_status[state_key] = TaskStatus.Successful

    def queue_transition(self, episode: int, step: int,
                         state: State, action_index: int, next_state: State) -> None:

        state_key = str(state.flatten().tolist())
        next_state_key = str(next_state.flatten().tolist())

        self.queue_subtask(FEMSubtask(episode, step, state, state_key))
        self.queue_subtask(FEMSubtask(episode, step, next_state, next_state_key))

        task = FEMTask(episode, step, state, state_key, action_index, next_state, next_state_key)

        self.task_list.append(task)

    def get_transition_result(self) -> Generator[FEMTaskResult, None, None]:
        for task in self.task_list:
            # Retrieve more subtask results if there is pending subtask in current task
            while self.subtask_status[task.state_key] is TaskStatus.Pending \
                    or self.subtask_status[task.next_state_key] is TaskStatus.Pending:
                self.retrieve_next_task_result()

            result = FEMTaskResult(task.episode, task.step,
                                   task.state, None,
                                   task.action_index,
                                   task.next_state, None)

            if self.subtask_status[task.state_key] is TaskStatus.Successful:
                result.state_measurement = self.mp_measurement_cache[task.state_key]

            if self.subtask_status[task.next_state_key] is TaskStatus.Successful:
                result.next_state_measurement = self.mp_measurement_cache[task.next_state_key]

            yield result

    def retrieve_next_task_result(self) -> None:
        result: FEMSubtask = fem.mp_result_queue.get(timeout=30)
        exp['measurement'].log(result.measurement)
        queue_time = 0. if isinstance(result.log.queue_time, str) else result.log.queue_time.total_seconds()
        mesh_time = 0. if isinstance(result.log.mesh_time, str) else result.log.mesh_time.total_seconds()
        solve_time = 0. if isinstance(result.log.solve_time, str) else result.log.solve_time.total_seconds()
        
        exp['fem_times/queue_time'].log(queue_time)
        exp['fem_times/mesh_time'].log(mesh_time)
        exp['fem_times/solve_time'].log(solve_time)

        if result.successful:
            self.subtask_status[result.state_key] = TaskStatus.Successful
        else:
            self.subtask_status[result.state_key] = TaskStatus.Failed

## DQN

### FEM-based Reward & Final Function

In [28]:
class FEMReward():
    def __init__(self,
                 environment: GridHoleBoardEnv,
                 target: Tensor,
                 max_step: int = hp.max_step_per_episode,
                 goal_reward: float = 10000.,
                 final_state_error_threshold: float = 1,
                 loop_penalty: float = -10000) -> None:
        self.environment: GridHoleBoardEnv = environment

        self.target_spot = target
        exp['target_value'] = target

        self.loss = nn.MSELoss(reduction='sum')


        self.goal_reward = goal_reward
        self.final_state_error_threshold = final_state_error_threshold
        self.loop_penalty = loop_penalty


    def __call__(self, result: FEMTaskResult) -> Tuple[Optional[float], bool]:
        reward = 0.
        final = False
        
        if result.state_measurement is None:
            return None, False

        error = float(self.loss(result.state_measurement, self.target_spot))
        final = error <= self.final_state_error_threshold
        
        # If the state not changed after apply the action(s_n -> s_n),
        # the loop penalty is applied
        if torch.equal(result.state, result.next_state):
            reward += self.loop_penalty

        if result.next_state_measurement is None:
            return reward, final

        next_error = float(self.loss(result.next_state_measurement, self.target_spot))
        
        # Reward decreasing error from state to next state
        reward += error - next_error
        
        # Reward extra points if next state is final
        if next_error <= self.final_state_error_threshold:
            reward += self.goal_reward
            
        log.step(result.episode, result.step)
        log['rewards/reward'] = reward
        log['rewards/error'] = error

        return reward, final

### Network Definition

In [29]:
# Network Container
class Model():
    def __init__(self, network: nn.Module, loss_func: _Loss, optimizer: Optimizer):
        self.network = network
        self.loss_func = loss_func
        self.optimizer = optimizer

    def __call__(self, network_input: Tensor) -> Tensor:
        return self.network(network_input)

In [30]:
def QNet(state_size: int = 12, action_number: int = 24, target_network: bool = False):
    net = nn.Sequential(
        nn.Linear(state_size, 100, device=cuda),
        nn.ReLU(),
        nn.Linear(100, 200, device=cuda),
        nn.ReLU(),
        nn.Linear(200, action_number, device=cuda),
    )
    if target_network:
        return Model(network=net, loss_func=None, optimizer=None)
    else:
        exp['Network'] = str(torchinfo.summary(net, input_size=(32, state_size), 
                                               device=cuda, verbose=0))
        return Model(network=net, loss_func=nn.SmoothL1Loss(), optimizer=torch.optim.Adam(net.parameters(), 0.001))

### Replay Memory Class

In [31]:
class ReplayMemory():
    def __init__(self, capacity):
        self.memory: deque = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

### Agent Class

In [32]:
class Agent():
    def __init__(self,
                 environment: GridHoleBoardEnv,
                 fem_simulator: GridHoleBoardFEMSimulator,
                 fem_reward_func: Callable[[Tensor, int, Tensor], float],
                 q_network: nn.Module,
                 target_network: nn.Module,
                 target_update_interval: int = hp.target_update_interval,
                 optimization_iter: int = hp.optimization_iter,
                 max_step_per_episode: int = hp.max_step_per_episode,
                 experience_replay: ReplayMemory = ReplayMemory(hp.experience_replay_size),
                 replay_batch_size: int = hp.replay_batch_size,
                 discount_factor: float = hp.discount_factor,
                 explore_factor_initial: float = hp.explore_factor_initial,
                 explore_factor_minimal: float = hp.explore_factor_minimal,
                 explore_factor_halflife: float = hp.explore_factor_halflife,
                ) -> None:
        self.environment: GridHoleBoardEnv = environment
        self.fem_simulator: GridHoleBoardFEMSimulator = fem_simulator
        self.fem_reward_func: Callable[[Tensor, int, Tensor], float] = fem_reward_func

        self.q_network: Model = q_network
        self.target_network: Model = target_network
        self.target_update_interval: int = target_update_interval

        self.optimization_iter: int = optimization_iter
        self.max_step_per_episode: int = max_step_per_episode
        self.experience_replay: ReplayMemory = experience_replay
        self.replay_batch_size: int = replay_batch_size

        self.discount_factor: float = discount_factor
        self.explore_factor_initial: float = explore_factor_initial
        self.explore_factor_minimal: float = explore_factor_minimal
        self.explore_factor_halflife: float = explore_factor_halflife
        self.explored_step: int = 0


    def select_action(self, disable_explore: bool = False) -> Tuple[State, int]:
        state = self.environment.get_state()
        explore_factor = max(self.explore_factor_initial 
                             * (0.5 ** (self.explored_step / self.explore_factor_halflife)), 
                             self.explore_factor_minimal)
        exp['explore_factor'].log(explore_factor, step=self.explored_step)
        
        if random.random() > explore_factor or disable_explore:
            prediction = self.q_network(state.flatten())
            
            
            action_index = prediction.flatten().masked_fill(self.environment.valid_action_mask, -1000000.).argmax().item()
            exp['actions'].log(f'Action {action_index}')
            exp['actions'].log(prediction.flatten())
            exp['actions'].log(self.environment.valid_action_mask)
            exp['actions'].log(prediction.flatten().masked_fill(self.environment.valid_action_mask, -1000000.))
            
            # action_index = prediction.argmax().item()
            action_type = 'Pred'
        else:
            action_index = random.randrange(len(self.environment.action_space))
            action_type = 'Rand'
        self.explored_step += 1
        return state, action_index, action_type

    def step(self, episode: int, step: int, disable_explore: bool = False) -> bool:
        state, action_index, action_type = self.select_action(disable_explore)
        
        # exp['actions'].log(str(self.environment.valid_action_mask).replace('rue', '').replace('alse', ''))
        if action_type == 'Pred':
            exp['actions'].log(self.environment.valid_action_mask[action_index])

        self.environment.step(action_index)
        next_state = self.environment.get_state()
        
        action = self.environment.action_space[action_index]
        action_log = f'{self.explored_step} {action_type} {action} ({state[action[0]]} => {next_state[action[0]]})'
        exp['actions'].log(action_log)
        
        self.fem_simulator.queue_transition(episode, step, state, action_index, next_state)

        # Return True if the next state is final
        return step >= self.max_step_per_episode

    def complete_fem_rewarding(self, episode: int) -> None:
        fem = self.fem_simulator    # Create local reference to shorter statements
        print('Start computing rewards based on completed FEM simulation')
        i = 0
        
        for result in fem.get_transition_result():
            reward, final = self.fem_reward_func(result)
            if reward is None:
                continue

            if final:
                next_state = None
            else:
                next_state = result.next_state

            self.experience_replay.push(result.state, result.action_index, reward, next_state)
            
            i += 1
            if i % 100:
                print(f'{i}/{self.max_step_per_episode}', end='\r')
            
            if final:
                fem.terminate_current_queue()
                print('Early stop triggered')
                return


    def optimize(self) -> None:
        if len(self.experience_replay) < self.replay_batch_size: return

        samples = self.experience_replay.sample(self.replay_batch_size)
        batch = Transition(*zip(*samples))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                  batch.next_state)), device=cuda, dtype=torch.bool)
        non_final_next_states = torch.stack([s.flatten() for s in batch.next_state
                                                        if s is not None])

        state_batch = torch.stack([s.flatten() for s in batch.state])
        action_batch = torch.tensor(batch.action_index, device=cuda).unsqueeze(1)
        reward_batch = torch.tensor(batch.reward, device=cuda)

        state_action_values = self.q_network(state_batch).gather(1, action_batch)


        next_state_values = torch.zeros(self.replay_batch_size, device=cuda)
        next_state_values[non_final_mask] = self.target_network(non_final_next_states).max(1)[0].detach()

        expected_state_action_values = (next_state_values * self.discount_factor) + reward_batch

        loss = self.q_network.loss_func(state_action_values, expected_state_action_values.unsqueeze(1))
        exp['optimization_loss'].log(float(loss))
        self.q_network.optimizer.zero_grad()
        loss.backward()
        for param in self.q_network.network.parameters():
            param.grad.data.clamp_(-1, 1)
        self.q_network.optimizer.step()

    def update_target_network(self) -> None:
        self.target_network.network.load_state_dict(self.q_network.network.state_dict())

    def train(self, episode: int = hp.max_episode) -> None:
        for e in range(episode):
            print(f'Episode {e}')
            start_time = datetime.now()
            self.environment.reset()
            # Use fixed number of steps in each episode
            for step_num in range(self.max_step_per_episode):
                self.step(e, step_num)
            self.complete_fem_rewarding(e)
            for i in range(self.optimization_iter):
                self.optimize()
            self.fem_simulator.reset_task_pool()
            self.fem_simulator.clear_image_cache()
            if e % self.target_update_interval == 0:
                # print('Target network update                 ')
                self.update_target_network()

            end_time = datetime.now()
            print(f'Episode {e} completed in {end_time - start_time}')

    def generate(self, step: int = 100) -> State:
        pass
        # TODO: rewrite
        # self.environment.reset()
        # for i in range(step):
        #     if self.step(100000000, i):
        #         print(f'Completed in {step} steps                    ', end='\r\n')
        #         break
        # return self.environment.get_state(), i

## Training

In [None]:
env = GridHoleBoardEnv(size=ep.size, grid_size=ep.grid_size)

fem = GridHoleBoardFEMSimulator(env, measuring_spots=ep.measuring_spots, 
                                pooling_radius=hp.measurement_pooling_radius)

reward_func = FEMReward(env, torch.tensor([30., 30., 30., 30., 30., 30., 30., 30.]))

agent = Agent(env, fem, reward_func, QNet(), QNet(target_network=True))

In [None]:
#print(agent.generate())

agent.train(1000)

#print(agent.generate())

Episode 0
Start computing rewards based on completed FEM simulation
Episode 0 completed in 0:04:49.626524
Episode 1
Start computing rewards based on completed FEM simulation
Episode 1 completed in 0:03:47.057468
Episode 2
Start computing rewards based on completed FEM simulation
Episode 2 completed in 0:04:11.333822
Episode 3
Start computing rewards based on completed FEM simulation
Episode 3 completed in 0:03:16.926786
Episode 4
Start computing rewards based on completed FEM simulation
Episode 4 completed in 0:02:00.740221
Episode 5
Start computing rewards based on completed FEM simulation
Episode 5 completed in 0:01:43.419541
Episode 6
Start computing rewards based on completed FEM simulation
Episode 6 completed in 0:01:07.019182
Episode 7
Start computing rewards based on completed FEM simulation
Episode 7 completed in 0:01:18.984613
Episode 8
Start computing rewards based on completed FEM simulation
Episode 8 completed in 0:00:39.376069
Episode 9
Start computing rewards based on com



Episode 140 completed in 0:00:49.976751
Episode 141
Start computing rewards based on completed FEM simulation
Episode 141 completed in 0:00:57.193732
Episode 142
Start computing rewards based on completed FEM simulation
Episode 142 completed in 0:01:09.235314
Episode 143
Start computing rewards based on completed FEM simulation
Episode 143 completed in 0:00:43.082637
Episode 144
Start computing rewards based on completed FEM simulation
Episode 144 completed in 0:00:56.241977
Episode 145
Start computing rewards based on completed FEM simulation
Episode 145 completed in 0:00:41.646944
Episode 146
Start computing rewards based on completed FEM simulation
Episode 146 completed in 0:00:50.599988
Episode 147
Start computing rewards based on completed FEM simulation
Episode 147 completed in 0:00:56.950024
Episode 148
Start computing rewards based on completed FEM simulation
Episode 148 completed in 0:00:53.302917
Episode 149
Start computing rewards based on completed FEM simulation
Episode 14



Episode 322 completed in 0:01:06.560471
Episode 323
Start computing rewards based on completed FEM simulation
Episode 323 completed in 0:00:32.085999
Episode 324
Start computing rewards based on completed FEM simulation
Episode 324 completed in 0:00:46.746361
Episode 325
Start computing rewards based on completed FEM simulation
Episode 325 completed in 0:01:03.271094
Episode 326
Start computing rewards based on completed FEM simulation
Episode 326 completed in 0:00:49.704407
Episode 327
Start computing rewards based on completed FEM simulation
Episode 327 completed in 0:00:42.141635
Episode 328
Start computing rewards based on completed FEM simulation
Episode 328 completed in 0:00:38.973609
Episode 329
Start computing rewards based on completed FEM simulation
Episode 329 completed in 0:00:46.857976
Episode 330
Start computing rewards based on completed FEM simulation
Episode 330 completed in 0:00:48.317233
Episode 331
Start computing rewards based on completed FEM simulation
Episode 33



Episode 417 completed in 0:00:54.973775
Episode 418
Start computing rewards based on completed FEM simulation
Episode 418 completed in 0:00:45.520235
Episode 419
Start computing rewards based on completed FEM simulation
Episode 419 completed in 0:00:36.856806
Episode 420
Start computing rewards based on completed FEM simulation
Episode 420 completed in 0:00:48.679983
Episode 421
Start computing rewards based on completed FEM simulation
Episode 421 completed in 0:00:32.280533
Episode 422
Start computing rewards based on completed FEM simulation
Episode 422 completed in 0:00:44.101568
Episode 423
Start computing rewards based on completed FEM simulation
Episode 423 completed in 0:00:57.543735
Episode 424
Start computing rewards based on completed FEM simulation
Episode 424 completed in 0:00:35.124028
Episode 425
Start computing rewards based on completed FEM simulation
Episode 425 completed in 0:00:54.369072
Episode 426
Start computing rewards based on completed FEM simulation
Episode 42



Episode 533 completed in 0:01:06.628838
Episode 534
Start computing rewards based on completed FEM simulation
Episode 534 completed in 0:00:52.886358
Episode 535
Start computing rewards based on completed FEM simulation
Episode 535 completed in 0:00:40.880449
Episode 536
Start computing rewards based on completed FEM simulation
Episode 536 completed in 0:00:57.746520
Episode 537
Start computing rewards based on completed FEM simulation
Episode 537 completed in 0:00:51.247688
Episode 538
Start computing rewards based on completed FEM simulation
Episode 538 completed in 0:00:51.339772
Episode 539
Start computing rewards based on completed FEM simulation
Episode 539 completed in 0:00:26.531769
Episode 540
Start computing rewards based on completed FEM simulation
Episode 540 completed in 0:00:32.380989
Episode 541
Start computing rewards based on completed FEM simulation
Episode 541 completed in 0:00:33.595030
Episode 542
Start computing rewards based on completed FEM simulation
Episode 54



Episode 557 completed in 0:01:02.820882
Episode 558
Start computing rewards based on completed FEM simulation
Episode 558 completed in 0:00:52.350367
Episode 559
Start computing rewards based on completed FEM simulation
Episode 559 completed in 0:00:37.566721
Episode 560
Start computing rewards based on completed FEM simulation
Episode 560 completed in 0:00:53.159959
Episode 561
Start computing rewards based on completed FEM simulation
Episode 561 completed in 0:00:56.824520
Episode 562
Start computing rewards based on completed FEM simulation
Episode 562 completed in 0:00:36.921134
Episode 563
Start computing rewards based on completed FEM simulation
Episode 563 completed in 0:00:32.932789
Episode 564
Start computing rewards based on completed FEM simulation
Episode 564 completed in 0:00:42.676298
Episode 565
Start computing rewards based on completed FEM simulation
Episode 565 completed in 0:00:31.600945
Episode 566
Start computing rewards based on completed FEM simulation
Episode 56



Episode 568 completed in 0:01:03.541960
Episode 569
Start computing rewards based on completed FEM simulation
Episode 569 completed in 0:01:06.913614
Episode 570
Start computing rewards based on completed FEM simulation
Episode 570 completed in 0:00:44.235814
Episode 571
Start computing rewards based on completed FEM simulation
Episode 571 completed in 0:00:43.512216
Episode 572
Start computing rewards based on completed FEM simulation
Episode 572 completed in 0:00:35.870264
Episode 573
Start computing rewards based on completed FEM simulation
Episode 573 completed in 0:00:33.499247
Episode 574
Start computing rewards based on completed FEM simulation
Episode 574 completed in 0:00:52.303651
Episode 575
Start computing rewards based on completed FEM simulation
Episode 575 completed in 0:00:35.681980
Episode 576
Start computing rewards based on completed FEM simulation
Episode 576 completed in 0:00:50.056549
Episode 577
Start computing rewards based on completed FEM simulation
Episode 57

In [None]:
fig, ax = plt.subplots()
ax.scatter(list(range(len(agent.running_loss))), agent.running_loss, s=1, vmin=0, vmax=2)
plt.show()

In [None]:
exp.stop()

## Manuscript

In [None]:
a = str(123)
a

In [None]:
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
torch.equal(a, b)


In [None]:
cache_stat: Dict[str, int] = {'solved': 0, 'image_cache': 0, 'file_cache': 0}

In [None]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved