In [1]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import os
from stable_baselines3 import PPO
from discovery.utils.feat_extractors import NatureCNN
from stable_baselines3.common.utils import obs_as_tensor
from discovery.experiments.FeatAct_minigrid.helpers import pre_process_obs
import cv2

from discovery.utils import filesys

pygame 2.5.2 (SDL 2.28.3, Python 3.10.11)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
filesys.set_directory_in_project()
agent = PPO.load("discovery/experiments/FeatAct_atari/models/Seaquest-v5_mpqgvvr1.zip")

Changed working directory to /Users/kevinroice/Documents/research/discovery


Exception: code expected at most 16 arguments, got 18
Exception: code expected at most 16 arguments, got 18


Load in the seaquest dataset.

In [3]:

path = f"/Users/kevinroice/Documents/research/discovery/datasets/AAD/clean/SeaquestNoFrameskip-v4/episode(1).hdf5"
with h5py.File(path, "r") as f:
    state = f["state"][...]
labels = np.load("/Users/kevinroice/Documents/research/discovery/datasets/AAD/clean/SeaquestNoFrameskip-v4/episode(1)_labels.npy")

In [4]:
state[0, :, :, :].shape

(210, 160, 3)

In [17]:
def pre_process_atari(dataset: np.ndarray):
    num_images = dataset.shape[0]
    preprocessed_images = np.zeros((num_images, 84, 84), dtype=np.uint8)

    for i in range(num_images):
        preprocessed_images[i] = cv2.resize(cv2.cvtColor(dataset[i], cv2.COLOR_RGB2GRAY), (84, 84))

        # Stack frames
        stacked_images = np.zeros((num_images - 3, 4, 84, 84), dtype=np.uint8)
        for i in range(num_images - 3):
            stacked_images[i] = np.stack(
                [
                    preprocessed_images[i],
                    preprocessed_images[i + 1],
                    preprocessed_images[i + 2],
                    preprocessed_images[i + 3],
                ]
            )

    return stacked_images


def stack_labels(labels):
    stacked_labels = np.zeros((labels.shape[0] - 3, 1), dtype=np.uint8)
    for i in range(labels.shape[0] - 3):
        stacked_labels[i] = labels[i + 3]
    return np.squeeze(stacked_labels)

In [6]:
def preproc_to_feats_atari(model, preprocs):
    with torch.no_grad():
        tensors = obs_as_tensor(preprocs, model.device)
        print(tensors.shape)
        feats = model.policy.extract_features(tensors) 
    return feats

In [7]:
pre_processed_states = pre_process_atari(state)

In [9]:
# turn all 1s and 3s into 0s
labels[labels == 1] = 0
labels[labels == 3] = 0
# turn all 2s into 1s
labels[labels == 2] = 1

In [18]:
stacked_labels = stack_labels(labels)

In [21]:
pre_processed_states.shape

(1354, 4, 84, 84)

In [22]:
stacked_labels.shape

(1354,)

In [24]:
feats = preproc_to_feats_atari(agent, pre_processed_states)

torch.Size([1354, 4, 84, 84])


In [25]:
from discovery.utils import sg_detection 
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import copy
import tqdm
import torch.nn as nn
import torch.optim as optim
import importlib

def train_classifier(clf, X, labels,
                     n_epochs=500,
                     batch_size=32,
                     test_size=0.2, random_state=0):
    # X = torch.cat(feats, dim=0)
    y = torch.tensor(labels).float()
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        test_size=test_size,
                                                        random_state=random_state)
    
    best_acc = - np.inf
    best_weights = None
    batch_start = torch.arange(0, len(X_train), batch_size) # TODO: check if the last batch is included
    loss_fn = nn.BCELoss(reduction='none')  # reduction='none' to get per-sample loss, not mean

    num_pos = y_train.sum()
    num_neg = len(y_train) - num_pos
    base_weight = torch.tensor([1.0, num_neg/num_pos]) # for weighted mean in loss calculation
    
    optimizer = optim.Adam(clf.parameters(), lr=0.0001)
    # TODO: collect positive examples, and concatenate them to each batch

    for epoch in range(n_epochs):
        clf.train()
        with tqdm.tqdm(batch_start, unit="batch", mininterval=0, disable=False) as bar:
            bar.set_description(f"Epoch {epoch}")
            for start in bar:
                # take a batch
                X_batch = X_train[start:start+batch_size]
                y_batch = y_train[start:start+batch_size]
                # forward pass
                y_pred = clf(X_batch)
                y_batch = y_batch.unsqueeze(1)
                weight = torch.where(y_batch == 1, base_weight[1], base_weight[0])
                loss2 = loss_fn(y_pred, y_batch)
                final_loss = torch.mean(weight*loss2)
                # backward pass
                optimizer.zero_grad()
                final_loss.backward()
                # update weights
                optimizer.step()
                # print progress
                acc = (y_pred.round() == y_batch).float().mean()
                bar.set_postfix(
                    loss=float(final_loss),
                    acc=float(acc)
                )
        # # evaluate accuracy at end of each epoch
        # clf.eval()
        # y_pred = clf(X_test)
        # acc = (y_pred.round() == y_test).float().mean()
        # acc = float(acc)
        if acc > best_acc:
            best_acc = acc
            best_weights = copy.deepcopy(clf.state_dict())
    return best_acc

In [26]:
stacked_labels[0]

0

In [27]:
clf = sg_detection.LinearClassifier(input_size=512)
acc = train_classifier(clf, feats, stacked_labels)

