# Reinforcement Learning- and FEM-based Inverse Design 

## Import

In [1]:
from typing import Union, Optional, Callable, Any
from typing import Tuple, List, Set, Dict
from typing import NamedTuple

In [2]:
from collections import defaultdict, deque

In [3]:
import os
import sys
import time
from datetime import datetime
import random
import math
import itertools
import uuid

Package `multiprocess` is used instead of built-in `multiprocessing`, to support multiprocessing in Jupyter environment.

In [4]:
import multiprocess
from multiprocess import Process, Pool, Queue, Manager

In [5]:
import numpy as np
import pandas as pd
from PIL import Image, ImageChops

In [6]:
import matplotlib.pyplot as plt

In [7]:
import getfem as gf

initializing ...
numthread = 1


In [8]:
import pyvista as pv
from pyvirtualdisplay.display import Display

In [9]:
import torch
from torch import nn

from torch import Tensor

from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer

from torchvision.transforms import PILToTensor

print(torch.__version__)

1.10.0


## Computing Devices

In [10]:
multiprocess.cpu_count()

56

In [11]:
# Getting all memory using os.popen()
mem_bytes = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES')  # e.g. 4015976448
mem_gib = mem_bytes/(1024.**3)
mem_gib

376.29273986816406

In [12]:
available_gpus = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
available_gpus

['Tesla V100S-PCIE-32GB', 'Tesla V100S-PCIE-32GB']

In [13]:
cuda = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Current computing device:', cuda)

Current computing device: cuda


In [14]:
multiprocess.set_start_method('spawn')

## Debug Helpers

In [15]:
def debug(*args):
    with open('debug.log', 'a') as fp:
        print(*args, file=fp)
debug('New Session Started Here')

## Log Cleanup

In [16]:
!bash clean.sh

rm: cannot remove 'fem_err.log': No such file or directory
rm: cannot remove 'render_err.log': No such file or directory
rm: cannot remove 'reward.log': No such file or directory


## Reinforcement Learning Environment

### Data Classes

In [17]:
# Data Classes
State = Tensor
    
Action = Tuple[Tuple[int, int], float] 

class Transition(NamedTuple):
    state: State
    action_index: int
    reward: float
    next_state: State
        
class FEMTask(NamedTuple):
    state: State
    state_key: str
    action_index: int
    next_state: State
    next_state_key: str

### GridHoleBoardEnv Class

In [18]:
class GridHoleBoardEnv():
    def __init__(self, 
                 size: Tuple[float, float], 
                 grid_size: Tuple[int, int],
                 holes_disabled: Optional[Set[Tuple[int, int]]] = None) -> None:
        self.size: Tuple[float, float] = size
        self.grid_size: Tuple[int, int] = grid_size
        self.cell_size: Tuple[float, float] = (size[0] / grid_size[0], size[1] / grid_size[1])
        # (x, y) -> size
        self.holes: Tensor = torch.zeros(self.grid_size, device=cuda)
            
        # (x, y) -> (x_coord, y_coord)
        self.holes_center: Tensor = torch.zeros((*self.grid_size, 2))
            
        self.holes_disabled: Set[Tuple(int, int)] = holes_disabled if holes_disabled else {}
            
        self.action_space: List[Tuple[Tuple[int, int], float]] = list()
            
        for x in range(self.grid_size[0]):
            for y in range(self.grid_size[1]):
                self.holes_center[x, y, 0] = (x + 0.5) * self.cell_size[0]
                self.holes_center[x, y, 1] = (y + 0.5) * self.cell_size[1]
                if (x, y) not in self.holes_disabled:
                    self.action_space.append(((x, y), 0.5))
                    self.action_space.append(((x, y), -0.5))
                    
    def reset(self) -> None:
        self.holes: Tensor = torch.ones(self.grid_size, device=cuda)
            
    def random(self) -> None:
        self.holes: Tensor = torch.rand(self.grid_size, device=cuda) * 3
    
    def step(self, action: Action) -> None:
        (x, y), size_change = action
        
        self.holes[x, y] += size_change
        
        self.holes[x, y] = torch.clamp(self.holes[x, y], 0., min(self.cell_size) / 2 - 1)
        
    def get_state(self) -> State:
        return self.holes.clone()

## FEM

### FEMConfig Class

In [19]:
class FEMPhysic():
    #
    # Physical parameters
    #
    epsilon = .02       # Thickness of the plate (cm)
    E = 21E6           # Young Modulus (N/cm^2)
    nu = 0.3           # Poisson ratio
    clambda = E*nu/((1+nu)*(1-2*nu)) # First Lame coefficient (N/cm^2)
    cmu = E/(2*(1+nu))               # Second Lame coefficient (N/cm^2)
    clambdastar = 2*clambda*cmu/(clambda+2*cmu) # Lame coefficient for Plane stress (N/cm^2)
    F = 100E2          # Force density at the right boundary (N/cm^2)
    kappa = 4.         # Thermal conductivity (W/(cm K))
    D = 10.            # Heat transfer coefficient (W/(K cm^2))
    air_temp = 20.     # Temperature of the air in oC.
    alpha_th = 16.6E-6 # Thermal expansion coefficient (/K).
    T0 = 20.           # Reference temperature in oC.
    rho_0 = 1.754E-8   # Resistance temperature coefficient at T0 = 20oC
    alpha = 0.0039     # Second resistance temperature coefficient.

    #
    # Numerical parameters
    #
    elements_degree = 2       # Degree of the finite element methods

### GridHoleBoardFEMTask

In [20]:
def GridHoleBoardFEMTask(render_queue: Queue,
                         result_cache: Dict[str, str],
                         task_id: int,
                         state_key: str,
                         state: State,
                         board_size: Tuple[float, float],
                         grid_size: Tuple[int, int],
                         holes_center: Tensor,
                         element_diameter: float,
                         fem_physic: FEMPhysic, 
                         queue_time: datetime) -> None:
    import os
    import sys
    import multiprocess
    
    devnull = open(os.devnull, 'w')
    oldstdout_fno = os.dup(sys.stdout.fileno())
    os.dup2(devnull.fileno(), 1)
    
    import uuid
    from datetime import datetime
    
    import getfem as gf
    import numpy as np
    
    from PIL import Image, ImageChops
    
    import torch
    
    import pyvista as pv
    from pyvirtualdisplay.display import Display
      
    start_time = datetime.now()
    queue_time = datetime.now() - queue_time
    
    fem_worker = multiprocess.current_process().name
    
    # Skip if already cached
    if state_key in result_cache:
        mesh_time = 'Skipped'
        solve_time = 'Skipped'
        os.dup2(oldstdout_fno, 1)
        devnull.close()

        render_queue.put((task_id, state_key, 
                          f'temp/{result_cache[state_key]}.vtk', result_cache[state_key], 
                          fem_worker, queue_time, mesh_time, solve_time))
        return
    
    
    # Generate Mesh
    try:
        board = gf.MesherObject('rectangle', [0., 0.], list(board_size))
        holes: List[gf.MesherObject] = list()

        for x in range(grid_size[0]):
            for y in range(grid_size[1]):
                center = holes_center[x, y].tolist()
                size = state[x, y].item()

                if size < 0.01 * element_diameter: continue

                holes.append(gf.MesherObject('ball', center, size))

        if holes:
            holes_union = gf.MesherObject('union', *holes)
            mesher = gf.MesherObject('set minus', board, holes_union)
        else:
            mesher = board

        #print('Beginning mesh generation')
        gf.util('trace level', 0)   # No trace for mesh generation
        mesh = gf.Mesh('generate', mesher, element_diameter, 2)

        boundary: Dict[str, int] = dict()

        # Boundary of the holes
        boundary['HOLE_BOUND'] = 1
        mesh.set_region(boundary['HOLE_BOUND'], 
                        mesh.outer_faces_in_box([1., 1.], 
                                                [board_size[0] - 1, board_size[1] - 1]))

        boundary['LEFT_BOUND'] = 2
        mesh.set_region(boundary['LEFT_BOUND'], mesh.outer_faces_with_direction([-1., 0.], 0.01))        

        boundary['RIGHT_BOUND'] = 3
        mesh.set_region(boundary['RIGHT_BOUND'], mesh.outer_faces_with_direction([ 1., 0.], 0.01)) 

        boundary['TOP_BOUND'] = 4
        mesh.set_region(boundary['TOP_BOUND'], mesh.outer_faces_with_direction([0.,  1.], 0.01)) 

        boundary['BOTTOM_BOUND'] = 5
        mesh.set_region(boundary['BOTTOM_BOUND'], mesh.outer_faces_with_direction([0., -1.], 0.01)) 

        mesh.region_subtract( boundary['RIGHT_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(  boundary['LEFT_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(   boundary['TOP_BOUND'], boundary['HOLE_BOUND'])
        mesh.region_subtract(boundary['BOTTOM_BOUND'], boundary['HOLE_BOUND'])

        region_id = 7
        for x in range(grid_size[0]):
            for y in range(grid_size[1]):
                center = holes_center[x, y].tolist()
                size = state[x, y].item()
                bound_key = f'HOLE{x}_{y}_BOUND'
                boundary[bound_key] = region_id
                mesh.set_region(boundary[bound_key], 
                                mesh.outer_faces_in_ball(center, size + 0.01 * element_diameter))
                if region_id == 7:
                    boundary['HOLE_UNION_BOUND'] = 6
                    mesh.set_region(boundary['HOLE_UNION_BOUND'], 
                                mesh.outer_faces_in_ball(center, size + 0.01 * element_diameter))
                else:
                    mesh.region_merge(boundary['HOLE_UNION_BOUND'], boundary[bound_key])
                region_id += 1

        np.testing.assert_array_equal(mesh.region(boundary['HOLE_BOUND']), 
                                      mesh.region(boundary['HOLE_UNION_BOUND']))
        mesh_time = datetime.now() - start_time
    except:
        mesh_time = 'Failed'
        solve_time = 'Skipped'
        render_queue.put((task_id, state_key, 
                          None, None, 
                          fem_worker, queue_time, mesh_time, solve_time))
        os.dup2(oldstdout_fno, 1)
        devnull.close()
        return
    
    # Solve
    try:
        fp = fem_physic

        #
        # Definition of finite elements methods and integration method
        #

        mfu = gf.MeshFem(mesh, 2)  # Finite element for the elastic displacement
        mfu.set_classical_fem(fp.elements_degree)
        mft = gf.MeshFem(mesh, 1)  # Finite element for temperature and electrical field
        mft.set_classical_fem(fp.elements_degree)
        mfvm = gf.MeshFem(mesh, 1) # Finite element for Von Mises stress interpolation
        mfvm.set_classical_discontinuous_fem(fp.elements_degree)
        mim = gf.MeshIm(mesh, fp.elements_degree * 2)   # Integration method

        md=gf.Model('real');
        md.add_fem_variable('u', mfu)       # Displacement of the structure
        md.add_fem_variable('theta', mft)   # Temperature
        md.add_fem_variable('V', mft)       # Electric potential

        # Membrane elastic deformation
        md.add_initialized_data('cmu', [fp.cmu])
        md.add_initialized_data('clambdastar', [fp.clambdastar])
        md.add_isotropic_linearized_elasticity_brick(mim, 'u', 'clambdastar', 'cmu')

        md.add_Dirichlet_condition_with_multipliers(mim, 'u', fp.elements_degree - 1, boundary['LEFT_BOUND'])
        md.add_initialized_data('Fdata', [fp.F * fp.epsilon, 0])
        md.add_source_term_brick(mim, 'u', 'Fdata', boundary['RIGHT_BOUND'])

        # Electrical field
        sigmaeps = '(eps/(rho_0*(1+alpha*(theta-T0))))'
        md.add_initialized_data('eps', [fp.epsilon])
        md.add_initialized_data('rho_0', [fp.rho_0])
        md.add_initialized_data('alpha', [fp.alpha])
        md.add_initialized_data('T0', [fp.T0])
        md.add_nonlinear_term(mim, sigmaeps+'*(Grad_V.Grad_Test_V)')
        md.add_Dirichlet_condition_with_multipliers(mim, 'V', fp.elements_degree - 1, boundary['RIGHT_BOUND'])
        md.add_initialized_data('DdataV', [2.])
        md.add_Dirichlet_condition_with_multipliers(mim, 'V', fp.elements_degree - 1, boundary['LEFT_BOUND'], 'DdataV')

        # Thermal problem
        md.add_initialized_data('kaeps', [fp.kappa * fp.epsilon])
        md.add_generic_elliptic_brick(mim, 'theta', 'kaeps')
        md.add_initialized_data('D2', [fp.D * 2])
        md.add_initialized_data('D2airt', [fp.air_temp * fp.D * 2])
        md.add_mass_brick(mim, 'theta', 'D2')
        md.add_source_term_brick(mim, 'theta', 'D2airt')
        md.add_initialized_data('Deps', [fp.D / fp.epsilon])
        md.add_initialized_data('Depsairt', [fp.air_temp * fp.D / fp.epsilon])
        md.add_Fourier_Robin_brick(mim, 'theta', 'Deps', boundary['TOP_BOUND'])
        md.add_source_term_brick(mim, 'theta', 'Depsairt', boundary['TOP_BOUND'])
        md.add_Fourier_Robin_brick(mim, 'theta', 'Deps', boundary['BOTTOM_BOUND'])
        md.add_source_term_brick(mim, 'theta', 'Depsairt', boundary['BOTTOM_BOUND'])

        # Joule heating term
        md.add_nonlinear_term(mim, '-' + sigmaeps + '*Norm_sqr(Grad_V)*Test_theta')

        # Thermal expansion term
        md.add_initialized_data('beta', [fp.alpha_th * fp.E / (1 - 2 * fp.nu)])
        md.add_linear_term(mim, 'beta*(T0-theta)*Trace(Grad_Test_u)')

        #
        # Model solve
        #

        md.disable_variable('u')
        md.solve('max_res', 1E-9, 'max_iter', 100)

        #
        # Solution export
        #  
        THETA = md.variable('theta')

        file_id = uuid.uuid4()

        vtk_path = f'temp/{file_id}.vtk'
        mft.export_to_vtk(vtk_path, mft, THETA, 'Temperature')

        solve_time = datetime.now() - start_time - mesh_time
    except:
        solve_time = 'Failed'
        render_queue.put((task_id, state_key,
                          None, None,
                          fem_worker, queue_time, mesh_time, solve_time))
        os.dup2(oldstdout_fno, 1)
        devnull.close()
        return
    
    
    os.dup2(oldstdout_fno, 1)
    devnull.close()
    
    
        
    render_queue.put((task_id, state_key, 
                      vtk_path, file_id, 
                      fem_worker, queue_time, mesh_time, solve_time))

### RenderTask

In [21]:
def RenderTask(render_queue: Queue, result_queue: Queue, 
               result_cache: Dict[str, str], result_image_cache: Dict[str, Tensor],
               device: torch.device) -> None:
    from datetime import datetime
    import multiprocess
    
    import numpy as np
    from PIL import Image, ImageChops
    
    import torch
    from torchvision.transforms import PILToTensor
    
    import pyvista as pv
    from pyvirtualdisplay.display import Display
    # Render Image    
    with Display(visible=0, size=(1280, 1024)) as display:
        p = pv.Plotter(off_screen=True, lighting='three lights')
        p.enable_3_lights()
        if p.scalar_bars:
            for sb in list(p.scalar_bars.keys()):
                p.remove_scalar_bar(sb)

                
        to_tensor = PILToTensor()
        
        render_worker = multiprocess.current_process().name
    
        # Starting Pipeline render_queue => this => result_queue
        while True:
            start_time = datetime.now()

            task_id, state_key, vtk_path, file_id, fem_worker, queue_time, mesh_time, solve_time = render_queue.get()

            if vtk_path is None:
                result_cache[state_key] = None
                result_image_cache[state_key] = None
                result_queue.put((task_id, state_key, None, None, None))
                
                render_time = 'Skipped'
                with open('fem.log', 'a') as log:
                    print('========', file=log)
                    print(datetime.now(), file=log)
                    print('Task', task_id, file=log)
                    print('Key', state_key, file=log)
                    print('Queue time', queue_time, file=log)
                    print('Mesh time', mesh_time, file=log)
                    print('Solve time', solve_time, file=log)
                    print('Render time', render_time, file=log)
                continue
                
            if state_key in result_image_cache:
                result_queue.put((task_id, state_key, vtk_path, f'result/{file_id}.png', result_image_cache[state_key]))
                
                render_time = 'Skipped'
                with open('fem.log', 'a') as log:
                    print('========', file=log)
                    print(datetime.now(), file=log)
                    print('Task', task_id, file=log)
                    print('Key', state_key, file=log)
                    print('Queue time', queue_time, file=log)
                    print('Mesh time', mesh_time, file=log)
                    print('Solve time', solve_time, file=log)
                    print('Render time', render_time, file=log)
                continue
                
            try:
                m = pv.read(vtk_path)

                a = p.add_mesh(m, line_width=5, cmap='Greys_r', clim=[20, 60], show_scalar_bar=False)

                p.view_xy()

                img_path = f'result/{file_id}.png'

                img_arr = p.screenshot(filename=img_path, transparent_background=True, window_size=[512, 384])

                p.remove_actor(a, render=False)
                # Try whether deep clean is nessary for large dataset
                #p.deep_clean()

                img = Image.fromarray(img_arr)
                bg = Image.new(img.mode, img.size, img.getpixel((0,0)))
                diff = ImageChops.difference(img, bg)
                diff = ImageChops.add(diff, diff, 2.0, -100)
                bbox = diff.getbbox()
                if bbox:
                    img = img.crop(bbox)
                img_tensor = to_tensor(img).to(device=device, dtype=torch.float)
                
                result_cache[state_key] = file_id
                result_image_cache[state_key] = img_tensor

                render_time = datetime.now() - start_time
                
            except Exception as e:
                img_path = None
                img_tensor = None
                render_time = 'Failed'
                result_cache[state_key] = None
                result_image_cache[state_key] = None
                with open('render_err.log', 'a') as log:
                    print('========', file=log)
                    print(datetime.now(), file=log)
                    print('Task', task_id, file=log)
                    print('Key', state_key, file=log)
                    print('Temp file', vtk_path, file=log)
                    print('Error', e, file=log)

            result_queue.put((task_id, state_key, vtk_path, img_path, img_tensor))

            with open('fem.log', 'a') as log:
                print('========', file=log)
                print(datetime.now(), file=log)
                print('Task', task_id, file=log)
                print('Key', state_key, file=log)
                print('Temp file', vtk_path, file=log)
                print('Result file', img_path, file=log)
                print('FEM worker', fem_worker, file=log)
                print('Render worker', render_worker, file=log)
                print('Queue time', queue_time, file=log)
                print('Mesh time', mesh_time, file=log)
                print('Solve time', solve_time, file=log)
                print('Render time', render_time, file=log)

### GridHoleBoardFEMSimulator

In [22]:
class GridHoleBoardFEMSimulator():
    def __init__(self, 
                 environment: GridHoleBoardEnv, 
                 *, 
                 element_diameter: float = 2, 
                 fem_task_pool_size: int = 32,
                 render_task_pool_size: int = 8) -> None:
        self.environment = environment
        self.board_size = environment.size
        self.grid_size = environment.grid_size
        self.holes_center = environment.holes_center.cpu()
        self.element_diameter: float = element_diameter
        self.fem_physic: FEMPhysic = FEMPhysic()
            
        
        self.mp_manager: Manager = Manager()
            
        self.mp_fem_task_pool_size: int = fem_task_pool_size
        self.mp_fem_task_pool: Pool = Pool(fem_task_pool_size)
            
        self.mp_render_task_pool_size: int = render_task_pool_size
        self.mp_render_task_pool: Pool = Pool(render_task_pool_size)
            
        self.mp_render_queue: Queue = self.mp_manager.Queue()
        self.mp_result_queue: Queue = self.mp_manager.Queue()
            
        self.mp_result_cache: Dict[str, str] = self.mp_manager.dict()
        self.mp_result_image_cache: Dict[str, Tensor] = self.mp_manager.dict()
            
        def render_err_callback(err):
            with open('render_err.log', 'a') as log:
                print('========', file=log)
                print(datetime.now(), file=log)
                print(err, file=log)
            # TODO Restarting render process after failed
            
        self.mp_render_task_pool.starmap_async(RenderTask,
                                               [(self.mp_render_queue, self.mp_result_queue, 
                                                 self.mp_result_cache, self.mp_result_image_cache,
                                                 cuda)] 
                                               * self.mp_render_task_pool_size,
                                               error_callback=render_err_callback)
            
            
    def clear_image_cache(self):
        self.mp_result_image_cache.clear()
        
    def terminate_current_queue(self):
        self.mp_fem_task_pool.terminate()

    def reset_task_pool(self):
        self.mp_fem_task_pool = Pool(self.mp_fem_task_pool_size)  
   
    def queue(self, task_id: int, state_key: str, state: State) -> None:
        queue_time = datetime.now()
        def err_callback(err):
            with open('fem_err.log', 'a') as log:
                print('========', file=log)
                print(datetime.now(), file=log)
                print('Task', task_id, file=log)
                print('Key', state_key, file=log)
                print('State', state, file=log)
                print(err, file=log)
                
        self.mp_fem_task_pool.apply_async(func=GridHoleBoardFEMTask,
                                          args=(self.mp_render_queue, 
                                                self.mp_result_cache,
                                                task_id,
                                                state_key, state.cpu(), 
                                                self.board_size, self.grid_size, self.holes_center, 
                                                self.element_diameter, self.fem_physic, queue_time), 
                                          error_callback=err_callback)
    
    def complete(self):
        self.mp_fem_task_pool.close()
        self.mp_fem_task_pool.join()

## DQN 

### Hyperparameters

In [23]:
epsilon = 1.
epsilon_decay = .995
lr = .001
replay_batch_size = 32
target_update_interval = 10

final_threshold = 0.01
final_step_threshold = 100

### Mock Functions

In [24]:
target_state = torch.tensor([0, 3, 0, 3, 
                             0, 3, 0, 3, 
                             0, 3, 0, 3], device=cuda)

def MockRewardFunc(state: State, action_index: int, next_state: State) -> float:
    loss = nn.MSELoss(reduction='sum')
    return float(loss(state.flatten(), target_state)) - float(loss(next_state.flatten(), target_state))

def MockFinalFunc(episode: int, step: int, state: State, action_index: int, reward: float, next_state: State):
    loss = nn.MSELoss(reduction='sum')
    if float(loss(state.flatten(), target_state)) < final_threshold or step >= final_step_threshold:
        return state, action_index, reward, None
    return state, action_index, reward, next_state


### FEM-based Reward & Final Function

In [25]:
class FEMReward():
    def __init__(self, 
                 target: str, 
                 compare_spot: List[Tuple[int, int]],
                 max_step: int = 1000,
                 goal_reward: float = 10000.,
                 final_state_threshold: float = 0.01,
                 invalid_state_penalty: float = -100.,
                 revalid_state_reward: float = 50.) -> None:
        self.target_image = PILToTensor()(Image.open(target))[0].to(device=cuda, dtype=torch.float)
                                       
        # Create a bool mask for extracting the comparing points
        self.compare_spot = compare_spot
        self.compare_mask = torch.BoolTensor(*(self.target_image.size())).fill_(False).to(cuda)
        for spot in compare_spot:
            self.compare_mask[spot] = True
            
        self.target_spot = self.target_image[self.compare_mask]
            
        self.loss = nn.MSELoss(reduction='sum')
        
        
        self.goal_reward = goal_reward
        self.final_state_threshold = final_state_threshold
        self.invalid_state_penalty = invalid_state_penalty
        self.revalid_state_reward = revalid_state_reward
        
        
    def __call__(self, step: int, state_fem_image: Optional[Tensor], 
                 action_index: int, next_state_fem_image: Optional[Tensor]):
        
        reward = 0.
        final = False
        
        if state_fem_image is not None:
            state_spot = state_fem_image[0][self.compare_mask]
            error = float(self.loss(state_spot, self.target_spot))
            final = error <= self.final_state_threshold           
        else:
            final = False
            
        if next_state_fem_image is not None:
            next_state_spot = next_state_fem_image[0][self.compare_mask]
            next_error = float(self.loss(next_state_spot, self.target_spot))
            if next_error <= self.final_state_threshold:
                reward += self.goal_reward
            
        if state_fem_image is not None:
            if next_state_fem_image is not None:
                # Regular loss-based reward for valid-to-valid move
                reward += error - next_error
            else:
                # Give penalty if step into invalid state(unsolvable fem)
                reward += self.invalid_state_penalty
        else:
            if next_state_fem_image is not None:
                # Give a one-time reward for backing to a valid state, should be small
                reward += self.revalid_state_reward
            else:
                # Give penalty for each step in invalid states
                reward += self.invalid_state_penalty
                
        with open('reward.log', 'a') as log:
            print(f'{reward:.3f}', end='\t', file=log)
            if state_fem_image is not None:
                print(f'{error:.3f}', end='\t', file=log)
            else:
                print(f'      ', end='\t', file=log)
            if next_state_fem_image is not None:
                print(f'{next_error:.3f}', end='\t', file=log)
            else:
                print(f'      ', end='\t', file=log)
            print(file=log)
        return reward, final

### Network Definition

In [26]:
# Network Container    
class Model():
    def __init__(self, network: nn.Module, loss_func: _Loss, optimizer: Optimizer):
        self.network = network
        self.loss_func = loss_func
        self.optimizer = optimizer
    
    def __call__(self, network_input: Tensor) -> Tensor:
        return self.network(network_input)

In [27]:
def QNet(state_size: int = 12, action_number: int = 24, target_network: bool = False):
    net = nn.Sequential(
        nn.Linear(state_size, 100, device=cuda),
        nn.ReLU(),
        nn.Linear(100, 200, device=cuda),
        nn.ReLU(),
        nn.Linear(200, action_number, device=cuda),
    )
    if target_network:
        return Model(network=net, loss_func=None, optimizer=None)
    else:
        return Model(network=net, loss_func=nn.SmoothL1Loss(), optimizer=torch.optim.Adam(net.parameters(), 0.001))

### Replay Memory Class

In [28]:
class ReplayMemory():
    def __init__(self, capacity):
        self.memory: deque = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

### Agent Class

In [29]:
class Agent():
    def __init__(self,
                 environment: GridHoleBoardEnv, 
                 fem_simulator: GridHoleBoardFEMSimulator,
                 fem_reward_func: Callable[[Tensor, int, Tensor], float],
                 q_network: nn.Module, 
                 target_network: nn.Module, 
                 target_update_interval: int = 100, 
                 optimization_epoch: int = 1,
                 experience_replay: ReplayMemory = ReplayMemory(10000), 
                 replay_batch_size: int = 32, 
                 discount_factor: float = 0.9,
                 explore_factor_start: float = 1,
                 explore_factor_end: float = 0.05,
                 explore_factor_decay: float = 200,
                ) -> None:
        self.environment: GridHoleBoardEnv = environment
        self.fem_simulator: GridHoleBoardThermalSimulator = fem_simulator
        self.fem_reward_func: Callable[[Tensor, int, Tensor], float] = fem_reward_func
            
        self.q_network: Model = q_network
        self.target_network: Model = target_network
        self.target_update_interval: int = target_update_interval
        
        self.optimization_epoch: int = optimization_epoch    
        self.experience_replay: ReplayMemory = experience_replay
        self.replay_batch_size: int = replay_batch_size
            
        self.discount_factor: float = discount_factor
        self.explore_factor_start: float = explore_factor_start
        self.explore_factor_end: float = explore_factor_end
        self.explore_factor_decay: float = explore_factor_decay
        self.explored_step: int = 0
            
        self.fem_task_buffer: Dict[int, FEMTask] = dict()
            
        self.running_loss: List[float] = []
            
            
    def select_action(self, disable_explore: bool = False) -> Tuple[State, int]:
        state = self.environment.get_state()
        explore_factor = self.explore_factor_end \
                            + (self.explore_factor_start - self.explore_factor_end) \
                            * math.exp(-1. * self.explored_step / self.explore_factor_decay)
        if random.random() > explore_factor or disable_explore:
            prediction = self.q_network(state.flatten())
            action_index = prediction.argmax()
        else:
            action_index = random.randrange(len(self.environment.action_space))
        self.explored_step += 1
        return state, action_index
            
    def step(self, episode: int, step: int, disable_explore: bool = False, max_step: int = 1000) -> bool:
        state, action_index = self.select_action(disable_explore)

        self.environment.step(self.environment.action_space[action_index])
        next_state = self.environment.get_state()
        
        state_key = str(state.flatten().tolist())
        next_state_key = str(next_state.flatten().tolist())
        
        task = FEMTask(state, state_key, action_index, next_state, next_state_key)
        self.fem_task_buffer[step] = task
        
        self.fem_simulator.queue(step, state_key, state)
        self.fem_simulator.queue(step, next_state_key, next_state)
        
        # Return True if the next state is final
        return step >= max_step
        
    def complete_fem_rewarding(self) -> None:
        fem = self.fem_simulator    # Create local reference to shorter statements
        print('Start computing rewards based on completed FEM simulation')
        i = -1
        while len(self.fem_task_buffer) > 0:
            i += 1
            if i % 100 == 0:
                print(f'Received result: {i} Remaining transition: {len(self.fem_task_buffer)}')
            
            task_id, state_key, vtk_path, img_path, img_tensor = fem.mp_result_queue.get()
            
            if task_id not in self.fem_task_buffer: 
                continue
                
            task = self.fem_task_buffer[task_id]
            
            step = task_id
                        
            if task.state_key in fem.mp_result_image_cache and task.next_state_key in fem.mp_result_image_cache:
                # Remove transition from wait list
                self.fem_task_buffer.pop(task_id)
                
                if fem.mp_result_image_cache[task.state_key] is None or fem.mp_result_image_cache[task.next_state_key] is None:
                    continue
                    
                state_image = fem.mp_result_image_cache.get(task.state_key, None)
                next_state_image = fem.mp_result_image_cache.get(task.next_state_key, None)
                
                reward, final = self.fem_reward_func(step, state_image, task.action_index, next_state_image)
                
                if final:
                    self.experience_replay.push(task.state, task.action_index, reward, None)
                    fem.terminate_current_queue()
                    print('Early stop triggered, terminating FEM tasks in current episode')
                    break
                else:
                    self.experience_replay.push(task.state, task.action_index, reward, task.next_state)
                    
        self.fem_task_buffer.clear()
        
    def optimize(self) -> None:
        if len(self.experience_replay) < replay_batch_size: return
        
        samples = self.experience_replay.sample(self.replay_batch_size)
        batch = Transition(*zip(*samples))

        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                                  batch.next_state)), device=cuda, dtype=torch.bool)
        non_final_next_states = torch.stack([s.flatten() for s in batch.next_state
                                                        if s is not None])

        state_batch = torch.stack([s.flatten() for s in batch.state])
        action_batch = torch.tensor(batch.action_index, device=cuda).unsqueeze(1)
        reward_batch = torch.tensor(batch.reward, device=cuda)

        state_action_values = self.q_network(state_batch).gather(1, action_batch)


        next_state_values = torch.zeros(self.replay_batch_size, device=cuda)
        next_state_values[non_final_mask] = self.target_network(non_final_next_states).max(1)[0].detach()

        expected_state_action_values = (next_state_values * self.discount_factor) + reward_batch

        loss = self.q_network.loss_func(state_action_values, expected_state_action_values.unsqueeze(1))
        self.q_network.optimizer.zero_grad()
        loss.backward()
        #print(f'Optimization loss {loss}')
        self.running_loss.append(float(loss))
        for param in self.q_network.network.parameters():
            param.grad.data.clamp_(-1, 1)
        self.q_network.optimizer.step()

    def update_target_network(self) -> None:
        self.target_network.network.load_state_dict(self.q_network.network.state_dict())

    def train(self, episode: int = 1000, step_per_episode: int = 1000) -> None:
        shortest_path = 10000
        for e in range(episode):
            start_time = datetime.now()
            self.environment.reset()
