In [1]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch.utils.data import DataLoader
from torch_geometric.nn.models import GCN, GAT

import numpy as np
import matplotlib.pyplot as plt

from time import time
from tqdm import tqdm

from collections import defaultdict

import optuna

from sklearn.metrics import recall_score

from reserve import generate_reserve
from MDP_helpers import MDP, relabel_k
from kmdp_toolbox import aStarAbs, sk_to_s
from experiment import Experiment

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
device='cpu'

In [3]:
torch.cuda.empty_cache()

In [4]:
# Generate Reserve Data
N_datasets = 100

N_sites = 5
N_species = 20
K = 7

N_states = 3**N_sites
N_actions = N_sites

print(f"Generating {N_datasets} MDPs with {N_states} states and {N_actions} actions \n")

mdp_datasets = []
for i in tqdm(range(N_datasets)):
    pj = np.random.random() # random probability between 0 and 1

    T, R = generate_reserve(N_sites, N_species, pj=pj, seed=i) 
    mdp = MDP(T, R, gamma=0.99)
    mdp.solve_MDP()

    mdp.k_states, mdp.K = aStarAbs(P=mdp.transitions, R=mdp.rewards, V=mdp.optimal_values, policy=mdp.optimal_policy, K=K, precision=1e-6)

    mdp_datasets.append(mdp)

Generating 100 MDPs with 243 states and 5 actions 



  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:10<00:00,  9.62it/s]


In [5]:
# dataset = []

# for i in range(len(mdp_datasets)):
#     # Set up data as graphs with node features defined by rewards and transition probabilities. Approach is to make dataset tabular as set node features
#     P = mdp_datasets[i].transitions
#     T = np.empty((N_states, N_states*N_actions))
#     for j in range(N_states):
#         T[j, :] = P[:, j, :].reshape(1, -1)

#     x = np.concatenate([T, mdp_datasets[i].rewards], axis=1)

#     # Cound whether transition are non-zero for any action
#     p_sum = np.sum((P> 0), axis=0)
#     edges = [[i, j] for i, j in zip(*np.where(p_sum > 0))]

#     # Convert to torch
#     x = torch.tensor(x, dtype=torch.float).to(device)
#     edges = torch.tensor(edges, dtype=torch.int).T.to(device)

#     k = torch.tensor(mdp_datasets[i].k_states, dtype=torch.int64).to(device)

#     T = torch.tensor(mdp_datasets[i].transitions).to('cpu')
#     R = torch.tensor(mdp_datasets[i].rewards).to('cpu')
#     V = torch.tensor(mdp_datasets[i].optimal_values).to('cpu')

#     dataset.append(
#         Data(
#             x=x,
#             edges=edges,
#             k_labels=k,
#             T = T,
#             R = R,
#             V = V
#         )
#     )

In [6]:
dataset = []

for i in range(len(mdp_datasets)):
    # Set up data as graphs with node features defined by rewards and transition probabilities. Approach is to make dataset tabular as set node features
    P = mdp_datasets[i].transitions

    # Cound whether transition are non-zero for any action
    p_sum = np.sum((P> 0), axis=0)
    edges = [[i, j] for i, j in zip(*np.where(p_sum > 0))]
    edge_features = [P[:, i[0], i[1]] for i in edges]

    # Convert to torch
    edges = torch.tensor(edges, dtype=torch.int64).T.to(device)
    edge_features = torch.tensor(edge_features, dtype = torch.float64).to(device)

    k = torch.tensor(mdp_datasets[i].k_states, dtype=torch.int64).to(device)
    T = torch.tensor(mdp_datasets[i].transitions).to('cpu')
    R = torch.tensor(mdp_datasets[i].rewards, dtype = torch.float32).to('cpu')
    V = torch.tensor(mdp_datasets[i].optimal_values).to('cpu')

    dataset.append(
        Data(
            x=R,
            edges=edges,
            edge_features = edge_features,
            k_labels=k,
            T = T,
            R = R,
            V = V
        )
    )

  edge_features = torch.tensor(edge_features, dtype = torch.float64).to(device)


