In [1]:
from generate_mdps import generate_datsets
from dataset import MDPDataset, AllNodeFeatures, InMemoryMDPDataset, TransitionsOnEdge
from experiment import Experiment
from MDP_helpers import calculate_gap, multiclass_recall_score

In [2]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import optuna
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GCN, GAT
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split
from collections import defaultdict
from sklearn.metrics import recall_score

from time import time
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.cuda.manual_seed(12345)
np.random.seed(12345)

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device='cpu'
print(device)

cuda


In [5]:
# hparam_file = "gat_hparams"
# experiment_name = "gcn_243_hparam_30"

# os.mkdir(f"Results/{experiment_name}") if not os.path.isdir(f"Results/{experiment_name}") else ...

In [6]:
N_datasets = 100

N_sites = 5
N_species = 20
K = 7

N_states = 3**N_sites
print(f"N_states: {N_states}")
generate_datsets(N_sites, N_species, K, N_datasets, remove_previous=False, folder="Reserve_MDP_243_7")

N_states: 243


'Data already exists'

In [7]:
# dataset = InMemoryMDPDataset(f"datasets/mdp_{N_states}_state", pre_transform=TransitionsOnEdge())
dataset = InMemoryMDPDataset("datasets/Reserve_MDP_243_7", pre_transform=TransitionsOnEdge())

  edge_features = torch.Tensor([P[:, i[0], i[1]].numpy() for i in edges])


In [8]:
train_ratio = 0.8
test_ratio = 0.2

train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size

train_set, test_set = random_split(dataset, [train_size, test_size])

train_data = DataLoader(train_set, batch_size=1, shuffle=True)
test_data = DataLoader(test_set, batch_size=1, shuffle=True)

In [None]:
N_epochs = 500

def objective(trial):
    hidden_channels = trial.suggest_int("hidden_channels", 1, 300)
    num_layers = trial.suggest_int("num_layers", 1, 3)
    dropout = trial.suggest_float("dropout", 1e-6, 0.5)
    lr = trial.suggest_float("lr", 1e-6, 4)
    weight_decay = trial.suggest_float("weight_decay", 1e-6, 1)
    gamma = trial.suggest_float("gamma", 0.85, 1)
    vi_hidden_channels = trial.suggest_int("vi_hidden_channels", 60, 150)

    # hidden_channels = trial.suggest_int("hidden_channels", 150, 225)
    # num_layers = 1#trial.suggest_int("num_layers", 1, 3)
    # dropout = trial.suggest_float("dropout", 0, 0.05)
    # lr = trial.suggest_float("lr", 2, 8)
    # weight_decay = trial.suggest_float("weight_decay", 1e-3, 0.04)
    # gamma = trial.suggest_float("gamma", 0.95, 1)

    gnn_model = GAT(
        in_channels=dataset[0].x.shape[1],
        out_channels=K,
        hidden_channels=hidden_channels,
        num_layers=num_layers,
        dropout=dropout,
    ).to(device)


    optimizer = torch.optim.Adam(gnn_model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_function = torch.nn.CrossEntropyLoss()
    # lr_sheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma)

    gnn_model.train()
    old_loss = 0
    old_count = 0
    for epoch in range(N_epochs):
        optimizer.zero_grad()     
        loss = 0
        for data in train_data:
            pred = gnn_model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
                edge_attr=data.edge_features.to(device)
            )
            loss += loss_function(pred, data.k_labels.to(device))
        if torch.abs(loss - old_loss) < 1e-9:
            if old_count > 25:
                break
            else:
                old_count += 1
        else:
            old_loss = 1*loss
            old_count = 0
        loss.backward()
        optimizer.step()
        # lr_sheduler.step()
    
    gnn_model.eval()
    errors = []
    recall = []
    for data in test_data:
        out = gnn_model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
                edge_attr=data.edge_features.to(device)
            )
        pred = F.softmax(out, dim=1).argmax(axis=1)#.to('cpu')

        _, error = calculate_gap(data.P, data.R, data.V, pred, K, device=device)
        errors.append(error.to('cpu'))
        recall.append(
            recall_score(data.k_labels.to('cpu'), pred.to('cpu'), average="macro")
        )
    
    return np.mean(recall) - np.mean(errors) # Minimise errors while maximising recall score

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=30, n_jobs=1)

best_params = study.best_params

print(study.best_value)
print(best_params)