In [None]:
# Notebook for concept detection in neural network

In [10]:
import numpy as np
import os
import sys
from tqdm import tqdm

sys.path.append(os.path.abspath(os.path.join(os.path.pardir, 'src')))

import concepts
import env
from policy import ActorCriticNet

model_name = "net"

agents_to_sample = [0, 150, 300, 450, 600]

full_model_path = "../models/saved_sessions/board_size_5/"
board_size = 5
dims = (4, 5, 5)

CONCEPT_FUNC = concepts.concept_area_advantage

CONCEPT_NAME = concepts.concept_area_advantage.__name__


In [11]:
def load_model(full_name, model_name, epoch):
    model_path = full_name + model_name + "_" + str(epoch) + ".keras"
    model = ActorCriticNet(board_size, model_path)
    return model

agents = [load_model(full_model_path, model_name, epoch) for epoch in agents_to_sample]



In [12]:
def play_match(agents: list[ActorCriticNet], board_size, concept_function):
    go_env = env.GoEnv(board_size)
    state = go_env.reset()

    player_to_start = 1 if np.random.random() > 0.5 else 0
    current_player = player_to_start

    positive_cases = []
    negative_cases = []

    sample_ratio = 0.2

    game_over = False

    while not game_over:
        if np.random.random() < sample_ratio:
            if concept_function(state):
                positive_cases.append(state)
            else:
                negative_cases.append(state)
        
        action = agents[current_player].best_action(state)

        state, _, game_over, _ = go_env.step(action)

        current_player = (current_player + 1) % 2
    
    return positive_cases, negative_cases

In [13]:
positive_cases = []
negative_cases = []

CASES_TO_SAMPLE = 1000 # 25000

positive_bar = tqdm(total=CASES_TO_SAMPLE, desc="Positive cases")
while len(positive_cases) < CASES_TO_SAMPLE:
    pos, neg = play_match(agents, board_size, CONCEPT_FUNC)
    positive_cases.extend(pos)
    negative_cases.extend(neg)
    positive_bar.update(len(pos))

positive_cases = positive_cases[:CASES_TO_SAMPLE]
negative_cases = negative_cases[:CASES_TO_SAMPLE]

Positive cases: 100%|██████████| 1000/1000 [01:42<00:00, 12.62it/s]

Positive cases: 100%|██████████| 1000/1000 [02:00<00:00, 12.62it/s]

In [None]:
positive_cases = np.array(positive_cases)
negative_cases = np.array(negative_cases)

print("Positive cases: ", positive_cases.shape)
print("Negative cases: ", negative_cases.shape)