In [7]:
data_split = int(len(dataset)*0.8)
train_data = dataset[:data_split]
test_data = dataset[data_split:]

hparam_split = int(len(train_data)*0.8)
train_data_hparam = train_data[:hparam_split]
val_data = train_data[hparam_split:]

In [8]:
gcn_model = GAT(
    in_channels=dataset[0].x.shape[1], 
    out_channels=K, 
    hidden_channels=100, 
    num_layers=1, 
).to(device)

gcn_model(
    x = dataset[0].x,
    edge_index = dataset[0].edges,
    edge_attr = dataset[0].edge_features
)


tensor([[-0.1786,  0.1220, -0.3189,  ..., -0.2409, -1.5827, -0.0484],
        [-0.0556,  0.1150, -0.2488,  ..., -0.0731, -0.8838, -0.0679],
        [ 0.3129,  0.5221, -0.8134,  ...,  0.3928, -0.9328, -0.3539],
        ...,
        [-0.3002, -0.1895,  0.1075,  ..., -0.4598, -1.0013,  0.1877],
        [-0.1475, -0.0624,  0.0040,  ..., -0.2167, -0.6073,  0.0670],
        [-0.0799,  0.0234, -0.1372,  ..., -0.1471, -0.7381,  0.0165]],
       grad_fn=<AddBackward0>)

In [9]:
def buildKMDP(T: torch.tensor, R: torch.tensor, predicted_k_states: torch.tensor, K: int) -> torch.tensor:
    """ Implement buildKMDP using inbuilt torch functions to keep everything on device """
    K2S = sk_to_s(predicted_k_states, K)
    weights = (1/torch.bincount(predicted_k_states))[predicted_k_states]

    RK = torch.empty(size=(K, N_actions), device=device, dtype=torch.float64)
    # R = torch.tensor(mdp.rewards).to(device)

    TK = torch.empty(size=(N_actions, K, K), device=device, dtype=torch.float64)
    # T = torch.tensor(mdp.transitions).to(device)

    for k in range(K):
        RK[k] = (R.T*weights).T[predicted_k_states==k].sum(axis=0)
        for kp in range(K):
            TK[:, k, kp] = (T[:, :, predicted_k_states==kp].sum(axis=2) * weights)[:, predicted_k_states==k].sum(axis=1)
    return TK, RK, K2S

In [10]:
# value iteration

def valueIteration(T: torch.tensor, R: torch.tensor, gamma = 0.99, epsilon=1e-4, N_iter=10000) -> torch.tensor:
    """ Implement Value Iteration in the pytorch environment """
    N_states, N_actions = R.shape
    V = torch.zeros(size=[N_states], device=device, dtype=torch.float64)
    Q = torch.empty(size=[N_states, N_actions], device=device, dtype=torch.float64)
    for i in range(N_iter):
        for a in range(N_actions):
            Q[:, a] = R[:, a].T + gamma*T[a, :, :]@V

        V_new, policy = Q.max(axis=1)

        if torch.all(torch.abs(V_new - V) < epsilon):
            break
        
        if i == N_iter - 1:
            raise Exception("Did not converge in time. Consider increasing the number of iterations.")

        V = V_new
    return V_new, policy

In [11]:
def valueFunction(T, R, policy, gamma=0.99, epsilon=1e-3, N_iter = 1e6):
    """ Calculate the value function of an mdp given a policy """
    N_states, N_actions = R.shape
    
    V = torch.zeros(size=[N_states])
    V_new = torch.zeros(size=[N_states])

    count = 0
    converged=False
    while not converged:
        for s in range(N_states):
            V_new[s] = R[s, policy[s]] + gamma*(T[policy[s], s]*V).sum()

        if torch.max(V_new - V) < epsilon:
            converged = True

        V = 1*V_new

        count += 1
        if count >= N_iter:
            print("Did not converge")
            break
        
    return V_new

