# Notebook providing Concept-Based Explanations to End Users at deployment time using Joint Embedding Model and concept functions

In [5]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.pardir, 'src')))

from policy import ResNet
from env import GoEnv
from jem import data_utils
import numpy as np

USE_JOINT_EMBEDDING_MODEL = False
BOARD_SIZE = 7
KOMI = 5.5
MOVE_CAP = 100


model = ResNet(BOARD_SIZE, load_path='../models/play_against/resnet/board_size_7/net_800.keras')

go_env = GoEnv(BOARD_SIZE, KOMI)

go_env.reset()

# Number of moves in the game
move_nr = 0
# Get the initial state
init_state = go_env.canonical_state()

game_over = False

# Black always starts
curr_player = 0
prev_turn_state = np.zeros((BOARD_SIZE, BOARD_SIZE))
temp_prev_turn_state = np.zeros((BOARD_SIZE, BOARD_SIZE))
prev_opposing_state = np.zeros((BOARD_SIZE, BOARD_SIZE))

while not game_over and move_nr < MOVE_CAP:
    # Get the player
    curr_state = go_env.canonical_state()
    valid_moves = go_env.valid_moves()

    
    if curr_player == 0:
        state = np.array([curr_state[0], prev_turn_state, curr_state[1],
                         prev_opposing_state, np.zeros((BOARD_SIZE, BOARD_SIZE))])
    else:
        state = np.array([curr_state[0], prev_turn_state, curr_state[1],
                             prev_opposing_state, np.ones((BOARD_SIZE, BOARD_SIZE))])
        
    # Get the action from the model
    action, value = model.best_action(state, valid_moves=valid_moves)

    # Apply the action to the environment
    _, _, game_over, _ = go_env.step(action)

    state_after_action = go_env.canonical_state()

    if curr_player == 0:
        state = np.array([state_after_action[1], curr_state[0], state_after_action[0],
                         curr_state[1], np.zeros((BOARD_SIZE, BOARD_SIZE))])
    else:
        state = np.array([state_after_action[1], curr_state[0], state_after_action[0],
                             curr_state[1], np.ones((BOARD_SIZE, BOARD_SIZE))])
    
    # Render the environment
    go_env.render()

    # Provide explanation for the move
    explanation, reward = data_utils.get_explanation_from_state(state, USE_JOINT_EMBEDDING_MODEL)

    print(f'Explanation: {explanation}\nReward: {reward}\nValue: {value}')

    # Flipp the player
    curr_player = 1 - curr_player

    # Update the previous state
    prev_turn_state = temp_prev_turn_state
    prev_opposing_state = curr_state[0]
    temp_prev_turn_state = prev_opposing_state

    # Increment the move number
    move_nr += 1

# Get the winner of the game in black's perspective (1 for win and -1 for loss)
winner = go_env.winning()

print("Black won!") if winner == 1 else print("White won!")


	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═╗
1	╟─┼─┼─┼─┼─┼─╢
2	╟─┼─┼─┼─○─┼─╢
3	╟─┼─┼─┼─┼─┼─╢
4	╟─┼─┼─┼─┼─┼─╢
5	╟─┼─┼─┼─┼─┼─╢
6	╚═╧═╧═╧═╧═╧═╝
	Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 49, White Area: 0

Explanation: plays in the center of the board in the opening to gain control
Reward: 0.1
Value: -0.17767879366874695
	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═╗
1	╟─┼─┼─┼─┼─┼─╢
2	╟─┼─┼─┼─○─┼─╢
3	╟─┼─┼─┼─┼─┼─╢
4	╟─┼─┼─┼─●─┼─╢
5	╟─┼─┼─┼─┼─┼─╢
6	╚═╧═╧═╧═╧═╧═╝
	Turn: BLACK, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 1, White Area: 1

Explanation: plays in the center of the board in the opening to gain control
Reward: 0.1
Value: 0.21696071326732635
	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═╗
1	╟─┼─┼─┼─┼─┼─╢
2	╟─○─┼─┼─○─┼─╢
3	╟─┼─┼─┼─┼─┼─╢
4	╟─┼─┼─┼─●─┼─╢
5	╟─┼─┼─┼─┼─┼─╢
6	╚═╧═╧═╧═╧═╧═╝
	Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 2, White Area: 1

Explanation: a generic move not tied to a strategy
Reward: 0
Value: 0.04214867949485779
	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═╗
1	╟─