Epoch 0: 100%|██████████| 34/34 [00:00<00:00, 177.88batch/s, acc=0.815, loss=0.519]
Epoch 1: 100%|██████████| 34/34 [00:00<00:00, 319.82batch/s, acc=0.815, loss=0.552]
Epoch 2: 100%|██████████| 34/34 [00:00<00:00, 314.19batch/s, acc=0.815, loss=0.538]
Epoch 3: 100%|██████████| 34/34 [00:00<00:00, 310.84batch/s, acc=0.852, loss=0.501]
Epoch 4: 100%|██████████| 34/34 [00:00<00:00, 299.15batch/s, acc=0.889, loss=0.458]
Epoch 5: 100%|██████████| 34/34 [00:00<00:00, 282.95batch/s, acc=0.889, loss=0.419]
Epoch 6: 100%|██████████| 34/34 [00:00<00:00, 286.41batch/s, acc=0.889, loss=0.384]
Epoch 7: 100%|██████████| 34/34 [00:00<00:00, 218.26batch/s, acc=0.889, loss=0.355]
Epoch 8: 100%|██████████| 34/34 [00:00<00:00, 202.58batch/s, acc=0.889, loss=0.331]
Epoch 9: 100%|██████████| 34/34 [00:00<00:00, 285.27batch/s, acc=0.889, loss=0.311]
Epoch 10: 100%|██████████| 34/34 [00:00<00:00, 280.37batch/s, acc=0.889, loss=0.295]
Epoch 11: 100%|██████████| 34/34 [00:00<00:00, 276.38batch/s, acc=0.926, lo

In [28]:
from sklearn.metrics import confusion_matrix
def classifier_performance(clf, X, labels):
    y = torch.tensor(labels).float().squeeze()
    y_pred = clf(X)
    acc = (y_pred.round() == y).float().mean()
    print("Accuracy: ", acc)
    y_pred_np = y_pred.detach().numpy()
    c_m = confusion_matrix(labels, y_pred_np.round())
    print("Confusion Matrix: ")
    print(c_m)
    return acc, c_m

In [29]:
classifier_performance(clf, feats, stacked_labels)

Accuracy:  tensor(0.9261)
Confusion Matrix: 
[[1292   15]
 [   5   42]]


(tensor(0.9261),
 array([[1292,   15],
        [   5,   42]]))

In [30]:
clf_nl = sg_detection.NonLinearClassifier(input_size=512, hidden_size=64)
acc = train_classifier(clf_nl, feats, stacked_labels)

Epoch 0: 100%|██████████| 34/34 [00:00<00:00, 218.37batch/s, acc=0.667, loss=0.63] 
Epoch 1: 100%|██████████| 34/34 [00:00<00:00, 204.51batch/s, acc=0.815, loss=0.522]
Epoch 2: 100%|██████████| 34/34 [00:00<00:00, 165.90batch/s, acc=0.852, loss=0.45] 
Epoch 3: 100%|██████████| 34/34 [00:00<00:00, 219.12batch/s, acc=0.852, loss=0.391]
Epoch 4: 100%|██████████| 34/34 [00:00<00:00, 207.78batch/s, acc=0.852, loss=0.343]
Epoch 5: 100%|██████████| 34/34 [00:00<00:00, 90.01batch/s, acc=0.852, loss=0.316] 
Epoch 6: 100%|██████████| 34/34 [00:00<00:00, 152.36batch/s, acc=0.852, loss=0.286]
Epoch 7: 100%|██████████| 34/34 [00:00<00:00, 197.59batch/s, acc=0.889, loss=0.266]
Epoch 8: 100%|██████████| 34/34 [00:00<00:00, 195.24batch/s, acc=0.889, loss=0.249]
Epoch 9: 100%|██████████| 34/34 [00:00<00:00, 271.01batch/s, acc=0.889, loss=0.235]
Epoch 10: 100%|██████████| 34/34 [00:00<00:00, 193.96batch/s, acc=0.889, loss=0.22] 
Epoch 11: 100%|██████████| 34/34 [00:00<00:00, 189.14batch/s, acc=0.926, lo

In [31]:
classifier_performance(clf_nl, feats, stacked_labels)

Accuracy:  tensor(0.9330)
Confusion Matrix: 
[[1300    7]
 [   7   40]]


(tensor(0.9330),
 array([[1300,    7],
        [   7,   40]]))

------

In [None]:
obs = pre_processed_states[0, :, :, :]
obs = obs.reshape(1, 4, 84, 84)
obs = obs_as_tensor(obs, agent.policy.device)
print(obs.shape)
agent.policy.extract_features(obs)

torch.Size([1, 4, 84, 84])


tensor([[10.7101,  0.0000,  3.3007,  0.0000,  0.0000,  4.7056,  6.2294,  0.0000,
          0.7875,  0.0000,  6.0332,  9.9827,  7.6485,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  2.3347,  1.9806,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9752,  0.0000,  0.0000,
          5.9296,  0.0000,  0.0000,  0.0000,  0.0000,  3.1411,  1.7217,  0.0000,
          0.0000,  0.0000,  0.0000,  1.5900,  1.1487,  9.6075,  3.9059,  7.3202,
          0.0000,  0.0000,  0.0000,  2.1230,  3.4050,  0.0000,  0.0000,  0.0000,
          0.6814,  0.0000,  6.6914,  0.0000,  9.6066,  0.1442,  0.0000,  0.5422,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4422,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0876,  2.1760,  0.0000,  5.8867,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  6.0191, 10.1115,  0.0000,  7.9544,  0.0000,  0.0627,
          0.0000,  0.0000,  

In [None]:
feats = obs_to_feats(agent, [state[0, :, :, :]])

ValueError: axes don't match array