In [None]:
from generate_mdps import generate_datsets
from dataset import MDPDataset, AllNodeFeatures, InMemoryMDPDataset, TransitionsOnEdge
from experiment import Experiment
from MDP_helpers import calculate_gap, multiclass_recall_score

In [None]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import optuna
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GCN, GAT
from torch.utils.data import random_split
from collections import defaultdict
from sklearn.metrics import recall_score, accuracy_score

from time import time
from tqdm import tqdm

In [None]:
torch.cuda.manual_seed(12345)
np.random.seed(12345)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device='cpu'
print(device)

In [None]:
hparam_file = "hparams"
experiment_name = "GCN_weighted_fully_connected"

os.mkdir(f"Results/{experiment_name}") if not os.path.isdir(f"Results/{experiment_name}") else ...

In [None]:
N_datasets = 100

N_sites = 5
N_species = 20
K = 7

gnn_model=GCN

N_epochs = 3000
N_epochs_H = 1000
lr = 0.001
lr_H = 0.01
edge_attributes = False

N_states = 3**N_sites
print(f"N_states: {N_states}")
dataset_folder = f"Reserve_MDP_{N_states}_{K}"
generate_datsets(N_sites, N_species, K, N_datasets, remove_previous=False, folder=dataset_folder)

In [None]:
dataset = InMemoryMDPDataset(f"datasets/{dataset_folder}", pre_transform=AllNodeFeatures(thresh=-1))

In [None]:
torch.all(dataset[0].R == dataset[5].R)

In [None]:
train_ratio = 0.8
test_ratio = 0.2

train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size

h_param_size = int(train_ratio*train_size)
val_size = train_size - h_param_size

train_set, test_set = random_split(dataset, [train_size, test_size])
hparam_train_set, hparam_val_set = random_split(train_set, [h_param_size, val_size])

train_data = DataLoader(train_set, batch_size=1, shuffle=True)
test_data = DataLoader(test_set, batch_size=1, shuffle=True)
hparam_train_data = DataLoader(hparam_train_set, batch_size=1, shuffle=True)
hparam_val_data = DataLoader(hparam_val_set, batch_size=1, shuffle=True)

In [None]:
def objective(trial):
    hidden_channels = trial.suggest_int("hidden_channels", 100, 250)
    num_layers = trial.suggest_int("num_layers", 1, 3)
    dropout = trial.suggest_float("dropout", 0, 0.05)
    weight_decay = trial.suggest_float("weight_decay", 1e-3, 0.05)


    model = gnn_model(
        in_channels=dataset[0].x.shape[1], 
        out_channels=K, 
        hidden_channels=hidden_channels, 
        num_layers=num_layers, 
        dropout=dropout
    ).to(device)


    optimizer = torch.optim.Adam(model.parameters(), lr=lr_H, weight_decay=weight_decay)

    model.train()
    for epoch in range(N_epochs_H):
        optimizer.zero_grad()     
        loss = 0

        for data in hparam_train_data:
            if edge_attributes:
                pred = model(
                    x = data.x.to(device), 
                    edge_index=data.edges.to(device), 
                    edge_attr=data.edge_features.to(device)
                )
            else:
                pred = model(
                    x = data.x.to(device), 
                    edge_index=data.edges.to(device), 
                )  
            weight = torch.bincount(data.k_labels)
            weight = weight/weight.sum()
            loss += F.cross_entropy(pred, data.k_labels.to(device), weight=weight.to(device))

        loss /= len(train_data) 
        loss.backward()
        optimizer.step()

        test_loss = 0
        model.eval()
        for data in hparam_val_data:
            if edge_attributes:
                pred = model(
                    x = data.x.to(device), 
                    edge_index=data.edges.to(device), 
                    edge_attr=data.edge_features.to(device)
                )
            else:
                pred = model(
                    x = data.x.to(device), 
                    edge_index=data.edges.to(device), 
                )  
            weight = torch.bincount(data.k_labels)
            weight = weight/weight.sum()
            test_loss += F.cross_entropy(pred, data.k_labels.to(device), weight=weight.to(device))
        test_loss /= len(hparam_val_data)
    return test_loss

study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=30)

best_params = study.best_params

print(study.best_value)
print(best_params)

In [None]:
# trials = [i for i in map(lambda x: dict([("loss", x.values[0]), ("recall", x.values[1]),*(x.params).items()]), study.get_trials())]
# trials = pd.DataFrame(trials)
# trials = trials.sort_values(by='recall', ascending=False)

trials = [i for i in map(lambda x: dict([("score", x.values[0]),*(x.params).items()]), study.get_trials())]
trials = pd.DataFrame(trials)
trials = trials.sort_values(by='score', ascending=False)

