In [None]:
# %%
from generate_mdps import generate_datsets
from dataset import MDPDataset, AllNodeFeatures, InMemoryMDPDataset, TransitionsOnEdge
from experiment import Experiment
from MDP_helpers import calculate_gap, multiclass_recall_score

# %%
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import optuna
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from torch_geometric.nn.models import GCN, GAT
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import random_split
from collections import defaultdict
from sklearn.metrics import recall_score, accuracy_score
from sklearn.model_selection import KFold
from torch.utils.data import SubsetRandomSampler

from time import time
from tqdm import tqdm

torch.cuda.manual_seed(12345)
np.random.seed(12345)

import pickle

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device='cpu'
print(device)

In [None]:
# Configs
experiment_name = "GCN_GCN"
hparam_file = "hparams"
ascending = True
edge_attributes = True
model_1 = GCN
model_2 = GCN
pre_transform = AllNodeFeatures()
lr = 0.001
N_epochs = 3000
N_epochs_H = 3000

N_datasets = 100

N_sites = 5
N_species = 20
K = 7
N_states = 3**N_sites

recreate_data = False
N_trials = 5

train_ratio = 0.8
test_ratio = 1-train_ratio

filename = f'Reserve_MDP_{N_states}_{K}'

In [None]:
print(f"MDP Data: N_states: {N_states}")
dataset_folder = f"Reserve_MDP_{N_states}_{K}"
generate_datsets(N_sites, N_species, K, N_datasets, remove_previous=recreate_data, folder=dataset_folder)

print("Loading data into dataloader")
print(pre_transform)
dataset = InMemoryMDPDataset(f"datasets/{dataset_folder}", pre_transform=pre_transform)
print(dataset[0])
if torch.all(dataset[0].R == dataset[5].R):
    raise Exception("Datasets are likely identical!!")

In [None]:
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size

h_param_size = int(train_ratio*train_size)
val_size = train_size - h_param_size

train_set, test_set = random_split(dataset, [train_size, test_size])
hparam_train_set, hparam_val_set = random_split(train_set, [h_param_size, val_size])

train_data = DataLoader(train_set, batch_size=1, shuffle=True)
test_data = DataLoader(test_set, batch_size=1, shuffle=True)
hparam_train_data = DataLoader(hparam_train_set, batch_size=1, shuffle=True)
hparam_val_data = DataLoader(hparam_val_set, batch_size=1, shuffle=True)

In [None]:
def loss_func(logits, target, V, mask, weights, lam, mu):
    targets_one_hot = torch.zeros(target.size(0), K)
    targets_one_hot.scatter_(1, target.unsqueeze(1), 1.0)

    probs = (torch.exp(logits.T)/torch.exp(logits).sum(axis=1)).T
    log_probs = torch.log(probs)
    ce_loss = -(weights*log_probs*targets_one_hot).sum(axis=1).mean()
    loss = ce_loss - ((mu - lam*V)*mask.flatten()).sum()

    return loss

In [None]:
def objective(trial):
    lam = trial.suggest_float("lam", 1e-3, 0.05)
    mu = trial.suggest_float("mu", 1e-3, 0.05)

    hidden_channels_1 = trial.suggest_int("hidden_channels_1", 100, 250)
    hidden_channels_2 = trial.suggest_int("hidden_channels_2", 100, 250)

    num_layers_1 = trial.suggest_int("num_layers_1", 1, 3)
    num_layers_2 = trial.suggest_int("num_layers_2", 1, 3)

    dropout_1 = trial.suggest_float("dropout_1", 0, 0.05)
    dropout_2 = trial.suggest_float("dropout_2", 0, 0.05)

    weight_decay_1 = trial.suggest_float("weight_decay_1", 1e-3, 0.05)
    weight_decay_2 = trial.suggest_float("weight_decay_2", 1e-3, 0.05)

    m1 = model_1(
        in_channels=-1,
        hidden_channels=hidden_channels_1,
        num_layers=num_layers_1,
        dropout=dropout_1,
        out_channels=1 #applied a linear layer to get out channels to 1
    )

    m2 = model_2(
        in_channels=-1, 
        out_channels=K, 
        hidden_channels=hidden_channels_2,
        num_layers=num_layers_2,
        dropout=dropout_2,
    )

    optimizer_1 = torch.optim.SGD(m1.parameters(), lr=lr, weight_decay=weight_decay_1)
    optimizer_2 = torch.optim.SGD(m2.parameters(), lr=lr, weight_decay=weight_decay_2)


    m1.train()
    m2.train()
    for epoch in range(N_epochs_H):
        optimizer_1.zero_grad()     
        optimizer_2.zero_grad()     

        loss = 0

        for data in hparam_train_data:
            pred1 = m1(
                x=data.x.to(device),
                edge_index=data.edges.to(device)
            )

            mask = F.sigmoid(pred1)

            pred2 = m2(
                x=data.x.to(device),
                edge_index=data.edges.to(device)
            )

            weight = torch.bincount(data.k_labels)
            weight = weight/weight.sum()

            loss += loss_func(
                logits=pred2,
                target=data.k_labels.to(device),
                V = data.V.to(device),
                mask=mask,
                weights=torch.tensor([1 for i in range(K)]),
                lam = lam,
                mu = mu
            )

        loss /= len(hparam_train_data) 
        loss.backward()
        optimizer_1.step()
        optimizer_2.step()

    test_gap = 0
    m1.eval()
    m2.eval()
    for data in hparam_val_data:
        pred1 = m1(
            x=data.x.to(device),
            edge_index=data.edges.to(device)
        )

        mask = F.sigmoid(pred1)

        pred2 = m2(
            x=data.x.to(device),
            edge_index=data.edges.to(device)
        )

        # weight = torch.bincount(data.k_labels)
        # weight = weight/weight.sum()

        # test_loss += loss_func(
        #     logits=pred2,
        #     target=data.k_labels.to(device),
        #     V = data.V.to(device),
        #     mask=mask,
        #     weights=torch.tensor([1 for i in range(K)]),
        #     lam = lam,
        #     mu = mu
        # )
        pred_k = F.softmax(pred2, dim=1).argmax(axis=1)
        filt = mask > 0.5

        count = 1*K
        for i in range(len(filt)):
            if filt[i]: 
                pred_k[i] = count
                count += 1
        

        gap, error = calculate_gap(data.P, data.R, data.V, pred_k, K, device='cpu')
        test_gap += gap

    test_gap /= len(hparam_val_data)

    return test_gap


study = optuna.create_study(direction='minimize')
study.optimize(objective, n_trials=30)

best_params = study.best_params

print(study.best_value)
print(best_params)