In [3]:
import gymnasium as gym
import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch
import os
from stable_baselines3 import PPO
from discovery.utils.feat_extractors import NatureCNN
from stable_baselines3.common.utils import obs_as_tensor
from discovery.experiments.FeatAct_minigrid.helpers import pre_process_obs
import cv2

from discovery.utils import filesys

In [4]:
filesys.set_directory_in_project()
agent = PPO.load("discovery/experiments/FeatAct_atari/models/Seaquest-v5_mpqgvvr1.zip")

FileNotFoundError: [Errno 2] No such file or directory: '/Users/kevinroice/Documents/research/discovery/discovery/experiments/FeatAct_atari/models/Seaquest-v5_mpqgvvr1.zip.zip'

Load in the seaquest dataset.

In [None]:

path = f"/Users/kevinroice/Documents/research/discovery/datasets/AAD/clean/SeaquestNoFrameskip-v4/episode(1).hdf5"
with h5py.File(path, "r") as f:
    state = f["state"][...]
labels = np.load("/Users/kevinroice/Documents/research/discovery/datasets/AAD/clean/SeaquestNoFrameskip-v4/episode(1)_labels.npy")

In [None]:
def obs_to_feats(model, obss):
    feats = []
    with torch.no_grad():
        for obs in obss:
            obs = pre_process_obs(obs[0], model)
            # print(obs[0].shape)
            if model.__class__.__name__ == "DoubleDQN":
                x = model.policy.extract_features(obs, model.policy.q_net.features_extractor)
            elif model.__class__.__name__ == "PPO":
                x = model.policy.extract_features(obs)
            feats.append(x)  
    return feats

In [None]:
state[0, :, :, :].shape

(210, 160, 3)

In [None]:
def pre_process_atari(dataset: np.ndarray):
    num_images = dataset.shape[0]
    preprocessed_images = np.zeros((num_images, 84, 84), dtype=np.uint8)

    for i in range(num_images):
        preprocessed_images[i] = cv2.resize(cv2.cvtColor(dataset[i], cv2.COLOR_RGB2GRAY), (84, 84))

        # Stack frames
        stacked_images = np.zeros((num_images - 3, 4, 84, 84), dtype=np.uint8)
        for i in range(num_images - 3):
            stacked_images[i] = np.stack(
                [
                    preprocessed_images[i],
                    preprocessed_images[i + 1],
                    preprocessed_images[i + 2],
                    preprocessed_images[i + 3],
                ]
            )

    return stacked_images


def stack_labels(labels):
    stacked_labels = np.zeros((labels.shape[0] - 3, 1), dtype=np.uint8)
    for i in range(labels.shape[0] - 3):
        stacked_labels[i] = labels[i + 3]
    return stacked_labels

In [None]:
def preproc_to_feats_atari(model, preprocs):
    with torch.no_grad():
        tensors = obs_as_tensor(preprocs, model.device)
        print(tensors.shape)
        feats = model.policy.extract_features(tensors) 
    return feats

In [None]:
pre_processed_states = pre_process_atari(state)

In [None]:
stacked_labels = stack_labels(labels)

In [None]:
pre_processed_states.shape

(1354, 4, 84, 84)

In [None]:
stacked_labels.shape

(1354, 1)

In [None]:
# turn all numbers into 0 and 2s into 1
stacked_labels[stacked_labels == 2] = 1

In [None]:
print(stacked_labels)

[[1]
 [1]
 [1]
 ...
 [0]
 [0]
 [0]]


In [None]:
feats = preproc_to_feats_atari(agent, pre_processed_states)

torch.Size([1354, 4, 84, 84])


In [None]:
from discovery.utils import sg_detection 
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import copy
import tqdm
import torch.nn as nn
import torch.optim as optim
import importlib

def train_classifier(clf, X, labels,
                     n_epochs=500,
                     batch_size=32,
                     test_size=0.2, random_state=0):
    # X = torch.cat(feats, dim=0)
    y = torch.tensor(labels).float().squeeze()
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        test_size=test_size,
                                                        random_state=random_state)
    
    best_acc = - np.inf
    best_weights = None
    batch_start = torch.arange(0, len(X_train), batch_size) # TODO: check if the last batch is included
    loss_fn = nn.BCELoss(reduction='none')  # reduction='none' to get per-sample loss, not mean

    num_pos = y_train.sum()
    num_neg = len(y_train) - num_pos
    base_weight = torch.tensor([1.0, num_neg/num_pos]) # for weighted mean in loss calculation
    
    optimizer = optim.Adam(clf.parameters(), lr=0.0001)
    # TODO: collect positive examples, and concatenate them to each batch

    for epoch in range(n_epochs):
        clf.train()
        with tqdm.tqdm(batch_start, unit="batch", mininterval=0, disable=False) as bar:
            bar.set_description(f"Epoch {epoch}")
            for start in bar:
                # take a batch
                X_batch = X_train[start:start+batch_size]
                y_batch = y_train[start:start+batch_size]
                # forward pass
                y_pred = clf(X_batch)
                y_batch = y_batch.unsqueeze(1)
                weight = torch.where(y_batch == 1, base_weight[1], base_weight[0])
                loss2 = loss_fn(y_pred, y_batch)
                final_loss = torch.mean(weight*loss2)
                # backward pass
                optimizer.zero_grad()
                final_loss.backward()
                # update weights
                optimizer.step()
                # print progress
                acc = (y_pred.round() == y_batch).float().mean()
                bar.set_postfix(
                    loss=float(final_loss),
                    acc=float(acc)
                )
        # # evaluate accuracy at end of each epoch
        # clf.eval()
        # y_pred = clf(X_test)
        # acc = (y_pred.round() == y_test).float().mean()
        # acc = float(acc)
        if acc > best_acc:
            best_acc = acc
            best_weights = copy.deepcopy(clf.state_dict())
    return best_acc