In [12]:
def calculate_gap(T, R, V, predicted_k_states, K):
    # predicted_k_states = #F.softmax(prediction, dim=1).argmax(axis=1)

    new_K = len(predicted_k_states.unique())
    predicted_k_states = relabel_k(predicted_k_states, K) if new_K != K else predicted_k_states

    PK, RK, K2S = buildKMDP(T, R, predicted_k_states, new_K)
    _, kmdp_policy = valueIteration(PK, RK, gamma=0.85, N_iter=50000, epsilon=1e-1)

    k_policy = torch.empty(size=[N_states], dtype=torch.int64)

    for k in range(new_K):
        k_policy[K2S[k]] = kmdp_policy[k]

    V_K = valueFunction(T, R, k_policy)

    gap = torch.max(torch.abs(V - V_K))
    error = gap/max(V)

    return gap, error


In [13]:
# ChatGPT written
def multiclass_recall_score(y_true, y_pred, average='macro'):
    """
    Calculate the multiclass recall score using PyTorch.

    Parameters:
    - y_true (torch.Tensor): True labels (ground truth).
    - y_pred (torch.Tensor): Predicted labels.
    - average (str): Type of averaging to use for multiclass recall.
        - 'macro' (default): Calculate recall for each class and then take the average.
        - 'micro': Calculate recall globally by considering all instances.
        - 'weighted': Calculate recall for each class and weight them by support.

    Returns:
    - recall (float): The multiclass recall score.
    """
    assert len(y_true) == len(y_pred), "Input arrays must have the same length"

    if average not in ('macro', 'micro', 'weighted'):
        raise ValueError("Invalid 'average' parameter. Use 'macro', 'micro', or 'weighted'.")

    num_classes = len(torch.unique(y_true))
    recall_per_class = []

    for class_label in range(num_classes):
        true_positive = torch.sum((y_true == class_label) & (y_pred == class_label)).item()
        false_negative = torch.sum((y_true == class_label) & (y_pred != class_label)).item()
        recall = true_positive / (true_positive + false_negative + 1e-10)  # Adding a small epsilon to avoid division by zero
        recall_per_class.append(recall)

    if average == 'macro':
        return sum(recall_per_class) / num_classes
    elif average == 'micro':
        total_true_positives = torch.sum((y_true == y_pred) & (y_true == class_label)).item()
        total_false_negatives = torch.sum((y_true != y_pred) & (y_true == class_label)).item()
        return total_true_positives / (total_true_positives + total_false_negatives + 1e-10)
    elif average == 'weighted':
        class_counts = [torch.sum(y_true == class_label).item() for class_label in range(num_classes)]
        total_samples = len(y_true)
        weights = [count / total_samples for count in class_counts]
        weighted_recall = sum([recall_per_class[i] * weights[i] for i in range(num_classes)])
        return weighted_recall

In [14]:
from torch.optim.lr_scheduler import ExponentialLR

In [15]:
N_epochs = 500

def objective(trial):
    hidden_channels = trial.suggest_int("hidden_channels", 30, 200)
    num_layers = trial.suggest_int("num_layers", 1, 3)
    dropout = trial.suggest_float("dropout", 0, 0.1)
    lr = trial.suggest_float("lr", 1e-2, 10)
    weight_decay = trial.suggest_float("weight_decay", 1e-3, 1e-1)

    gamma = trial.suggest_float("gamma", 0, 1)

    # gcn_model = GCN(
    #     in_channels=dataset[0].x.shape[1], 
    #     out_channels=K, 
    #     hidden_channels=hidden_channels, 
    #     num_layers=num_layers, 
    #     dropout=dropout
    # ).to(device)

    gcn_model = GAT(
        in_channels=dataset[0].x.shape[1], 
        out_channels=K, 
        hidden_channels=hidden_channels, 
        num_layers=num_layers, 
        dropout=dropout
    ).to(device)


    optimizer = torch.optim.Adam(gcn_model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_function = torch.nn.CrossEntropyLoss()
    lr_sheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    gcn_model.train()
    for epoch in range(N_epochs):
        optimizer.zero_grad()     
        loss = 0
        for data in train_data_hparam:
            pred = gcn_model(x = data.x, edge_index=data.edges, edge_attr=edge_features)
            loss += loss_function(pred, data.k_labels)
        loss.backward()
        optimizer.step()
        lr_sheduler.step()
    
    gcn_model.eval()
    errors = []
    recall = []
    for data in val_data:
        out = gcn_model(x = data.x, edge_index=data.edges, edge_attr=edge_features)
        pred = F.softmax(out, dim=1).argmax(axis=1).to('cpu')

        _, error = calculate_gap(data.T, data.R, data.V, pred, K)
        errors.append(error.to('cpu'))
        recall.append(
            recall_score(data.k_labels.to('cpu'), pred.to('cpu'), average="macro")
        )
    
    return np.mean(errors) + (1 - np.mean(recall)) # Minimise errors while maximising recall score

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=90, n_jobs=6)