#             for step_num in itertools.count():
#                 if step_num % 100 == 0:
#                     print(f'Episode {e} running, Current step {step_num}', end='\r\n')
#                     pass
#                 if self.step(e, step_num):
#                     print(f'Episode {e} completed in {step_num} steps                    ', end='\r\n')
#                     if step_num < shortest_path:
#                         shortest_path = step_num
#                         print(f'Shorter path found in episode {e} with {shortest_path} steps       ')
#                     break
            # Use fixed number of steps in each episode
            for step_num in range(step_per_episode):
                self.step(e, step_num)
            print(f'Episode {e} completed')
            self.complete_fem_rewarding()
            for i in range(self.optimization_epoch):
                self.optimize()
            self.fem_simulator.clear_image_cache()
            self.fem_simulator.reset_task_pool()
            if e % self.target_update_interval == 0:
                #print('Target network update                 ')
                self.update_target_network()
                
            end_time = datetime.now()
            print(f'Episode {e} completed in {end_time - start_time}')
                
    def generate(self, step: int = 100) -> State:
        self.environment.reset()
        for i in range(step):
            if self.step(100000000, i):
                print(f'Completed in {step} steps                    ', end='\r\n')
                break
        return self.environment.get_state(), i

## Training

In [30]:
env = GridHoleBoardEnv(size=(80, 60), grid_size=(4, 3))

