In [1]:
from generate_mdps import generate_datsets, valueIteration
from dataset import MDPDataset, AllNodeFeatures, InMemoryMDPDataset, TransitionsOnEdge
from experiment import Experiment
from MDP_helpers import calculate_gap, multiclass_recall_score
from kmdp_toolbox import aStarAbs

In [2]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import optuna
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GCN, GAT
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split
from collections import defaultdict
from sklearn.metrics import recall_score

from time import time
from tqdm import tqdm

import pickle

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import warnings
warnings.filterwarnings('ignore')

In [4]:
torch.cuda.manual_seed(12345)
np.random.seed(12345)

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device='cpu'
print(device)

cuda


In [6]:
N_datasets = 1

N_sites = 6
N_species = 20
K = 8

N_states = 3**N_sites
print(f"N_states: {N_states}")
generate_datsets(N_sites, N_species, K, N_datasets, remove_previous=False, folder="hparam_data")

N_states: 729


'Data already exists'

In [7]:
# dataset = MDPDataset(f"datasets/hparam_data", pre_transform=TransitionsOnEdge())
dataset = InMemoryMDPDataset(f"datasets/hparam_data", pre_transform=AllNodeFeatures())
mdp = dataset[0]

In [8]:
gap, _ = calculate_gap(mdp.P, mdp.R, mdp.V, mdp.k_labels, K)
print(gap)

tensor(0.0144)


In [9]:
state_gaps = pd.DataFrame(index = np.arange(N_states), columns = np.arange(K))

In [10]:
k_labels = 1*mdp.k_labels

for i in tqdm(range(N_states)):
    k_labels = 1*mdp.k_labels
    for j in range(K):
        k_labels[i] = j
        gap_ij, _ = calculate_gap(mdp.P, mdp.R, mdp.V, k_labels, K)
        state_gaps.loc[i, j] = float(gap_ij)

100%|██████████| 729/729 [04:12<00:00,  2.88it/s]


In [11]:
state_gaps

Unnamed: 0,0,1,2,3,4,5,6,7
0,1.014938,0.081602,0.072176,0.02856,0.014446,0.025757,0.093995,0.014446
1,0.014446,0.264561,0.090727,0.094483,0.054577,0.061214,0.014446,0.054577
2,1.014938,0.827007,0.058672,0.027634,0.014446,0.025603,0.08419,0.014446
3,0.014446,0.174024,0.357069,0.153692,0.014446,0.165456,0.154922,0.014446
4,0.014446,0.198595,0.198591,0.094309,0.014446,0.096728,0.047385,0.014446
...,...,...,...,...,...,...,...,...
724,0.014446,0.014446,0.014446,0.014446,0.014446,0.014446,0.014446,0.014446
725,0.014446,0.014446,0.014446,0.014446,0.072417,0.014446,0.014446,0.072176
726,1.014938,0.014446,1.010271,1.010271,1.010271,1.010271,1.010271,0.827007
727,0.014446,0.014446,0.014446,0.014446,0.014446,0.014446,0.014446,0.014446


Seems to be two classes of sensitive states
1. States that are sensitive if put into a certain k_state. This could represent a state that is insignificant to all k_states except one in which the decision would be disastrous
2. States that need to be in a certain k_state for the problem to solve well. This could represent critical decisions in certain situations

Make the following definitions
0. Insensitve States: States that arent sensitive to any degree
1. Type 1 sensitive: States that are sensitive to 1 or 2 abstract states
2. Type 2 sensitive: States that are sensitive to all but 1 or 2 abstract states
3. Mixed sensitivity: States that are sensitive to a number of abstract states, but arent classified as Type 1 or Type 2

Set the threshold for strong sensitivity to be 0.5

Only 8 are of type 1 where there is a single unique state 

In [12]:
sensitivity = pd.DataFrame(index = np.arange(N_states), columns=["Type"])
n = (state_gaps > 0.1).sum(axis=1)
sensitivity.loc[n == 0, "Type"] = "Insensitive" 
sensitivity.loc[n.isin([1, 2]), "Type"] = "Type 1" 
sensitivity.loc[n.isin([6, 7]), "Type"] = "Type 2" 
sensitivity.loc[~n.isin([0, 1, 2, 6, 7]), "Type"] = "Mixed" 

In [13]:
sensitivity["Type"].value_counts()

Type
Mixed          383
Insensitive    155
Type 2         147
Type 1          44
Name: count, dtype: int64

In [14]:
sensitivity["V"] = mdp.V.numpy()

In [15]:
iqr = lambda x: np.percentile(x, 75) - np.percentile(x, 25)

In [16]:
sensitivity.groupby("Type")["V"].agg([np.mean, np.std, np.median, np.min, np.max, iqr])

