# Reinforcement Learning- and FEM-based Inverse Design

## Experiment Logger

In [1]:
import neptune.new as neptune

exp = neptune.init(project="pil-clemson/metamtl-rl",)

https://app.neptune.ai/pil-clemson/metamtl-rl/e/METAMTLRL-81
Remember to stop your run once you’ve finished logging your metadata (https://docs.neptune.ai/api-reference/run#.stop). It will be stopped automatically only when the notebook kernel/interactive console is terminated.


## Import

In [2]:
from typing import Union, Optional, Callable, Any
from typing import Tuple, List, Set, Dict
from typing import NamedTuple
from typing import Generator

In [3]:
from collections import defaultdict, deque
from types import SimpleNamespace
import queue
from queue import PriorityQueue
from enum import Enum

In [4]:
from dataclasses import dataclass, field

In [5]:
import traceback

In [6]:
import os
import sys
import time
from datetime import datetime, timedelta
import random
import math
import itertools
import uuid

Package `multiprocess` is used instead of built-in `multiprocessing`, to support multiprocessing in Jupyter environment.

In [7]:
import multiprocess
from multiprocess import Process, Pool, Queue, Manager

In [8]:
import numpy as np
import pandas as pd
from PIL import Image, ImageChops

In [9]:
import matplotlib.pyplot as plt

In [10]:
import getfem as gf

initializing ...
numthread = 1


In [11]:
import pyvista as pv
from pyvirtualdisplay.display import Display

In [2]:
import torch
from torch import nn

from torch import Tensor

from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from torchvision.transforms import PILToTensor

print('PyTorch version:', torch.__version__)

1.10.0


## Computing Devices

In [13]:
print('CPU Cores:', multiprocess.cpu_count())

56

In [14]:
# Getting all memory using os.popen()
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')  # e.g. 4015976448
mem_gib = mem_bytes/(1024.**3)
print('Memory size:', int(mem_gib), 'GiB')

376.29273986816406

In [15]:
available_gpus = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
print('GPUs:', available_gpus)

['Tesla V100S-PCIE-32GB', 'Tesla V100S-PCIE-32GB']

In [16]:
cuda = torch.device('cpu') if torch.cuda.is_available() else torch.device('cpu')
print('Current computing device:', cuda)

Current computing device: cpu


In [17]:
multiprocess.set_start_method('spawn')

## Batch Logger

The Batch Logger logs value of one step without specify the step number every time.
Also, steps from all episodes are converted to a log-step, based on the episode and step number. In such case, all steps can be plotted on a single series.

In [18]:
class BatchLogger():
    def __init__(self, experiment_logger, episode_multiplier: int = 1000) -> None:
        self.exp = experiment_logger
        self.episode_multiplier: int = episode_multiplier
        self.log_step: int = 0
    
    def step(self, episode: int, step: int) -> None:
        self.episode = episode
        self.log_step = self.episode_multiplier * episode + step
    
    def __call__(self, key: str, value: Union[float, str]):
        self.exp[key].log(value, step=self.log_step)
        
    def __setitem__(self, key: str, value: Union[float, str]):
        self.exp[key].log(value, step=self.log_step)

In [19]:
log = BatchLogger(exp)

## Local Log Cleanup

In [20]:
!bash clean.sh

rm: cannot remove 'debug.log': No such file or directory
rm: cannot remove 'fem_err.log': No such file or directory
rm: cannot remove 'fem.log': No such file or directory
rm: cannot remove 'render_err.log': No such file or directory
rm: cannot remove 'reward.log': No such file or directory


## Hyperparameters & FEM Physics

In [21]:
def parameters():
    # Environment Parameters
    ep = {
        'size': (80, 60),
        'grid_size': (4, 3)
    }

    # Hyperparameters
    hp = {
        'fem_task_pool_size': 32,
        'render_task_pool_size': 8,

        'target_update_interval': 100,
        'optimization_iter': 1,
        'experience_replay_size': 10000,
        'replay_batch_size': 32,
        'lr': .001,
        'discount_factor': .9,
        'explore_factor_start': 1.,
        'explore_factor_end': 0.05,
        'explore_factor_decay': 200.,

        'final_threshold': 0.01,
        'final_step_threshold': 100,

        'max_episode': 1000,
        'max_step_per_episode': 1000
    }

    #Network Parameters, current for log only, no programatic function
    net_params = {
        'state_space_size': 12,
        'layers': [100, 200],
        'action_space_size': 24,
        'activation_func': 'ReLU'
    }




    epsilon = .02       # Thickness of the plate (cm)
    E = 21E6           # Young Modulus (N/cm^2)
    nu = 0.3            # Poisson ratio
    clambda = E * nu / ((1 + nu) * (1 - 2 * nu)) # First Lame coefficient (N/cm^2)
    cmu = E / (2 * (1 + nu))               # Second Lame coefficient (N/cm^2)
    clambdastar = 2 * clambda * cmu / (clambda + 2 * cmu) # Lame coefficient for Plane stress (N/cm^2)
    F = 100E2          # Force density at the right boundary (N/cm^2)
    kappa = 4.         # Thermal conductivity (W/(cm K))
    D = 10.            # Heat transfer coefficient (W/(K cm^2))
    air_temp = 20.     # Temperature of the air in oC.
    alpha_th = 16.6E-6 # Thermal expansion coefficient (/K).
    T0 = 20.           # Reference temperature in oC.
    rho_0 = 1.754E-8   # Resistance temperature coefficient at T0 = 20oC
    alpha = 0.0039     # Second resistance temperature coefficient.

    # FEM Parameters
    fp = {
        #
        # Physical parameters
        #
        'epsilon': epsilon,       # Thickness of the plate (cm)
        'E': E,           # Young Modulus (N/cm^2)
        'nu': nu,           # Poisson ratio
        'clambda': clambda, # First Lame coefficient (N/cm^2)
        'cmu': cmu,               # Second Lame coefficient (N/cm^2)
        'clambdastar': clambdastar, # Lame coefficient for Plane stress (N/cm^2)
        'F': F,          # Force density at the right boundary (N/cm^2)
        'kappa': kappa,         # Thermal conductivity (W/(cm K))
        'D': D,            # Heat transfer coefficient (W/(K cm^2))
        'air_temp': air_temp,     # Temperature of the air in oC.
        'alpha_th': alpha_th, # Thermal expansion coefficient (/K).
        'T0': T0,           # Reference temperature in oC.
        'rho_0': rho_0,   # Resistance temperature coefficient at T0 = 20oC
        'alpha': alpha,     # Second resistance temperature coefficient.

        #
        # Numerical parameters
        #
        'elements_degree': 2,       # Degree of the finite element methods
        'element_diameter': 2
    }

    exp['Environment_Parameters'] = ep
    exp['Hyperparameters'] = hp
    exp['Network'] = net_params
    exp['FEM_Parameters'] = fp

    return ep, hp, fp

ep, hp, fp = parameters()

ep = SimpleNamespace(**ep)
hp = SimpleNamespace(**hp)
fp = SimpleNamespace(**fp)



## Reinforcement Learning Environment

### Data Classes

In [22]:
# Data Classes
State = Tensor

Action = Tuple[Tuple[int, int], float]

class Transition(NamedTuple):
    state: State
    action_index: int
    reward: float
    next_state: State

class TaskStatus(Enum):
    Pending = 0
    Running = 1
    Successful = 2
    Skipped = -1
    Failed = -2

# A FEMTask instance represent the FEM simulation of a RL transition
@dataclass(order=True)
class FEMTask():
    episode: int = field(compare=False)
    step: int
    state: State = field(compare=False)
    state_key: str = field(compare=False)
    action_index: int = field(compare=False)
    next_state: State = field(compare=False)
    next_state_key: str = field(compare=False)



# A FEMSubtask instance represent the FEM simulation of a single state
class FEMSubtaskLog():
    fem_worker: str = None
    render_worker: str = None
    queue_time: timedelta = None
    mesh_time: Union[str, timedelta] = None
    solve_time: Union[str, timedelta] = None
    render_time: Union[str, timedelta] = None
    error_msg: str = ''

@dataclass(order=True)
class FEMSubtask():
    episode: int = field(compare=False)
    step: int
    state: State = field(compare=False)
    state_key: str = field(compare=False)
    successful: bool = field(default=False, compare=False, init=False)
    file_id: str = field(default='', compare=False, init=False)
    vtk_path: str = field(default='', compare=False, init=False)
    img_path: str = field(default='', compare=False, init=False)
    img_tensor: Tensor = field(default=None, compare=False, init=False, repr=False)
    log: FEMSubtaskLog = field(default=FEMSubtaskLog(), compare=False, init=False)

@dataclass(order=True)
class FEMTaskResult():
    episode: int = field(compare=False)
    step: int
    state: State = field(compare=False)
    state_image: Tensor = field(compare=False, repr=False)
    action_index: int = field(compare=False)
    next_state: State = field(compare=False)
    next_state_image: Tensor = field(compare=False, repr=False)

### GridHoleBoardEnv Class

In [23]:
class GridHoleBoardEnv():
    def __init__(self,
                 size: Tuple[float, float],
                 grid_size: Tuple[int, int],
                 holes_disabled: Optional[Set[Tuple[int, int]]] = None) -> None:
        self.size: Tuple[float, float] = size
        self.grid_size: Tuple[int, int] = grid_size
        self.cell_size: Tuple[float, float] = (size[0] / grid_size[0], size[1] / grid_size[1])
        # (x, y) -> size
        self.holes: Tensor = torch.zeros(self.grid_size, device=cuda)

        # (x, y) -> (x_coord, y_coord)
        self.holes_center: Tensor = torch.zeros((*self.grid_size, 2))

        self.holes_disabled: Set[Tuple(int, int)] = holes_disabled if holes_disabled else {}

        self.action_space: List[Tuple[Tuple[int, int], float]] = list()

        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                self.holes_center[x, y, 0] = (x + 0.5) * self.cell_size[0]
                self.holes_center[x, y, 1] = (y + 0.5) * self.cell_size[1]
                if (x, y) not in self.holes_disabled:
                    self.action_space.append(((x, y), 0.5))
                    self.action_space.append(((x, y), -0.5))

    def reset(self) -> None:
        self.holes: Tensor = torch.ones(self.grid_size, device=cuda)

    def random(self) -> None:
        self.holes: Tensor = torch.rand(self.grid_size, device=cuda) * 3

    def step(self, action: Action) -> None:
        (x, y), size_change = action

        self.holes[x, y] += size_change

        self.holes[x, y] = torch.clamp(self.holes[x, y], 0., min(self.cell_size) / 2 - 1)

    def get_state(self) -> State:
        return self.holes.clone()

## Multiprocess FEM

### FEMConfig Class

### GridHoleBoardFEMTask

In [24]:
def FEMWorkerInitializer(mp_render_queue: Queue,
                         mp_result_cache: Dict[str, str],
                         env_board_size: Tuple[float, float],
                         env_grid_size: Tuple[int, int],
                         env_holes_center: Tensor,
                         shared_fem_parameters: SimpleNamespace) -> None:
    global render_queue, result_cache, board_size, grid_size, holes_center, fem_parameters

    render_queue = mp_render_queue
    result_cache = mp_result_cache
    board_size = env_board_size
    grid_size = env_grid_size
    holes_center = env_holes_center
    fem_parameters = shared_fem_parameters

In [25]:
def GridHoleBoardFEMTask(subtask: FEMSubtask, queue_time: datetime) -> None:
    import os
    import sys
    import multiprocess

    devnull = open(os.devnull, 'w')
    oldstdout_fno = os.dup(sys.stdout.fileno())
    os.dup2(devnull.fileno(), 1)

    import uuid
    from datetime import datetime

    import getfem as gf
    import numpy as np

    from PIL import Image, ImageChops

    import torch

    import pyvista as pv
    from pyvirtualdisplay.display import Display
    
    # Creating references for shorter notation
    state = subtask.state
    state_key = subtask.state_key

    fp = fem_parameters

    # Timer
    start_time = datetime.now()
    queue_time = datetime.now() - queue_time

    fem_worker = multiprocess.current_process().name
    
    subtask.log.queue_time = queue_time
    subtask.log.fem_worker = fem_worker

    # Generate Mesh
    try:
        board = gf.MesherObject('rectangle', [0., 0.], list(board_size))
        holes: List[gf.MesherObject] = list()

        for x in range(grid_size[0]):
            for y in range(grid_size[1]):
                center = holes_center[x, y].tolist()
                size = state[x, y].item()

                if size < 0.01 * fp.element_diameter: continue

                holes.append(gf.MesherObject('ball', center, size))

        if holes:
            holes_union = gf.MesherObject('union', *holes)
            mesher = gf.MesherObject('set minus', board, holes_union)
        else:
            mesher = board

        #print('Beginning mesh generation')
        gf.util('trace level', 0)   # No trace for mesh generation
        mesh = gf.Mesh('generate', mesher, fp.element_diameter, 2)

        boundary: Dict[str, int] = dict()

        # Boundary of the holes
        boundary['HOLE_BOUND'] = 1
        mesh.set_region(boundary['HOLE_BOUND'],
                        mesh.outer_faces_in_box([1., 1.],
                                                [board_size[0] - 1, board_size[1] - 1]))

        boundary['LEFT_BOUND'] = 2
        mesh.set_region(boundary['LEFT_BOUND'], mesh.outer_faces_with_direction([-1., 0.], 0.01))

        boundary['RIGHT_BOUND'] = 3
        mesh.set_region(boundary['RIGHT_BOUND'], mesh.outer_faces_with_direction([ 1., 0.], 0.01))

        boundary['TOP_BOUND'] = 4
        mesh.set_region(boundary['TOP_BOUND'], mesh.outer_faces_with_direction([0.,  1.], 0.01))

        boundary['BOTTOM_BOUND'] = 5
        mesh.set_region(boundary['BOTTOM_BOUND'], mesh.outer_faces_with_direction([0., -1.], 0.01))

        mesh.region_subtract( boundary['RIGHT_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(  boundary['LEFT_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(   boundary['TOP_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(boundary['BOTTOM_BOUND'], boundary['HOLE_BOUND'])

        region_id = 7
        for x in range(grid_size[0]):
            for y in range(grid_size[1]):
                center = holes_center[x, y].tolist()
                size = state[x, y].item()
                bound_key = f'HOLE{x}_{y}_BOUND'
                boundary[bound_key] = region_id
                mesh.set_region(boundary[bound_key],
                                mesh.outer_faces_in_ball(center, size + 0.01 * fp.element_diameter))
                if region_id == 7:
                    boundary['HOLE_UNION_BOUND'] = 6
                    mesh.set_region(boundary['HOLE_UNION_BOUND'],
                                mesh.outer_faces_in_ball(center, size + 0.01 * fp.element_diameter))
                else:
                    mesh.region_merge(boundary['HOLE_UNION_BOUND'], boundary[bound_key])
                region_id += 1

        np.testing.assert_array_equal(mesh.region(boundary['HOLE_BOUND']),
                                      mesh.region(boundary['HOLE_UNION_BOUND']))
        
        mesh_time = datetime.now() - start_time
        subtask.log.mesh_time = mesh_time
    except Exception as err:
        subtask.log.mesh_time = 'Failed'
        subtask.log.solve_time = 'Skipped'
        subtask.log.error_msg += str(err) + '\n'
        
        render_queue.put(subtask)
        
        os.dup2(oldstdout_fno, 1)
        devnull.close()
        return

    # Solve
    try:


        #
        # Definition of finite elements methods and integration method
        #

        mfu = gf.MeshFem(mesh, 2)  # Finite element for the elastic displacement
        mfu.set_classical_fem(fp.elements_degree)
        mft = gf.MeshFem(mesh, 1)  # Finite element for temperature and electrical field
        mft.set_classical_fem(fp.elements_degree)
        mfvm = gf.MeshFem(mesh, 1) # Finite element for Von Mises stress interpolation
        mfvm.set_classical_discontinuous_fem(fp.elements_degree)
        mim = gf.MeshIm(mesh, fp.elements_degree * 2)   # Integration method

        md=gf.Model('real');
        md.add_fem_variable('u', mfu)       # Displacement of the structure
        md.add_fem_variable('theta', mft)   # Temperature
        md.add_fem_variable('V', mft)       # Electric potential

        # Membrane elastic deformation
        md.add_initialized_data('cmu', [fp.cmu])
        md.add_initialized_data('clambdastar', [fp.clambdastar])
        md.add_isotropic_linearized_elasticity_brick(mim, 'u', 'clambdastar', 'cmu')

        md.add_Dirichlet_condition_with_multipliers(mim, 'u', fp.elements_degree - 1, boundary['LEFT_BOUND'])
        md.add_initialized_data('Fdata', [fp.F * fp.epsilon, 0])
        md.add_source_term_brick(mim, 'u', 'Fdata', boundary['RIGHT_BOUND'])

        # Electrical field
        sigmaeps = '(eps/(rho_0*(1+alpha*(theta-T0))))'
        md.add_initialized_data('eps', [fp.epsilon])
        md.add_initialized_data('rho_0', [fp.rho_0])
        md.add_initialized_data('alpha', [fp.alpha])
        md.add_initialized_data('T0', [fp.T0])
        md.add_nonlinear_term(mim, sigmaeps+'*(Grad_V.Grad_Test_V)')
        md.add_Dirichlet_condition_with_multipliers(mim, 'V', fp.elements_degree - 1, boundary['RIGHT_BOUND'])
        md.add_initialized_data('DdataV', [2.])
        md.add_Dirichlet_condition_with_multipliers(mim, 'V', fp.elements_degree - 1, boundary['LEFT_BOUND'], 'DdataV')

        # Thermal problem
        md.add_initialized_data('kaeps', [fp.kappa * fp.epsilon])
        md.add_generic_elliptic_brick(mim, 'theta', 'kaeps')
        md.add_initialized_data('D2', [fp.D * 2])
        md.add_initialized_data('D2airt', [fp.air_temp * fp.D * 2])
        md.add_mass_brick(mim, 'theta', 'D2')
        md.add_source_term_brick(mim, 'theta', 'D2airt')
        md.add_initialized_data('Deps', [fp.D / fp.epsilon])
        md.add_initialized_data('Depsairt', [fp.air_temp * fp.D / fp.epsilon])
        md.add_Fourier_Robin_brick(mim, 'theta', 'Deps', boundary['TOP_BOUND'])
        md.add_source_term_brick(mim, 'theta', 'Depsairt', boundary['TOP_BOUND'])
        md.add_Fourier_Robin_brick(mim, 'theta', 'Deps', boundary['BOTTOM_BOUND'])
        md.add_source_term_brick(mim, 'theta', 'Depsairt', boundary['BOTTOM_BOUND'])

        # Joule heating term
        md.add_nonlinear_term(mim, '-' + sigmaeps + '*Norm_sqr(Grad_V)*Test_theta')

        # Thermal expansion term
        md.add_initialized_data('beta', [fp.alpha_th * fp.E / (1 - 2 * fp.nu)])
        md.add_linear_term(mim, 'beta*(T0-theta)*Trace(Grad_Test_u)')

        #
        # Model solve
        #

        md.disable_variable('u')
        md.solve('max_res', 1E-9, 'max_iter', 100)

        #
        # Solution export
        #
        THETA = md.variable('theta')

        file_id = uuid.uuid4()

        vtk_path = f'temp/{file_id}.vtk'
        mft.export_to_vtk(vtk_path, mft, THETA, 'Temperature')

        solve_time = datetime.now() - start_time - mesh_time
        subtask.log.solve_time = solve_time
    except Exception as err:
        subtask.log.solve_time = 'Failed'
        subtask.log.error_msg += str(err) + '\n'
        
        render_queue.put(subtask)
        
        os.dup2(oldstdout_fno, 1)
        devnull.close()
        return
    
    subtask.file_id = file_id
    subtask.vtk_path = vtk_path

    os.dup2(oldstdout_fno, 1)
    devnull.close()

    render_queue.put(subtask)

### RenderTask

In [26]:
def RenderTask(render_queue: Queue, result_queue: Queue,
               result_cache: Dict[str, str], result_image_cache: Dict[str, Tensor],
               device: torch.device) -> None:
    from datetime import datetime
    import multiprocess

    import numpy as np
    from PIL import Image, ImageChops

    import torch
    from torchvision.transforms import PILToTensor

    import pyvista as pv
    from pyvirtualdisplay.display import Display
    # Render Image
    with Display(visible=0, size=(1280, 1024)) as display:
        p = pv.Plotter(off_screen=True, lighting='three lights')
        p.enable_3_lights()
        if p.scalar_bars:
            for sb in list(p.scalar_bars.keys()):
                p.remove_scalar_bar(sb)

        to_tensor = PILToTensor()

        render_worker = multiprocess.current_process().name


        # Starting Pipeline render_queue => this => result_queue
        while True:
            err = ''
            start_time = datetime.now()
            
            subtask = render_queue.get()
            
            vtk_path = subtask.vtk_path
            file_id = subtask.file_id
            state_key = subtask.state_key

            # Skip if FEM task failed in any step
            if vtk_path == '' or file_id == '':
                subtask.log.render_time = 'Skipped'
                result_queue.put(subtask)
                continue

            try:
                m = pv.read(vtk_path)

                a = p.add_mesh(m, line_width=5, cmap='Greys_r', clim=[20, 60], show_scalar_bar=False)

                p.view_xy()

                img_path = f'result/{file_id}.png'

                img_arr = p.screenshot(filename=None, transparent_background=True, window_size=[512, 384])

                p.remove_actor(a, render=False)
                # Try whether deep clean is nessary for large dataset
                #p.deep_clean()

                img = Image.fromarray(img_arr)
                bg = Image.new(img.mode, img.size, img.getpixel((0,0)))
                diff = ImageChops.difference(img, bg)
                diff = ImageChops.add(diff, diff, 2.0, -100)
                bbox = diff.getbbox()
                if bbox:
                    img = img.crop(bbox)
                    
                img.save(img_path)
                
                img_tensor = to_tensor(img).to(device=device)
                #img_tensor = to_tensor(img).to(device=device, dtype=torch.float)

                result_cache[state_key] = file_id
                result_image_cache[state_key] = img_tensor

                render_time = datetime.now() - start_time

            except Exception as err:
                subtask.log.render_time = 'Failed'

                result_queue.put(subtask)

            subtask.log.render_time = render_time

            subtask.img_path = img_path
            subtask.img_tensor = img_tensor

            subtask.successful = True
            result_queue.put(subtask)

### GridHoleBoardFEMSimulator

In [27]:
class GridHoleBoardFEMSimulator():
    def __init__(self,
                 environment: GridHoleBoardEnv,
                 *,
                 fem_task_pool_size: int = hp.fem_task_pool_size,
                 render_task_pool_size: int = hp.render_task_pool_size) -> None:
        self.environment = environment
        self.board_size = environment.size
        self.grid_size = environment.grid_size
        self.holes_center = environment.holes_center.cpu()

        self.mp_manager: Manager = Manager()

        self.mp_render_queue: Queue = self.mp_manager.Queue()
        self.mp_result_queue: Queue = self.mp_manager.Queue()

        self.mp_result_cache: Dict[str, str] = self.mp_manager.dict()
        self.mp_result_image_cache: Dict[str, Tensor] = self.mp_manager.dict()

        self.mp_fem_task_pool_size: int = fem_task_pool_size
        self.mp_fem_task_pool: Pool = Pool(fem_task_pool_size,
                                           initializer=FEMWorkerInitializer,
                                           initargs=(self.mp_render_queue,
                                                     self.mp_result_cache,
                                                     self.board_size,
                                                     self.grid_size,
                                                     self.holes_center,
                                                     fp))

        self.mp_render_task_pool_size: int = render_task_pool_size
        self.mp_render_task_pool: Pool = Pool(render_task_pool_size)

        self.task_list: List[FEMTask] = list()
        self.subtask_status: Dict[str, TaskStatus] = dict()    # state_key -> task_status
        
        self.cache_stat: Dict[str, int] = {'newly_solved': 0, 'same_episode': 0, 'image_cache': 0, 'file_cache': 0}

        def render_err_callback(err):
            with open('render_err.log', 'a') as log:
                print('========', file=log)
                print(datetime.now(), file=log)
                print(err, file=log)
            # TODO Restarting render process after failed

        self.mp_render_task_pool.starmap_async(RenderTask,
                                               [(self.mp_render_queue, self.mp_result_queue,
                                                 self.mp_result_cache, self.mp_result_image_cache,
                                                 cuda)]
                                               * self.mp_render_task_pool_size,
                                               error_callback=render_err_callback)

    def clear_image_cache(self):
        self.mp_result_image_cache.clear()

    def terminate_current_queue(self):
        self.mp_fem_task_pool.terminate()

    def reset_task_pool(self):
        exp['cache_stat'].log(str(self.cache_stat))
        self.cache_stat = {'newly_solved': 0, 'same_episode': 0, 'image_cache': 0, 'file_cache': 0}
        self.task_list.clear()
        self.subtask_status.clear()
        self.mp_fem_task_pool = Pool(self.mp_fem_task_pool_size,
                                     initializer=FEMWorkerInitializer,
                                     initargs=(self.mp_render_queue,
                                               self.mp_result_cache,
                                               self.board_size,
                                               self.grid_size,
                                               self.holes_center,
                                               fp))

    def queue_subtask(self, subtask: FEMSubtask) -> None:
        state_key = subtask.state_key
        
        if state_key not in self.subtask_status:
            # Fallback to image cache
            if state_key in self.mp_result_image_cache:
                exp['state_cache'].log(f'From ImgCache: {state_key}')
                self.cache_stat['image_cache'] += 1
                self.subtask_status[state_key] = TaskStatus.Successful
            else:
                # Fallback to file cache
                if state_key in self.mp_result_cache:
                    exp['state_cache'].log(f'From FileCache: {state_key}')
                    self.cache_stat['file_cache'] += 1
                    self.load_from_cache(state_key)
                else:
                    # Run FEM for unseen state
                    exp['state_cache'].log(f'Solve New: {state_key}')
                    self.cache_stat['newly_solved'] += 1
                    self.subtask_status[state_key] = TaskStatus.Pending
                    queue_time = datetime.now()
                    def err_callback(err):
                        with open('fem_err.log', 'a') as log:
                            print('========', file=log)
                            print(datetime.now(), file=log)
                            print(subtask, file=log)
                            print(err, file=log)

                    self.mp_fem_task_pool.apply_async(func=GridHoleBoardFEMTask,
                                                      args=(subtask, queue_time),
                                                      error_callback=err_callback)
        else:
            exp['state_cache'].log(f'From Same Episode: {state_key}')
            self.cache_stat['same_episode'] += 1
                    

    def load_from_cache(self, state_key: str):
        to_tensor = PILToTensor()
        file_id = self.mp_result_cache[state_key]
        img_path = f'result/{file_id}.png'
        img = Image.open(img_path)
        img_tensor = to_tensor(img).to(device=cuda)
        self.mp_result_image_cache[state_key] = img_tensor
        self.subtask_status[state_key] = TaskStatus.Successful

    def queue_transition(self, episode: int, step: int,
                         state: State, action_index: int, next_state: State) -> None:

        state_key = str(state.flatten().tolist())
        next_state_key = str(next_state.flatten().tolist())

        self.queue_subtask(FEMSubtask(episode, step, state, state_key))
        self.queue_subtask(FEMSubtask(episode, step, next_state, next_state_key))

        task = FEMTask(episode, step, state, state_key, action_index, next_state, next_state_key)

        self.task_list.append(task)

    def get_transition_result(self) -> Generator[FEMTaskResult, None, None]:
        for task in self.task_list:
            # Retrieve more subtask results if there is pending subtask in current task
            while self.subtask_status[task.state_key] is TaskStatus.Pending \
                    or self.subtask_status[task.next_state_key] is TaskStatus.Pending:
                self.retrieve_next_task_result()

            result = FEMTaskResult(task.episode, task.step,
                                   task.state, None,
                                   task.action_index,
                                   task.next_state, None)

            if self.subtask_status[task.state_key] is TaskStatus.Successful:
                result.state_image = self.mp_result_image_cache[task.state_key]

            if self.subtask_status[task.next_state_key] is TaskStatus.Successful:
                result.next_state_image = self.mp_result_image_cache[task.next_state_key]

            yield result

    def retrieve_next_task_result(self) -> None:
        result: FEMSubtask = fem.mp_result_queue.get(timeout=600)
        queue_time = 0. if isinstance(result.log.queue_time, str) else result.log.queue_time.total_seconds()
        mesh_time = 0. if isinstance(result.log.mesh_time, str) else result.log.mesh_time.total_seconds()
        solve_time = 0. if isinstance(result.log.solve_time, str) else result.log.solve_time.total_seconds()
        render_time = 0. if isinstance(result.log.render_time, str) else result.log.render_time.total_seconds()

        exp['fem_times/queue_time'].log(queue_time)
        exp['fem_times/mesh_time'].log(mesh_time)
        exp['fem_times/solve_time'].log(solve_time)
        exp['fem_times/render_time'].log(render_time)

        if result.successful:
            self.subtask_status[result.state_key] = TaskStatus.Successful
        else:
            self.subtask_status[result.state_key] = TaskStatus.Failed

## DQN

### FEM-based Reward & Final Function

In [28]:
class FEMReward():
    def __init__(self,
                 environment: GridHoleBoardEnv,
                 target: str,
                 compare_spot: List[Tuple[int, int]],
                 max_step: int = 1000,
                 goal_reward: float = 10000.,
                 final_state_error_threshold: float = 0.01,
                 out_of_bound_penalty: float = -1000) -> None:
        self.environment: GridHoleBoardEnv = environment
        self.target_image: Tensor = PILToTensor()(Image.open(target))[0].to(device=cuda, dtype=torch.float)

        # Create a bool mask for extracting the comparing points
        self.compare_spot = compare_spot
        self.compare_mask = torch.BoolTensor(*(self.target_image.size())).fill_(False).to(cuda)
        for spot in compare_spot:
            self.compare_mask[spot] = True

        self.target_spot = self.target_image[self.compare_mask]

        self.loss = nn.MSELoss(reduction='sum')


        self.goal_reward = goal_reward
        self.final_state_error_threshold = final_state_error_threshold
        self.out_of_bound_penalty = out_of_bound_penalty


    def __call__(self, result: FEMTaskResult) -> Tuple[Optional[float], bool]:
        reward = 0.
        final = False
        
        if result.state_fem_image is None:
            return None, False
        
        state_spot = result.state_fem_image[0][self.compare_mask]
        error = float(self.loss(state_spot, self.target_spot))
        final = error <= self.final_state_error_threshold
        
        # If the state not changed(which means blocked by the environment),
        # the out-of-bound penalty is applied
        if torch.equal(result.state, result.next_state):
            reward += self.out_of_bound_penalty

        if result.next_state_fem_image is None:
            return reward, final
        
        next_state_spot = result.next_state_fem_image[0][self.compare_mask]
        next_error = float(self.loss(next_state_spot, self.target_spot))
        
        # Reward decreasing error from state to next state
        reward += error - next_error
        
        # Reward extra points if next state is final
        if next_error <= self.final_state_error_threshold:
            reward += self.goal_reward
            
        log.step(result.episode, result.step)
        log['rewards/reward'] = reward
        log['rewards/error'] = error

        return reward, final

### Network Definition

In [29]:
# Network Container
class Model():
    def __init__(self, network: nn.Module, loss_func: _Loss, optimizer: Optimizer):
        self.network = network
        self.loss_func = loss_func
        self.optimizer = optimizer

    def __call__(self, network_input: Tensor) -> Tensor:
        return self.network(network_input)

In [30]:
def QNet(state_size: int = 12, action_number: int = 24, target_network: bool = False):
    net = nn.Sequential(
        nn.Linear(state_size, 100, device=cuda),
        nn.ReLU(),
        nn.Linear(100, 200, device=cuda),
        nn.ReLU(),
        nn.Linear(200, action_number, device=cuda),
    )
    if target_network:
        return Model(network=net, loss_func=None, optimizer=None)
    else:
        return Model(network=net, loss_func=nn.SmoothL1Loss(), optimizer=torch.optim.Adam(net.parameters(), 0.001))

### Replay Memory Class

In [31]:
class ReplayMemory():
    def __init__(self, capacity):
        self.memory: deque = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

### Agent Class

In [32]:
class Agent():
    def __init__(self,
                 environment: GridHoleBoardEnv,
                 fem_simulator: GridHoleBoardFEMSimulator,
                 fem_reward_func: Callable[[Tensor, int, Tensor], float],
                 q_network: nn.Module,
                 target_network: nn.Module,
                 target_update_interval: int = hp.target_update_interval,
                 optimization_iter: int = hp.optimization_iter,
                 max_step_per_episode: int = hp.max_step_per_episode,
                 experience_replay: ReplayMemory = ReplayMemory(hp.experience_replay_size),
                 replay_batch_size: int = hp.replay_batch_size,
                 discount_factor: float = hp.discount_factor,
                 explore_factor_start: float = hp.explore_factor_start,
                 explore_factor_end: float = hp.explore_factor_end,
                 explore_factor_decay: float = hp.explore_factor_decay,
                ) -> None:
        self.environment: GridHoleBoardEnv = environment
        self.fem_simulator: GridHoleBoardFEMSimulator = fem_simulator
        self.fem_reward_func: Callable[[Tensor, int, Tensor], float] = fem_reward_func

        self.q_network: Model = q_network
        self.target_network: Model = target_network
        self.target_update_interval: int = target_update_interval

        self.optimization_iter: int = optimization_iter
        self.max_step_per_episode: int = max_step_per_episode
        self.experience_replay: ReplayMemory = experience_replay
        self.replay_batch_size: int = replay_batch_size

        self.discount_factor: float = discount_factor
        self.explore_factor_start: float = explore_factor_start
        self.explore_factor_end: float = explore_factor_end
        self.explore_factor_decay: float = explore_factor_decay
        self.explored_step: int = 0


    def select_action(self, disable_explore: bool = False) -> Tuple[State, int]:
        state = self.environment.get_state()
        explore_factor = self.explore_factor_end \
                            + (self.explore_factor_start - self.explore_factor_end) \
                            * math.exp(-1. * self.explored_step / self.explore_factor_decay)
        exp['explore_factor'].log(explore_factor, step=self.explored_step)
        if random.random() > explore_factor or disable_explore:
            prediction = self.q_network(state.flatten())
            action_index = prediction.argmax().item()
        else:
            action_index = random.randrange(len(self.environment.action_space))
        self.explored_step += 1
        return state, action_index

    def step(self, episode: int, step: int, disable_explore: bool = False) -> bool:
        state, action_index = self.select_action(disable_explore)

        self.environment.step(self.environment.action_space[action_index])
        next_state = self.environment.get_state()
            
        action = self.environment.action_space[action_index]
        action_log = f'{action} ({state[action[0]]} => {next_state[action[0]]})'
        exp['actions'].log(action_log)

        self.fem_simulator.queue_transition(episode, step, state, action_index, next_state)

        # Return True if the next state is final
        return step >= self.max_step_per_episode

    def complete_fem_rewarding(self, episode: int) -> None:
        fem = self.fem_simulator    # Create local reference to shorter statements
        print('Start computing rewards based on completed FEM simulation')
        i = 0
        
        for result in fem.get_transition_result():
            reward, final = self.fem_reward_func(result)
            if reward is None:
                continue

            if final:
                next_state = None
            else:
                next_state = result.next_state

            self.experience_replay.push(result.state, result.action_index, reward, next_state)
            
            i += 1
            if i % 100:
                print(f'{i}/{self.max_step_per_episode}', end='\r')
            
            if final:
                fem.terminate_current_queue()
                print('Early stop triggered')
                return


    def optimize(self) -> None:
        if len(self.experience_replay) < self.replay_batch_size: return

        samples = self.experience_replay.sample(self.replay_batch_size)
        batch = Transition(*zip(*samples))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                  batch.next_state)), device=cuda, dtype=torch.bool)
        non_final_next_states = torch.stack([s.flatten() for s in batch.next_state
                                                        if s is not None])

        state_batch = torch.stack([s.flatten() for s in batch.state])
        action_batch = torch.tensor(batch.action_index, device=cuda).unsqueeze(1)
        reward_batch = torch.tensor(batch.reward, device=cuda)

        state_action_values = self.q_network(state_batch).gather(1, action_batch)


        next_state_values = torch.zeros(self.replay_batch_size, device=cuda)
        next_state_values[non_final_mask] = self.target_network(non_final_next_states).max(1)[0].detach()

        expected_state_action_values = (next_state_values * self.discount_factor) + reward_batch

        loss = self.q_network.loss_func(state_action_values, expected_state_action_values.unsqueeze(1))
        self.q_network.optimizer.zero_grad()
        loss.backward()
        for param in self.q_network.network.parameters():
            param.grad.data.clamp_(-1, 1)
        self.q_network.optimizer.step()

    def update_target_network(self) -> None:
        self.target_network.network.load_state_dict(self.q_network.network.state_dict())

    def train(self, episode: int = hp.max_episode) -> None:
        for e in range(episode):
            start_time = datetime.now()
            self.environment.reset()
            # Use fixed number of steps in each episode
            for step_num in range(self.max_step_per_episode):
                self.step(e, step_num)
            print(f'Episode {e} completed')
            self.complete_fem_rewarding(e)
            for i in range(self.optimization_iter):
                self.optimize()
            self.fem_simulator.clear_image_cache()
            self.fem_simulator.reset_task_pool()
            if e % self.target_update_interval == 0:
                # print('Target network update                 ')
                self.update_target_network()

            end_time = datetime.now()
            print(f'Episode {e} completed in {end_time - start_time}')

    def generate(self, step: int = 100) -> State:
        pass
        # TODO: rewrite
        # self.environment.reset()
        # for i in range(step):
        #     if self.step(100000000, i):
        #         print(f'Completed in {step} steps                    ', end='\r\n')
        #         break
        # return self.environment.get_state(), i

## Training

In [33]:
env = GridHoleBoardEnv(size=ep.size, grid_size=ep.grid_size)

fem = GridHoleBoardFEMSimulator(env)

reward_func = FEMReward(env, 'target.png', [(73, 36),
                                       (147, 36),
                                       (73, 110),
                                       (147, 110),
                                       (73, 184),
                                       (147, 184),
                                       (73, 258),
                                       (147, 258)])

agent = Agent(env, fem, reward_func, QNet(), QNet(target_network=True))

In [None]:
#print(agent.generate())

agent.train(1000)

#print(agent.generate())

Episode 0 completed
Start computing rewards based on completed FEM simulation
Episode 0 completed in 0:01:43.933093
Episode 1 completed
Start computing rewards based on completed FEM simulation
Episode 1 completed in 0:00:21.685933
Episode 2 completed
Start computing rewards based on completed FEM simulation
Episode 2 completed in 0:00:36.146539
Episode 3 completed
Start computing rewards based on completed FEM simulation
Episode 3 completed in 0:00:19.423703
Episode 4 completed
Start computing rewards based on completed FEM simulation
Episode 4 completed in 0:00:16.995210
Episode 5 completed
Start computing rewards based on completed FEM simulation
Episode 5 completed in 0:00:15.272616
Episode 6 completed
Start computing rewards based on completed FEM simulation
Episode 6 completed in 0:00:40.447813
Episode 7 completed
Start computing rewards based on completed FEM simulation
Episode 7 completed in 0:00:15.187500
Episode 8 completed
Start computing rewards based on completed FEM simul

In [None]:
fig, ax = plt.subplots()
ax.scatter(list(range(len(agent.running_loss))), agent.running_loss, s=1, vmin=0, vmax=2)
plt.show()

In [None]:
exp.stop()

## Manuscript

In [None]:
a = str(123)
a

In [5]:
a = torch.tensor([1, 2])
b = torch.tensor([1, 2])
torch.equal(a, b)


True

In [None]:
cache_stat: Dict[str, int] = {'solved': 0, 'image_cache': 0, 'file_cache': 0}

In [None]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved