# Binary Classification
This is a notebook whose purpose is to give a better intuition about the Binary Classification environment.

Let's start by creating the environment itself.

In [None]:
from concurrent.futures.thread import ThreadPoolExecutor

from edesdetectrl.dataloaders.echonet import Echonet
import edesdetectrl.environments.binary_classification
import gym
from edesdetectrl.config import config

volumetracings_csv_file = config["data"]["volumetracings_path"]
filelist_csv_file = config["data"]["filelist_path"]
videos_dir = config["data"]["videos_path"]
split = "TEST"
thread_pool_executor = ThreadPoolExecutor()
echonet = Echonet("TEST")
seq_iterator = echonet.get_random_generator(42, thread_pool_executor, 5)
env = gym.make("EDESClassification-v0", seq_iterator=seq_iterator)


In [None]:
observation = env.reset()
num_channels = observation.shape[0]
print(f"Shape of observation: {observation.shape} <- (num_channels, width, height)")

Let's take a look at what an observation looks like.

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(ncols=num_channels, figsize=(16, 8), dpi=100)

current_frame = int(num_channels/2)
for i in range(num_channels):
    if i == current_frame:
        axes[i].set_title("↓ Current frame ↓")
    elif i == current_frame-1:
        axes[i].set_title("←← Previous frames")
    elif i == current_frame+1:
        axes[i].set_title("Next frames →→")
        
    axes[i].imshow(observation[i], cmap='copper' if i == current_frame else 'gray')
plt.show()


Observations are pre-processed as a first step as part of the neural network model. This is what the processed observation looks like.

_NB! Pre-processing is currently bugged and is not in use._

In [None]:
import edesdetectrl.model as model
import jax.numpy as jnp

# Expected input shape is (batch_size, width, height, channels)
# Observation shape is (channels, width, height)
observation_reshaped = jnp.expand_dims(jnp.transpose(observation, (1,2,0)),0)
pre_processed_observation = model.pre_process_frames(observation_reshaped)
fig, axes = plt.subplots(ncols=num_channels, figsize=(16, 8), dpi=100)

current_frame = int(num_channels/2)
for i in range(num_channels):
    if i == current_frame:
        axes[i].set_title("↓ Current frame ↓")
    elif i == current_frame-1:
        axes[i].set_title("←← Previous frames")
    elif i == current_frame+1:
        axes[i].set_title("Next frames →→")
        
    axes[i].imshow(pre_processed_observation[0,:,:,i], cmap='copper' if i == current_frame else 'gray')
plt.show()

# Evaluation
Let's look at how a trained model performs. To run this code, you must have an already trained model available under config data->trained_params_path. See `config_upstream.toml`.

The agent in the below script is following the trained Q-function greedily, meaning that it will always select the action that it thinks is better (no exploring). We add the Q-values to a list as well as the rewards received so that we can plot below. Instead of plotting the Q-values directly, we plot the so-called advantage. Put simply, advantage is the Q-values minus the average of the Q-values. In this case, this is just to make the plots easier to read.

The blue and orange lines represents the estimated advantage of taking the D or S actions, respectively. When D has a higher value than S, the agents selects D, and vice-versa. The green line represents the reward, and will be 1 if the selected action was correct, otherwise it will be 0. A perfect agent would take actions such that the reward, the green line, is always 1.

In [None]:
import pickle

import haiku as hk
import jax
from edesdetectrl.util import functional

network = functional.chainf(
    model.get_func_approx(env.action_space.n),
    hk.transform,
    hk.without_apply_rng,
)
with open(config["data"]["trained_params_path"], "rb") as f:
    trained_params = pickle.load(f)
q = jax.jit(lambda s: network.apply(trained_params, s)[0])


s = env.reset()
done = False

states = [s]
actions = []
rewards = []
q_values = []
while not done:
    qs = q(s)
    a = jnp.argmax(qs)
    s, r, done, info = env.step(a)

    states.append(s)
    actions.append(a)
    rewards.append(r)
    q_values.append(qs)


def calc_advantage(t):
    d, s = t
    v = (d + s) / 2
    return d - v, s - v


advantage = list(map(calc_advantage, q_values))

fig, ax = plt.subplots()
ax.plot(advantage)
ax.plot(rewards)
ax.legend(["Diastole", "Systole", "Reward"])
