In [1]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import os
from stable_baselines3 import PPO
from discovery.utils.feat_extractors import NatureCNN
from stable_baselines3.common.utils import obs_as_tensor
from discovery.experiments.FeatAct_minigrid.helpers import pre_process_obs
import cv2

from discovery.utils import filesys

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
filesys.set_directory_in_project()
agent = PPO.load("discovery/experiments/FeatAct_atari/models/Seaquest-v5_mpqgvvr1.zip")

Changed working directory to /Users/szepi1991/Code/discovery


Load in the seaquest dataset.

In [4]:

path = f"datasets/AAD/clean/SeaquestNoFrameskip-v4/episode(1).hdf5"
with h5py.File(path, "r") as f:
    state = f["state"][...]
labels = np.load("datasets/AAD/clean/SeaquestNoFrameskip-v4/episode(1)_labels.npy")

In [5]:
state[0, :, :, :].shape

(210, 160, 3)

In [6]:
def pre_process_atari(dataset: np.ndarray):
    num_images = dataset.shape[0]
    preprocessed_images = np.zeros((num_images, 84, 84), dtype=np.uint8)

    for i in range(num_images):
        preprocessed_images[i] = cv2.resize(cv2.cvtColor(dataset[i], cv2.COLOR_RGB2GRAY), (84, 84))

        # Stack frames
        stacked_images = np.zeros((num_images - 3, 4, 84, 84), dtype=np.uint8)
        for i in range(num_images - 3):
            stacked_images[i] = np.stack(
                [
                    preprocessed_images[i],
                    preprocessed_images[i + 1],
                    preprocessed_images[i + 2],
                    preprocessed_images[i + 3],
                ]
            )

    return stacked_images


def stack_labels(labels):
    stacked_labels = np.zeros((labels.shape[0] - 3, 1), dtype=np.uint8)
    for i in range(labels.shape[0] - 3):
        stacked_labels[i] = labels[i + 3]
    return np.squeeze(stacked_labels)

In [7]:
def preproc_to_feats_atari(model, preprocs):
    with torch.no_grad():
        tensors = obs_as_tensor(preprocs, model.device)
        print(tensors.shape)
        feats = model.policy.extract_features(tensors) 
    return feats

In [8]:
pre_processed_states = pre_process_atari(state)

In [9]:
# turn all 1s and 3s into 0s
labels[labels == 1] = 0
labels[labels == 3] = 0
# turn all 2s into 1s
labels[labels == 2] = 1

In [10]:
stacked_labels = stack_labels(labels)

In [11]:
pre_processed_states.shape

(1354, 4, 84, 84)

In [12]:
stacked_labels.shape

(1354,)

In [13]:
feats = preproc_to_feats_atari(agent, pre_processed_states)

torch.Size([1354, 4, 84, 84])


In [14]:
from discovery.utils import sg_detection 
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import copy
import tqdm
import torch.nn as nn
import torch.optim as optim
import importlib

def train_classifier(clf, X, labels,
                     n_epochs=500,
                     batch_size=32,
                     test_size=0.2, random_state=0):
    # X = torch.cat(feats, dim=0)
    y = torch.tensor(labels).float()
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        test_size=test_size,
                                                        random_state=random_state)
    
    best_acc = - np.inf
    best_weights = None
    batch_start = torch.arange(0, len(X_train), batch_size) # TODO: check if the last batch is included
    loss_fn = nn.BCELoss(reduction='none')  # reduction='none' to get per-sample loss, not mean

    num_pos = y_train.sum()
    num_neg = len(y_train) - num_pos
    base_weight = torch.tensor([1.0, num_neg/num_pos]) # for weighted mean in loss calculation
    
    optimizer = optim.Adam(clf.parameters(), lr=0.0001)
    # TODO: collect positive examples, and concatenate them to each batch

    for epoch in range(n_epochs):
        clf.train()
        with tqdm.tqdm(batch_start, unit="batch", mininterval=0, disable=False) as bar:
            bar.set_description(f"Epoch {epoch}")
            for start in bar:
                # take a batch
                X_batch = X_train[start:start+batch_size]
                y_batch = y_train[start:start+batch_size]
                # forward pass
                y_pred = clf(X_batch)
                y_batch = y_batch.unsqueeze(1)
                weight = torch.where(y_batch == 1, base_weight[1], base_weight[0])
                loss2 = loss_fn(y_pred, y_batch)
                final_loss = torch.mean(weight*loss2)
                # backward pass
                optimizer.zero_grad()
                final_loss.backward()
                # update weights
                optimizer.step()
                # print progress
                acc = (y_pred.round() == y_batch).float().mean()
                bar.set_postfix(
                    loss=float(final_loss),
                    acc=float(acc)
                )
        # # evaluate accuracy at end of each epoch
        # clf.eval()
        # y_pred = clf(X_test)
        # acc = (y_pred.round() == y_test).float().mean()
        # acc = float(acc)
        if acc > best_acc:
            best_acc = acc
            best_weights = copy.deepcopy(clf.state_dict())
    return best_acc

In [15]:
stacked_labels[0]

0

In [16]:
clf = sg_detection.LinearClassifier(input_size=512)
acc = train_classifier(clf, feats, stacked_labels)

Epoch 0:   0%|          | 0/34 [00:00<?, ?batch/s]

