In [1]:
import numpy as np 

seed = 24 
rng = np.random.RandomState(seed)

In [2]:
import json
import yaml
import argparse
import numpy as np
from pathlib import Path
from datetime import datetime

from src.registries import DATASETS, MODELS, OPTIMIZERS, EVALUATORS
from src.hooks.extract import extract_paired_trajectory_activations
from src.evaluator.registry import run_evals

# Import to register
import src.datasets.datasets
import src.models.models
import src.optimizers.optimizers


def load_yaml(path: str) -> dict:
    with open(path) as f:
        return yaml.safe_load(f)


def save_json(obj, path: Path):
    with open(path, 'w') as f:
        json.dump(obj, f, indent=2)


def make_run_dir(name: str) -> Path:
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_dir = Path("runs") / name / ts
    run_dir.mkdir(parents=True, exist_ok=True)
    return run_dir


cfg = load_yaml('configs/novelty.yaml')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
net = MODELS.get(cfg['model']['source'])(cfg['model'])

In [4]:
# for novelty we need to collect human and AZ game positions 
# so we need to get the human games loaded in and the AZ games loaded in 
pos_cfg = cfg['positions']
human_games = DATASETS.get(pos_cfg['source'])(pos_cfg)
np.save("human_games.npy", human_games)

Loading Kaggle Othello dataset...
Loaded 25657 games from Kaggle

Collection statistics:
  Total games processed: 446
  Successful parses: 446 (100.0%)
  No positions in range [0,60]: 0
  Parse failures: 0
  Too short: 0
  Total positions collected: 25037


In [None]:
ds_cfg = cfg['dataset']
ds_cfg['net'] = net
az_games = DATASETS.get(ds_cfg['source'])(ds_cfg) # these probably need to be played games 

In [6]:
# now need to get the activations 
from src.hooks.extract import extract_features_by_layer

# get the activations for the human games 
human_activations = extract_features_by_layer(net, human_games, cfg['hooks']['layers'])

# get the activations for the AZ games 
az_activations = extract_features_by_layer(net, az_games, cfg['hooks']['layers'])

In [7]:
from src.novelty.novelty import NoveltyFilter

filter = NoveltyFilter(human_activations['bn2'], az_activations['bn2'])