Unnamed: 0_level_0,mean,std,median,amin,amax,<lambda_0>
Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
Insensitive,0.061054,0.034258,0.073632,0.009837,0.113492,0.072076
Mixed,0.374646,0.280737,0.271847,0.112612,1.1157,0.262981
Type 1,0.291272,0.407839,0.015188,0.0,1.156662,0.484352
Type 2,0.32793,0.255359,0.232091,0.101564,1.047122,0.3237


In [17]:
with open("Results/GCN_weighted/Reserve_MDP_729_8/model.pkl", "rb") as file:
    model_weighted = pickle.load(file).to('cpu')
    file.close()

with open("Models/gcn_729_8.pckl", "rb") as file:
    model_unweighted = pickle.load(file).to('cpu')
    file.close()

In [18]:
pred_unweighted = model_unweighted(
    x = mdp.x,
    edge_index=mdp.edges,
)
pred_k_unweighted = F.softmax(pred_unweighted, dim=1).argmax(axis=1)

pred_weighted = model_weighted(
    x = mdp.x,
    edge_index=mdp.edges,
)
pred_k_weighted = F.softmax(pred_weighted, dim=1).argmax(axis=1)

In [19]:
sensitivity["k_state"] = mdp["k_labels"]
sensitivity["pred_k_unweighted"] = pred_k_unweighted
sensitivity["pred_k_weighted"] = pred_k_weighted
sensitivity["correct_unweighted"] = sensitivity["k_state"] == sensitivity["pred_k_unweighted"]
sensitivity["correct_weighted"] = sensitivity["k_state"] == sensitivity["pred_k_weighted"]

In [25]:
sensitivity.head()

Unnamed: 0,Type,V,k_state,pred_k_unweighted,pred_k_weighted,correct_unweighted,correct_weighted
0,Type 1,1.156662,7,1,3,False,False
1,Type 1,0.339792,6,4,2,False,False
2,Type 1,1.131185,7,3,5,False,False
3,Mixed,0.44173,4,4,4,True,True
4,Type 1,0.239324,4,4,4,True,True


In [21]:
counts = pd.DataFrame()
counts["k_state"] = sensitivity["k_state"].value_counts()
counts["pred_k_unweighted"] = sensitivity["pred_k_unweighted"].value_counts()
counts["pred_k_weighted"] = sensitivity["pred_k_weighted"].value_counts()
counts

Unnamed: 0_level_0,k_state,pred_k_unweighted,pred_k_weighted
k_state,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
4,172,215.0,110.0
2,137,80.0,179.0
1,130,226.0,161.0
6,117,116.0,96.0
3,103,53.0,74.0
5,64,39.0,109.0
7,5,,
0,1,,


In [22]:
df = sensitivity.groupby("Type")["correct_unweighted"].agg([np.sum, len])
df["Accuracy"] = df["sum"]/df["len"]
df["Most common state"] = sensitivity.groupby("Type")["k_state"].agg(lambda x: x.value_counts().idxmax())
df["Most common predicted state"] = sensitivity.groupby("Type")["pred_k_unweighted"].agg(lambda x: x.value_counts().idxmax())
df

Unnamed: 0_level_0,sum,len,Accuracy,Most common state,Most common predicted state
Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Insensitive,81,155,0.522581,3,4
Mixed,280,383,0.73107,4,4
Type 1,11,44,0.25,4,4
Type 2,133,147,0.904762,1,1


In [23]:
df = sensitivity.groupby("Type")["correct_weighted"].agg([np.sum, len])
df["Accuracy"] = df["sum"]/df["len"]
df["Most common state"] = sensitivity.groupby("Type")["k_state"].agg(lambda x: x.value_counts().idxmax())
df["Most common predicted state"] = sensitivity.groupby("Type")["pred_k_weighted"].agg(lambda x: x.value_counts().idxmax())
df

Unnamed: 0_level_0,sum,len,Accuracy,Most common state,Most common predicted state
Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Insensitive,101,155,0.651613,3,2
Mixed,245,383,0.639687,4,2
Type 1,8,44,0.181818,4,2
Type 2,139,147,0.945578,1,1


In [24]:
pd.merge(sensitivity, state_gaps, left_index=True, right_index=True).groupby("Type")[state_gaps.columns].mean()

Unnamed: 0_level_0,0,1,2,3,4,5,6,7
Type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
Insensitive,0.023818,0.050019,0.0474,0.042863,0.046078,0.048367,0.053243,0.046076
Mixed,0.177783,0.191548,0.174964,0.226123,0.166721,0.215988,0.185095,0.098949
Type 1,0.721479,0.085198,0.058614,0.092835,0.055905,0.080708,0.083179,0.027898
Type 2,0.243532,0.200082,0.250158,0.241216,0.293612,0.23771,0.214351,0.254167