In [None]:
from discovery.utils import sg_detection 
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import copy
import tqdm
import torch.nn as nn
import torch.optim as optim
import importlib

def calculate_class_weights(labels, n_classes):
    class_counts = torch.bincount(labels.squeeze(), minlength=n_classes)
    class_weights = 1. / class_counts.float()
    class_weights /= class_weights.sum()  # normalize the weights
    return class_weights


def train_classifier(clf, X, labels, n_classes,
                     n_epochs=500, batch_size=32,
                     test_size=0.2, random_state=0):
    # X = torch.cat(feats, dim=0)
    y = torch.tensor(labels).long() # ensure labels are integers
    X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                        test_size=test_size,
                                                        random_state=random_state)
    
    # Calculate class weights
    print(type(y_train[0]), y_train[0], y_train[0].shape)
    class_weights = calculate_class_weights(y_train, n_classes)
    
    best_acc = - np.inf
    best_weights = None
    batch_start = torch.arange(0, len(X_train), batch_size)  # Check if the last batch is included
    print(class_weights)
    loss_fn = nn.CrossEntropyLoss(weight=class_weights)  # Use class weights in the loss function
    
    optimizer = optim.Adam(clf.parameters(), lr=0.0001)
    
    for epoch in range(n_epochs):
        clf.train()
        with tqdm.tqdm(batch_start, unit="batch", mininterval=0, disable=False) as bar:
            bar.set_description(f"Epoch {epoch}")
            for start in bar:
                # Take a batch
                X_batch = X_train[start:start+batch_size]
                y_batch = y_train[start:start+batch_size]
                # Ensure y_batch is of type LongTensor
                y_batch = y_batch.long()
                # Forward pass
                y_pred = clf(X_batch)
                # Ensure y_pred is of type FloatTensor
                y_pred = y_pred.type(torch.FloatTensor)
                y_batch = y_batch.type(torch.FloatTensor)
                print(type(y_pred), y_pred.shape)
                loss = loss_fn(y_pred, y_batch)
                # Backward pass
                optimizer.zero_grad()
                loss.backward()
                # Update weights
                optimizer.step()
                # Print progress
                _, y_pred_labels = torch.max(y_pred, 1)
                acc = (y_pred_labels == y_batch).float().mean()
                bar.set_postfix(
                    loss=float(loss),
                    acc=float(acc)
                )
        # Evaluate accuracy at end of each epoch
        clf.eval()
        with torch.no_grad():
            y_pred = clf(X_test)
            _, y_pred_labels = torch.max(y_pred, 1)
            acc = (y_pred_labels == y_test).float().mean()
            acc = float(acc)
        if acc > best_acc:
            best_acc = acc
            best_weights = copy.deepcopy(clf.state_dict())
    
    clf.load_state_dict(best_weights)
    return best_acc

In [None]:
stacked_labels[0]

array([1], dtype=uint8)

In [None]:
clf = sg_detection.LinearClassifier(input_size=512)
acc = train_classifier(clf, feats, stacked_labels)