best_params = study.best_params

print(study.best_value)
print(best_params)

[I 2023-10-30 21:17:05,371] A new study created in memory with name: no-name-68abbcb5-c30f-4d70-a9eb-689da12cd6a6


  Q[:, a] = R[:, a].T + gamma*T[a, :, :]@V
[I 2023-10-30 21:20:39,422] Trial 5 finished with value: 1.5576930118179741 and parameters: {'hidden_channels': 111, 'num_layers': 1, 'dropout': 0.09231175350377001, 'lr': 2.7246438322489674, 'weight_decay': 0.019752491727814386, 'gamma': 0.6029542839683466}. Best is trial 5 with value: 1.5576930118179741.
[I 2023-10-30 21:24:35,985] Trial 3 finished with value: 1.8357622095694102 and parameters: {'hidden_channels': 39, 'num_layers': 2, 'dropout': 0.09268147576393308, 'lr': 9.964703875139556, 'weight_decay': 0.08159747132472983, 'gamma': 0.335353677880762}. Best is trial 5 with value: 1.5576930118179741.
[I 2023-10-30 21:28:28,814] Trial 7 finished with value: 1.785645023643093 and parameters: {'hidden_channels': 133, 'num_layers': 1, 'dropout': 0.087763728765618, 'lr': 6.185326640515875, 'weight_decay': 0.0022345241801896626, 'gamma': 0.4995737128193001}. Best is trial 5 with value: 1.5576930118179741.
[I 2023-10-30 21:29:13,981] Trial 2 fini

0.9333078923915921
{'hidden_channels': 87, 'num_layers': 1, 'dropout': 0.04048283236093458, 'lr': 7.584595173582405, 'weight_decay': 0.0681759884657168, 'gamma': 0.8881174140897286}


In [16]:
import pandas as pd
trials = [i for i in map(lambda x: dict([("score", x.values[0]),*(x.params).items()]), study.get_trials())]
trials = pd.DataFrame(trials)

score	hidden_channels	num_layers	dropout	lr	weight_decay	weight_param	gamma
21	0.403245	67	1	0.051031	4.809483	0.045045	8	0.972926
20	0.403374	63	1	0.003532	4.839212	0.040630	8	0.989183
14	1.023662	19	1	0.295559	5.025916	0.053842	7	0.798589

In [17]:
trials.sort_values(by="score")

Unnamed: 0,score,hidden_channels,num_layers,dropout,lr,weight_decay,gamma
78,0.933308,87,1,0.040483,7.584595,0.068176,0.888117
87,0.956534,97,1,0.050359,7.184763,0.062805,0.849785
26,0.965460,59,1,0.041301,8.093975,0.072873,0.892203
39,0.985529,53,1,0.055260,9.369357,0.062743,0.873322
57,0.996693,61,1,0.045105,6.495390,0.073942,0.945401
...,...,...,...,...,...,...,...
44,1.800246,124,3,0.059509,7.898251,0.080566,0.870065
62,1.810941,155,2,0.053284,8.436308,0.058149,0.841841
3,1.835762,39,2,0.092681,9.964704,0.081597,0.335354
13,1.851202,84,1,0.005799,2.517000,0.002008,0.480816


In [18]:
experiment = Experiment(savefile="gat_hparams")

In [19]:
for i in trials.index:
    trials.loc[i].to_dict()
    experiment.save(trials.loc[i].to_dict())