fem = GridHoleBoardFEMSimulator(env, fem_task_pool_size=32, render_task_pool_size=8)

reward_func = FEMReward('target.png', [(73, 36),
                                       (147, 36),
                                       (73, 110),
                                       (147, 110),
                                       (73, 184),
                                       (147, 184),
                                       (73, 258),
                                       (147, 258)])

agent = Agent(env, fem, reward_func, QNet(12, 24), QNet(12, 24, target_network=True))

In [None]:
#print(agent.generate())

agent.train(1000)

#print(agent.generate())

Episode 0 completed
Start computing rewards based on completed FEM simulation
Received result: 0 Remaining transition: 1000
Received result: 100 Remaining transition: 951
Received result: 200 Remaining transition: 900
Received result: 300 Remaining transition: 850
Received result: 400 Remaining transition: 799
Received result: 500 Remaining transition: 749
Received result: 600 Remaining transition: 699
Received result: 700 Remaining transition: 649
Received result: 800 Remaining transition: 600
Received result: 900 Remaining transition: 549
Received result: 1000 Remaining transition: 499
Received result: 1100 Remaining transition: 449
Received result: 1200 Remaining transition: 399
Received result: 1300 Remaining transition: 349
Received result: 1400 Remaining transition: 299
Received result: 1500 Remaining transition: 248
Received result: 1600 Remaining transition: 200
Received result: 1700 Remaining transition: 148
Received result: 1800 Remaining transition: 99
Received result: 1900 