Epoch 0: 100%|██████████| 34/34 [00:00<00:00, 114.22batch/s, acc=0.63, loss=1.36]  
Epoch 1: 100%|██████████| 34/34 [00:00<00:00, 93.60batch/s, acc=0.667, loss=1.2]   
Epoch 2: 100%|██████████| 34/34 [00:00<00:00, 196.03batch/s, acc=0.704, loss=1.04] 
Epoch 3: 100%|██████████| 34/34 [00:00<00:00, 269.14batch/s, acc=0.852, loss=0.908]
Epoch 4: 100%|██████████| 34/34 [00:00<00:00, 259.00batch/s, acc=0.852, loss=0.803]
Epoch 5: 100%|██████████| 34/34 [00:00<00:00, 242.62batch/s, acc=0.852, loss=0.723]
Epoch 6: 100%|██████████| 34/34 [00:00<00:00, 257.02batch/s, acc=0.852, loss=0.658]
Epoch 7: 100%|██████████| 34/34 [00:00<00:00, 227.32batch/s, acc=0.852, loss=0.605]
Epoch 8: 100%|██████████| 34/34 [00:00<00:00, 157.23batch/s, acc=0.889, loss=0.56] 
Epoch 9: 100%|██████████| 34/34 [00:00<00:00, 240.80batch/s, acc=0.889, loss=0.521]
Epoch 10: 100%|██████████| 34/34 [00:00<00:00, 232.28batch/s, acc=0.889, loss=0.488]
Epoch 11: 100%|██████████| 34/34 [00:00<00:00, 243.32batch/s, acc=0.926, lo

In [None]:
from sklearn.metrics import confusion_matrix
def classifier_performance(clf, X, labels):
    y = torch.tensor(labels).float().squeeze()
    y_pred = clf(X)


    acc = (y_pred.round() == y).float().mean()
    print("Accuracy: ", acc)
    y_pred_np = y_pred.detach().numpy()
    c_m = confusion_matrix(labels, y_pred_np.round())
    print("Confusion Matrix: ")
    print(c_m)
    return acc, c_m

In [None]:
classifier_performance(clf, feats, stacked_labels)

Accuracy:  tensor(0.7764)
Confusion Matrix: 
[[1170   16    0]
 [   5  162    0]
 [   0    1    0]]


(tensor(0.7764),
 array([[1170,   16,    0],
        [   5,  162,    0],
        [   0,    1,    0]]))

In [None]:
clf_nl = sg_detection.NonLinearClassifier(input_size=512, hidden_size=64)
acc = train_classifier(clf_nl, feats, stacked_labels)

Epoch 0: 100%|██████████| 34/34 [00:00<00:00, 162.65batch/s, acc=0.926, loss=0.8]  
Epoch 1: 100%|██████████| 34/34 [00:00<00:00, 92.14batch/s, acc=0.963, loss=0.61]  
Epoch 2: 100%|██████████| 34/34 [00:00<00:00, 180.30batch/s, acc=0.963, loss=0.496]
Epoch 3: 100%|██████████| 34/34 [00:00<00:00, 224.10batch/s, acc=0.963, loss=0.411]
Epoch 4: 100%|██████████| 34/34 [00:00<00:00, 239.95batch/s, acc=0.963, loss=0.345]
Epoch 5: 100%|██████████| 34/34 [00:00<00:00, 235.24batch/s, acc=0.963, loss=0.293]
Epoch 6: 100%|██████████| 34/34 [00:00<00:00, 246.98batch/s, acc=0.963, loss=0.254]
Epoch 7: 100%|██████████| 34/34 [00:00<00:00, 232.41batch/s, acc=0.963, loss=0.223]
Epoch 8: 100%|██████████| 34/34 [00:00<00:00, 241.11batch/s, acc=0.963, loss=0.199]
Epoch 9: 100%|██████████| 34/34 [00:00<00:00, 232.45batch/s, acc=0.963, loss=0.178]
Epoch 10: 100%|██████████| 34/34 [00:00<00:00, 235.22batch/s, acc=0.963, loss=0.16] 
Epoch 11: 100%|██████████| 34/34 [00:00<00:00, 227.67batch/s, acc=0.963, lo

In [None]:
classifier_performance(clf_nl, feats, stacked_labels)

Accuracy:  tensor(0.7814)
Confusion Matrix: 
[[1178    8    0]
 [   6  161    0]
 [   0    1    0]]


(tensor(0.7814),
 array([[1178,    8,    0],
        [   6,  161,    0],
        [   0,    1,    0]]))

------

In [None]:
obs = pre_processed_states[0, :, :, :]
obs = obs.reshape(1, 4, 84, 84)
obs = obs_as_tensor(obs, agent.policy.device)
print(obs.shape)
agent.policy.extract_features(obs)

torch.Size([1, 4, 84, 84])


tensor([[10.7101,  0.0000,  3.3007,  0.0000,  0.0000,  4.7056,  6.2294,  0.0000,
          0.7875,  0.0000,  6.0332,  9.9827,  7.6485,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  2.3347,  1.9806,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.9752,  0.0000,  0.0000,
          5.9296,  0.0000,  0.0000,  0.0000,  0.0000,  3.1411,  1.7217,  0.0000,
          0.0000,  0.0000,  0.0000,  1.5900,  1.1487,  9.6075,  3.9059,  7.3202,
          0.0000,  0.0000,  0.0000,  2.1230,  3.4050,  0.0000,  0.0000,  0.0000,
          0.6814,  0.0000,  6.6914,  0.0000,  9.6066,  0.1442,  0.0000,  0.5422,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.4422,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0876,  2.1760,  0.0000,  5.8867,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  6.0191, 10.1115,  0.0000,  7.9544,  0.0000,  0.0627,
          0.0000,  0.0000,  

In [None]:
feats = obs_to_feats(agent, [state[0, :, :, :]])

ValueError: axes don't match array