In [None]:
from diffpool_helpers.model.diffpool_continuous import TSDiffPool
import pickle
import os
import json
import pandas as pd
import numpy as np
import curvlearn as cv
from curvlearn.manifolds.manifold import Manifold
import torch
import torch.nn as nn
import argparse
import time
import dgl

In [2]:
with open('test_set.pkl', 'rb') as file:
    test_set = pickle.load(file)

with open('train_set.pkl', 'rb') as file:
    train_set = pickle.load(file)

In [3]:
for graph, _ in train_set:
    for key, value in graph.ndata.items():
        graph.ndata[key] = value.float()
for graph, _ in test_set:
    for key, value in graph.ndata.items():
        graph.ndata[key] = value.float()

In [4]:
batch_size = 20
train_batched = []
test_batched = []
for i in range(len(train_set) // batch_size):
    idx_low = batch_size * i
    idx_high = (i + 1) * batch_size
    y_i = [x[1] for x in train_set[idx_low:idx_high]]
    x_i = [x[0] for x in train_set[idx_low:idx_high]]
    if not x_i or not y_i: continue
    train_batched.append((dgl.batch(x_i), y_i))

for i in range(len(test_set) // batch_size):
    idx_low = batch_size * i
    idx_high = (batch_size + 1) * i
    y_i = [x[1] for x in test_set[idx_low:idx_high]]
    x_i = [x[0] for x in test_set[idx_low:idx_high]]
    if not x_i or not y_i: continue
    test_batched.append((dgl.batch(x_i), y_i))

# Train Actual Model

In [5]:
def get_default_args():
    args = argparse.Namespace(
        dataset='default_dataset',
        save_dir='checkpoints',
        epoch=150,
        cuda=0,
        clip=0.5,
        # input_dim=None,
        # hidden_dim=64,
        input_dim=2,
        hidden_dim=2,
        embedding_dim=16,
        activation='relu',
        n_layers=3,
        dropout=0.5,
        n_pooling=1,
        linkpred=False,
        batch_size=1,
        aggregator_type='mean',
        assign_dim=1,
        pool_ratio=0.5,
        cat=True
    )
    return args

In [6]:
input_dim = 1
hidden_dim = 4
embedding_dim = 4
model = TSDiffPool(
    # input_dim,
    # hidden_dim,
    input_dim=2,
    hidden_dim=1,
    embedding_dim=1, # idk why this would be 21...
    label_dim = batch_size, # linear task
    activation=nn.ReLU(),
    n_layers = 4,
    dropout=0.5,
    n_pooling=1,
    linkpred=False,
    batch_size=1,
    aggregator_type="meanpool",
    assign_dim=8,
    pool_ratio=0.5,
)

In [7]:
def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
    """
    training function
    """
    dir = prog_args.save_dir + "/" + prog_args.dataset
    if not os.path.exists(dir):
        os.makedirs(dir)
    dataloader = dataset
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()), lr=0.005
    )
    early_stopping_logger = {"best_epoch": -1, "val_acc": -1}

    if prog_args.cuda > 0:
        torch.cuda.set_device(0)
    losses = []
    # max int
    best_test = float('inf')
    for epoch in range(prog_args.epoch):
        begin_time = time.time()
        model.train()
        accum_correct = 0
        total = 0
        epoch_loss = []
        # print("\nEPOCH ###### {} ######".format(epoch))
        y_preds = torch.zeros((len(dataloader)))
        computation_time = 0.0
        for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
            for key, value in batch_graph.ndata.items():
                batch_graph.ndata[key] = value.float()
            # graph_labels = graph_labels.long()
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()

            model.zero_grad()
            compute_start = time.time()
            ypred = model(batch_graph)
            y_preds[batch_idx] = torch.mean(ypred).item()
            loss = model.loss(ypred, torch.tensor([graph_labels], dtype=torch.float32))
            epoch_loss.append(loss.item())
            loss.backward()
            batch_compute_time = time.time() - compute_start
            computation_time += batch_compute_time
            nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
            optimizer.step()
        losses.append(np.mean(epoch_loss))
        print(f"EPOCH {epoch} LOSS {np.mean(epoch_loss)}")
        print(f"Predictions: {y_preds.mean()}")

        elapsed_time = time.time() - begin_time
        test_losses = []
        for batch_graph, graph_labels in test_set:
            if torch.cuda.is_available():
                batch_graph = batch_graph.to(torch.cuda.current_device())
                graph_labels = graph_labels.cuda()
            model.eval()
            with torch.no_grad():
                ypred = model(batch_graph)
                loss = model.loss(ypred, torch.tensor([graph_labels], dtype=torch.float32))
                test_losses.append(loss.item())
        mean_test_loss = np.mean(test_losses)
        if mean_test_loss < best_test:
            best_val = mean_test_loss
            early_stopping_logger["best_epoch"] = epoch
            early_stopping_logger["val_acc"] = mean_test_loss
            torch.save(
                model.state_dict(),
                dir + "/model_{}_{}.pth".format(prog_args.dataset, epoch),
            )
        else:
            print("Test loss is greater than the best val loss. Quitting...")
            return losses
        torch.cuda.empty_cache()
    return losses, early_stopping_logger

In [None]:
losses, early_stopping_logger = train(train_batched, model, get_default_args(), same_feat=True, val_dataset=None)

In [9]:
import pickle

file_path = 'model_52_4_PL.pkl'

with open(file_path, 'wb') as file:
    pickle.dump(model, file)