In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
import json
from copy import deepcopy

from cthulhu import CthulhuShield
from study import LOG_DIR, CthulhuGrid

import numpy as np

from matplotlib import pyplot as plt
import matplotlib as mpl

In [None]:
shield = CthulhuShield()

In [None]:
shield.stim([16, 17])

In [None]:
shield.stim([0, 1, 2, 3])

In [None]:
shield.stim([12, 13, 14, 15])

In [None]:
shield.stim([3, 7, 11, 15])

In [None]:
shield.stim([0, 4, 8, 12])

In [None]:
shield.stim(list(range(18)))

In [None]:
shield.stop()

In [None]:
log_dir = LOG_DIR
logs = []
for fname in os.listdir(log_dir):
  if fname.endswith('.json'):
    path = os.path.join(log_dir, fname)
    with open(path, 'r') as f:
      log = json.load(f)
      log = sorted(log, key=lambda x: x['timestamp'])
      logs.append(log)
logs = sorted(logs, key=lambda x: x[0]['timestamp'])

In [None]:
def parse(log):
  patterns = []
  final_states = []
  for event in log:
    name = event['name']
    if name == 'reset':
      pattern = deepcopy(event['pattern'])
      states = deepcopy(event['states'])
    elif name == 'toggle':
      idx = event['idx']
      states[idx] = not states[idx]
    elif name == 'submit':
      patterns.append(deepcopy(pattern))
      final_states.append(deepcopy(states))
    else:
      raise ValueError(name)
  return patterns, final_states

In [None]:
X = []
Y = []
for log in logs:
  patterns, final_states = parse(log)
  X.extend(patterns)
  Y.extend(final_states)
X = np.array(X)
Y = np.array(Y)

In [None]:
X.shape, Y.shape

In [None]:
X.mean(), Y.mean()

In [None]:
accs = np.mean(X == Y, axis=1)
np.mean(accs)

In [None]:
plt.xlabel('Number of Items Completed')
plt.ylabel('Prediction Accuracy')
plt.plot(accs, color='orange')
plt.axhline(y=X.mean(), linestyle='--', label='Random (Baseline)', color='gray')
plt.legend(loc='best')
plt.show()

In [None]:
grid = CthulhuGrid(CthulhuShield(debug_mode=True), None)
poses = grid.poses

In [None]:
def c(x, y):
  if x and not y:
    return 'red'
  elif not x and y:
    return 'blue'
  elif x and y:
    return 'green'
  else:
    return 'gray'

def plot(i):
  plt.title(i)
  colors = [c(x, y) for x, y in zip(X[i], Y[i])]
  plt.scatter(poses[:, 0], -poses[:, 1], s=grid.radius*50, c=colors)
  plt.axis('off')
  plt.show()

In [None]:
for i in range(X.shape[0]):
  plot(i)
  print(''.join(['*'] * 60))