In [15]:
from MDP_helpers import calculate_gap
from dataset import MDPDataset, AllNodeFeatures, InMemoryMDPDataset, TransitionsOnEdge
from generate_mdps import generate_datsets
from experiment import Experiment

In [16]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import optuna
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GCN, GAT
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split
from collections import defaultdict
from sklearn.metrics import recall_score

from time import time
from tqdm import tqdm

In [17]:
torch.cuda.manual_seed(12345)
np.random.seed(12345)

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device='cpu'

In [19]:
N_datasets = 100

N_sites = 5
N_species = 20
K = 7

N_states = 3**N_sites
print(f"N_states: {N_states}")
generate_datsets(N_sites, N_species, K, N_datasets, remove_previous=False, folder="hparam_data")

N_states: 243
Deleting folder  datasets/hparam_data/raw
Generating 100 MDPs with 243 states and 5 actions 



  0%|          | 0/100 [00:00<?, ?it/s]

100%|██████████| 100/100 [00:19<00:00,  5.03it/s]


In [20]:
# dataset = InMemoryMDPDataset(f"datasets/mdp_{N_states}_state", pre_transform=TransitionsOnEdge())
dataset = InMemoryMDPDataset(f"datasets/mdp_{N_states}_state", pre_transform=AllNodeFeatures())

In [21]:
torch.all(dataset[0].R == dataset[5].R)

tensor(False)

In [22]:
data_split = int(len(dataset)*0.8)
train_data = dataset[:data_split]
test_data = dataset[data_split:]

hparam_split = int(len(train_data)*0.8)
train_data_hparam = train_data[:hparam_split]
val_data = train_data[hparam_split:]

In [23]:
from torch.optim.lr_scheduler import ExponentialLR

In [29]:
N_epochs = 500

def objective(trial):
    hidden_channels = trial.suggest_int("hidden_channels", 150, 225)
    num_layers = 1#trial.suggest_int("num_layers", 1, 3)
    dropout = trial.suggest_float("dropout", 0, 0.05)
    lr = trial.suggest_float("lr", 2, 8)
    weight_decay = trial.suggest_float("weight_decay", 1e-3, 0.04)

    gamma = trial.suggest_float("gamma", 0.95, 1)

    gcn_model = GCN(
        in_channels=dataset[0].x.shape[1], 
        out_channels=K, 
        hidden_channels=hidden_channels, 
        num_layers=num_layers, 
        dropout=dropout
    ).to(device)


    optimizer = torch.optim.Adam(gcn_model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_function = torch.nn.CrossEntropyLoss()
    lr_sheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    gcn_model.train()
    for epoch in range(N_epochs):
        optimizer.zero_grad()     
        loss = 0
        for data in train_data_hparam:
            pred = gcn_model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
                # edge_attr=edge_features
            )
            loss += loss_function(pred, data.k_labels.to(device))
        loss.backward()
        optimizer.step()
        lr_sheduler.step()
    
    gcn_model.eval()
    errors = []
    recall = []
    for data in val_data:
        out = gcn_model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
                # edge_attr=edge_features
            )
        pred = F.softmax(out, dim=1).argmax(axis=1)#.to('cpu')

        _, error = calculate_gap(data.P, data.R, data.V, pred, K, device=device)
        errors.append(error.to('cpu'))
        recall.append(
            recall_score(data.k_labels.to('cpu'), pred.to('cpu'), average="macro")
        )
    
    return np.mean(recall) - np.mean(errors) # Minimise errors while maximising recall score

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=15, n_jobs=1)

best_params = study.best_params

print(study.best_value)
print(best_params)

[I 2023-11-01 18:01:19,532] A new study created in memory with name: no-name-325fa119-e223-45cf-b3a6-552d097f2aad
[I 2023-11-01 18:01:36,260] Trial 0 finished with value: 0.3650246668315851 and parameters: {'hidden_channels': 186, 'dropout': 0.015995434981939807, 'lr': 2.734393372431539, 'weight_decay': 0.0031204553909045256, 'gamma': 0.9540675060228595}. Best is trial 0 with value: 0.3650246668315851.
[I 2023-11-01 18:01:52,010] Trial 1 finished with value: 0.6003516783212614 and parameters: {'hidden_channels': 170, 'dropout': 0.03718399312608728, 'lr': 6.256611530360379, 'weight_decay': 0.01067138133398014, 'gamma': 0.970472237538076}. Best is trial 1 with value: 0.6003516783212614.
[I 2023-11-01 18:02:08,639] Trial 2 finished with value: 0.6395676311156622 and parameters: {'hidden_channels': 195, 'dropout': 0.01405175681230887, 'lr': 2.656559835518944, 'weight_decay': 0.004537985683567207, 'gamma': 0.9885185131486003}. Best is trial 2 with value: 0.6395676311156622.
[I 2023-11-01 18

0.6395676311156622
{'hidden_channels': 195, 'dropout': 0.01405175681230887, 'lr': 2.656559835518944, 'weight_decay': 0.004537985683567207, 'gamma': 0.9885185131486003}


In [25]:
import pandas as pd
trials = [i for i in map(lambda x: dict([("score", x.values[0]),*(x.params).items()]), study.get_trials())]
trials = pd.DataFrame(trials)

score	hidden_channels	num_layers	dropout	lr	weight_decay	weight_param	gamma
21	0.403245	67	1	0.051031	4.809483	0.045045	8	0.972926
20	0.403374	63	1	0.003532	4.839212	0.040630	8	0.989183
14	1.023662	19	1	0.295559	5.025916	0.053842	7	0.798589

In [26]:
trials.sort_values(by="score", ascending=False)#.plot.scatter(x="weight_decay", y="score")

Unnamed: 0,score,hidden_channels,dropout,lr,weight_decay,gamma
14,0.599886,189,0.030787,3.640513,0.008176,0.989859
2,0.57262,151,0.030954,2.301565,0.004388,0.984601
12,0.527417,170,0.027613,7.781805,0.011654,0.985464
11,0.518439,174,0.026031,7.728138,0.010252,0.980316
5,0.496078,172,0.023114,7.176134,0.002211,0.965279
7,0.469434,156,0.040329,3.638957,0.001374,0.994972
4,0.458307,216,0.025998,6.629531,0.0258,0.997925
8,0.455355,209,0.034853,5.428958,0.002279,0.99737
1,0.433783,150,0.01313,6.85955,0.028589,0.967631
3,0.42767,157,0.036784,2.719543,0.031845,0.974796


In [27]:
# experiment = Experiment(savefile="gat_hparams")

In [28]:
# for i in trials.index:
#     trials.loc[i].to_dict()
    # experiment.save(trials.loc[i].to_dict())