Epoch 0: 100%|██████████| 34/34 [00:00<00:00, 592.86batch/s, acc=0.222, loss=1.34]  
Epoch 1: 100%|██████████| 34/34 [00:00<00:00, 1243.97batch/s, acc=0.37, loss=0.978] 
Epoch 2: 100%|██████████| 34/34 [00:00<00:00, 971.40batch/s, acc=0.444, loss=0.847] 
Epoch 3: 100%|██████████| 34/34 [00:00<00:00, 1114.28batch/s, acc=0.444, loss=0.765]
Epoch 4: 100%|██████████| 34/34 [00:00<00:00, 1255.17batch/s, acc=0.556, loss=0.7]  
Epoch 5: 100%|██████████| 34/34 [00:00<00:00, 1229.18batch/s, acc=0.556, loss=0.648]
Epoch 6: 100%|██████████| 34/34 [00:00<00:00, 1214.43batch/s, acc=0.704, loss=0.606]
Epoch 7: 100%|██████████| 34/34 [00:00<00:00, 1209.29batch/s, acc=0.778, loss=0.57] 
Epoch 8: 100%|██████████| 34/34 [00:00<00:00, 1221.22batch/s, acc=0.778, loss=0.54] 
Epoch 9: 100%|██████████| 34/34 [00:00<00:00, 1242.56batch/s, acc=0.815, loss=0.513]
Epoch 10: 100%|██████████| 34/34 [00:00<00:00, 535.01batch/s, acc=0.815, loss=0.49]  
Epoch 11: 100%|██████████| 34/34 [00:00<00:00, 988.95batch/s, ac

In [17]:
from sklearn.metrics import confusion_matrix
def classifier_performance(clf, X, labels):
    y = torch.tensor(labels).float().squeeze()
    y_pred = clf(X)
    acc = (y_pred.round() == y).float().mean()
    print("Accuracy: ", acc)
    y_pred_np = y_pred.detach().numpy()
    c_m = confusion_matrix(labels, y_pred_np.round())
    print("Confusion Matrix: ")
    print(c_m)
    return acc, c_m

In [18]:
classifier_performance(clf, feats, stacked_labels)

Accuracy:  tensor(0.9261)
Confusion Matrix: 
[[1292   15]
 [   5   42]]


(tensor(0.9261),
 array([[1292,   15],
        [   5,   42]]))

In [19]:
clf_nl = sg_detection.NonLinearClassifier(input_size=512, hidden_size=64)
acc = train_classifier(clf_nl, feats, stacked_labels)

Epoch 0: 100%|██████████| 34/34 [00:00<00:00, 1066.17batch/s, acc=0.667, loss=0.567]
Epoch 1: 100%|██████████| 34/34 [00:00<00:00, 1049.87batch/s, acc=0.593, loss=0.564]
Epoch 2: 100%|██████████| 34/34 [00:00<00:00, 1088.17batch/s, acc=0.852, loss=0.468]
Epoch 3: 100%|██████████| 34/34 [00:00<00:00, 1081.88batch/s, acc=0.889, loss=0.406]
Epoch 4: 100%|██████████| 34/34 [00:00<00:00, 1085.29batch/s, acc=0.889, loss=0.359]
Epoch 5: 100%|██████████| 34/34 [00:00<00:00, 1091.42batch/s, acc=0.889, loss=0.327]
Epoch 6: 100%|██████████| 34/34 [00:00<00:00, 879.44batch/s, acc=0.889, loss=0.301] 
Epoch 7: 100%|██████████| 34/34 [00:00<00:00, 1052.90batch/s, acc=0.889, loss=0.277]
Epoch 8: 100%|██████████| 34/34 [00:00<00:00, 1040.84batch/s, acc=0.889, loss=0.262]
Epoch 9: 100%|██████████| 34/34 [00:00<00:00, 1042.57batch/s, acc=0.926, loss=0.245]
Epoch 10: 100%|██████████| 34/34 [00:00<00:00, 1095.99batch/s, acc=0.926, loss=0.231]
Epoch 11: 100%|██████████| 34/34 [00:00<00:00, 1069.17batch/s, a

In [20]:
classifier_performance(clf_nl, feats, stacked_labels)

Accuracy:  tensor(0.9337)
Confusion Matrix: 
[[1301    6]
 [   7   40]]


(tensor(0.9337),
 array([[1301,    6],
        [   7,   40]]))

------

In [21]:
obs = pre_processed_states[0, :, :, :]
obs = obs.reshape(1, 4, 84, 84)
obs = obs_as_tensor(obs, agent.policy.device)
print(obs.shape)
agent.policy.extract_features(obs)

torch.Size([1, 4, 84, 84])


tensor([[10.6792,  0.0000,  3.2849,  0.0000,  0.0000,  4.6934,  6.2134,  0.0000,
          0.7665,  0.0000,  6.0376,  9.9950,  7.6318,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  2.3148,  1.9839,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9741,  0.0000,  0.0000,
          5.9235,  0.0000,  0.0000,  0.0000,  0.0000,  3.1284,  1.7241,  0.0000,
          0.0000,  0.0000,  0.0000,  1.6094,  1.1486,  9.5892,  3.9028,  7.3212,
          0.0000,  0.0000,  0.0000,  2.1245,  3.3880,  0.0000,  0.0000,  0.0000,
          0.6887,  0.0000,  6.6848,  0.0000,  9.5831,  0.1434,  0.0000,  0.5921,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4345,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0816,  2.1754,  0.0000,  5.8832,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  5.9995, 10.1024,  0.0000,  7.9316,  0.0000,  0.0513,
          0.0000,  0.0000,  

NameError: name 'obs_to_feats' is not defined