# Notebook providing Concept-Based Explanations to End Users at deployment time using Joint Embedding Model and concept functions

In [4]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.pardir, 'src')))

from policy import ResNet, ConceptNet
from env import GoEnv
from jem import data_utils
import numpy as np

USE_CONCEPT_BOTTLENECK_MODEL = False # If False, use the ResNet model

USE_JOINT_EMBEDDING_MODEL = False # Only used if USE_CONCEPT_BOTTLENECK_MODEL is False

if USE_CONCEPT_BOTTLENECK_MODEL:
    BOARD_SIZE = 5
    KOMI = 1.5
    MOVE_CAP = 100
else:
    BOARD_SIZE = 7
    KOMI = 3.5
    MOVE_CAP = 100

if USE_CONCEPT_BOTTLENECK_MODEL:
    model = ConceptNet(BOARD_SIZE, load_path='../models/cbm/net_1000.keras')
else:
    model = ResNet(BOARD_SIZE, load_path='../models/play_against/resnet/board_size_7/net_800.keras')

go_env = GoEnv(BOARD_SIZE, KOMI)

go_env.reset()

# Number of moves in the game
move_nr = 0
# Get the initial state
init_state = go_env.canonical_state()

game_over = False

# Black always starts
curr_player = 0
prev_turn_state = np.zeros((BOARD_SIZE, BOARD_SIZE))
temp_prev_turn_state = np.zeros((BOARD_SIZE, BOARD_SIZE))
prev_opposing_state = np.zeros((BOARD_SIZE, BOARD_SIZE))

while not game_over and move_nr < MOVE_CAP:
    # Get the player
    curr_state = go_env.canonical_state()
    valid_moves = go_env.valid_moves()

    
    state = np.array([curr_state[0], 
                      prev_turn_state, 
                      curr_state[1],
                      prev_opposing_state, 
                      np.full((BOARD_SIZE, BOARD_SIZE), curr_player)])
        
    # Get the action from the model
    if USE_CONCEPT_BOTTLENECK_MODEL:
        explanation, action, value = model.best_action_with_explanation(state, valid_moves=valid_moves)
    else:
        action, value = model.best_action(state, valid_moves=valid_moves)

    # Apply the action to the environment
    _, _, game_over, _ = go_env.step(action)

    state_after_action = go_env.canonical_state()

    state = np.array([state_after_action[1], 
                      curr_state[0], 
                      state_after_action[0],
                      curr_state[1], 
                      np.full((BOARD_SIZE, BOARD_SIZE), curr_player)])
    
    # Render the environment
    go_env.render()

    # Provide explanation for the move
    if USE_CONCEPT_BOTTLENECK_MODEL:
        explanation, reward = data_utils.get_explanation_from_index(explanation)
    else:
        explanation, reward = data_utils.get_explanation_from_state(state, USE_JOINT_EMBEDDING_MODEL)

    print(f'Explanation: {explanation}\nReward: {reward}\nValue: {value}')

    # Flipp the player
    curr_player = 1 - curr_player

    # Update the previous state
    prev_turn_state = temp_prev_turn_state
    prev_opposing_state = curr_state[0]
    temp_prev_turn_state = prev_opposing_state

    # Increment the move number
    move_nr += 1

# Get the winner of the game in black's perspective (1 for win and -1 for loss)
winner = go_env.winning()

print("Black won!") if winner == 1 else print("White won!")


	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═╗
1	╟─┼─┼─┼─┼─┼─╢
2	╟─┼─┼─┼─┼─┼─╢
3	╟─┼─┼─○─┼─┼─╢
4	╟─┼─┼─┼─┼─┼─╢
5	╟─┼─┼─┼─┼─┼─╢
6	╚═╧═╧═╧═╧═╧═╝
	Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 49, White Area: 0

Explanation: plays in the center of the board in the opening to gain control
Reward: 0.1
Value: -0.17767879366874695
	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═●
1	╟─┼─┼─┼─┼─┼─╢
2	╟─┼─┼─┼─┼─┼─╢
3	╟─┼─┼─○─┼─┼─╢
4	╟─┼─┼─┼─┼─┼─╢
5	╟─┼─┼─┼─┼─┼─╢
6	╚═╧═╧═╧═╧═╧═╝
	Turn: BLACK, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 1, White Area: 1

Explanation: creates an area advantage
Reward: 0.2
Value: 0.0058687468990683556
	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═●
1	╟─┼─┼─┼─┼─┼─╢
2	╟─┼─┼─┼─┼─┼─╢
3	╟─┼─┼─○─○─┼─╢
4	╟─┼─┼─┼─┼─┼─╢
5	╟─┼─┼─┼─┼─┼─╢
6	╚═╧═╧═╧═╧═╧═╝
	Turn: WHITE, Game State (ONGOING|PASSED|END): ONGOING
	Black Area: 2, White Area: 1

Explanation: plays in the center of the board in the opening to gain control
Reward: 0.1
Value: 0.7144497036933899
	0 1 2 3 4 5 6 
0	╔═╤═╤═╤═╤═╤═●
1	╟─┼─┼─┼─┼─┼