Received result: 1100 Remaining transition: 449
Received result: 1200 Remaining transition: 399
Received result: 1300 Remaining transition: 350
Received result: 1400 Remaining transition: 301
Received result: 1500 Remaining transition: 250
Received result: 1600 Remaining transition: 201
Received result: 1700 Remaining transition: 150
Received result: 1800 Remaining transition: 101
Received result: 1900 Remaining transition: 51
Received result: 2000 Remaining transition: 2
Episode 19 completed in 0:14:36.414716
Episode 20 completed
Start computing rewards based on completed FEM simulation
Received result: 0 Remaining transition: 1000
Received result: 100 Remaining transition: 950
Received result: 200 Remaining transition: 900
Received result: 300 Remaining transition: 849
Received result: 400 Remaining transition: 798
Received result: 500 Remaining transition: 749
Received result: 600 Remaining transition: 699
Received result: 700 Remaining transition: 648
Received result: 800 Remaining



Received result: 1900 Remaining transition: 49




Received result: 2000 Remaining transition: 1
Episode 24 completed in 0:11:08.417564
Episode 25 completed
Start computing rewards based on completed FEM simulation
Received result: 0 Remaining transition: 1000
Received result: 100 Remaining transition: 948
Received result: 200 Remaining transition: 899
Received result: 300 Remaining transition: 851
Received result: 400 Remaining transition: 800
Received result: 500 Remaining transition: 747
Received result: 600 Remaining transition: 699
Received result: 700 Remaining transition: 650
Received result: 800 Remaining transition: 598
Received result: 900 Remaining transition: 549
Received result: 1000 Remaining transition: 499
Received result: 1100 Remaining transition: 446
Received result: 1200 Remaining transition: 400
Received result: 1300 Remaining transition: 349
Received result: 1400 Remaining transition: 299
Received result: 1500 Remaining transition: 249
Received result: 1600 Remaining transition: 200
Received result: 1700 Remaining

In [None]:
fig, ax = plt.subplots()
ax.scatter(list(range(len(agent.running_loss))), agent.running_loss, s=1, vmin=0, vmax=2)
plt.show()

## Manuscript

In [None]:
env = GridHoleBoardEnv(size=(80, 60), grid_size=(4, 3))

env.reset()
env.step(((1, 1), 2.5))

state = env.get_state()

fem = GridHoleBoardFEMSimulator(env)

In [None]:
str(state.flatten().tolist())

In [None]:
sys.getsizeof(str(state.flatten().tolist()))

In [None]:
mesh, boundary = fem.generate_mesh(state)

In [None]:
filename = fem.solve(mesh, boundary)

In [None]:
fem.render_image(filename)

In [None]:
fem.run(state)