experiment = Experiment(savefile=f"Results/{experiment_name}/hparams")
for i in trials.index:
    trials.loc[i].to_dict()
    experiment.save(trials.loc[i].to_dict())

In [None]:
pd.DataFrame(experiment.load()).sort_values(by="score", ascending=True).head()

In [None]:
all_results = defaultdict(lambda : defaultdict(list))
trial_name = "Trial"
trial_num = 0

hidden_channels = int(best_params['hidden_channels'])
num_layers = int(best_params['num_layers'])
dropout = best_params['dropout']
weight_decay = best_params['weight_decay']

model = gnn_model(
    in_channels=dataset[0].x.shape[1], 
    out_channels=K, 
    hidden_channels=hidden_channels, 
    num_layers=num_layers, 
    dropout=dropout
).to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

epochs = tqdm(range(N_epochs))
for epoch in epochs:
    model.train()
    optimizer.zero_grad()     

    loss = 0
    for data in train_data:
        if edge_attributes:
            pred = model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
                edge_attr=data.edge_features.to(device)
            )
        else:
            pred = model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
            )  
        weight = torch.bincount(data.k_labels)
        weight = weight/weight.sum()
        loss += F.cross_entropy(pred, data.k_labels.to(device), weight=weight.to(device))

    loss /= len(train_data) 
    loss.backward()
    optimizer.step()

    all_results[trial_name]['training_loss'].append(loss.to('cpu').detach().float())


    model.eval()
    test_loss = 0
    avg_gap = 0
    avg_error = 0
    avg_recall = 0
    avg_acc = 0
    for data in test_data:
        if edge_attributes:
            pred = model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
                edge_attr=data.edge_features.to(device)
            )
        else:
            pred = model(
                x = data.x.to(device), 
                edge_index=data.edges.to(device), 
            )  
        pred_k = F.softmax(pred, dim=1).argmax(axis=1)
        weight = torch.bincount(data.k_labels)
        weight = weight/weight.sum()
        test_loss += F.cross_entropy(pred, data.k_labels.to(device), weight=weight.to(device))
        gap, error = calculate_gap(data.P, data.R, data.V, pred_k, K, device='cpu')
        avg_gap += gap
        avg_error += error
        avg_recall += recall_score(data.k_labels, pred_k.to('cpu'), average="macro")
        avg_acc += accuracy_score(data.k_labels, pred_k.to('cpu'))


    test_loss /= len(test_data)
    avg_gap /= len(test_data)
    avg_error /= len(test_data)
    avg_recall /= len(test_data)
    avg_acc /= len(test_data)

    all_results[trial_name]['test_loss'].append(test_loss.to('cpu').detach().float())
    all_results[trial_name]['test_gap'].append(avg_gap)
    all_results[trial_name]['test_error'].append(avg_error)
    all_results[trial_name]['test_recall'].append(avg_recall)
    all_results[trial_name]['test_accuracy'].append(avg_acc)

    epochs.set_description(f"Trial {trial_num}, Epoch {epoch+1}/{N_epochs}, Loss {test_loss:.4f}, Gap {avg_gap:.4f}, Recall {avg_recall:.4f}, Accuracy {avg_acc:.4f}")

In [None]:
all_results.pop("trial_0");    

In [None]:
processed = {}
for key in all_results["Trial"].keys():
    df = pd.DataFrame({trial_id:all_results[trial_id][key] for trial_id in all_results.keys()}).astype(float)
    df.to_csv(f"Results/{experiment_name}/{key}.csv")
    processed[key] = df

In [None]:
processed = {}
for key in all_results["Trial"].keys():
    df = pd.DataFrame({trial_id:all_results[trial_id][key] for trial_id in all_results.keys()}).astype(float)
    # df.to_csv(f"Results/{experiment_name}/{filename}/{key}.csv")
    processed[key] = df

print("Generating plots")
n_plots = len(all_results["Trial"].keys())
n_cols = 2
n_rows = int((n_plots + n_plots%2)/2)

fig, ax = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(15, 6*n_rows))

count = 0
for key in all_results["Trial"].keys():
    col = count %2
    row = count //2 
    count += 1

    df_long = processed[key].stack()
    df_long.index = df_long.index.to_flat_index().map(lambda x: x[0])
    sns.lineplot(df_long, errorbar='ci', ax=ax[row, col])
    ax[row, col].set_ylabel(key)
    ax[row, col].set_xlabel("Epoch")

plt.savefig(f"Results/{experiment_name}/plots.png")

In [None]:
processed