In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np 
import torch_geometric.nn as pyg_nn 
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Dataset, Data, DataLoader
import torch.optim as optim
import os.path as osp
import scipy.io as sio
from datetime import datetime
from tensorboardX import SummaryWriter
from dataset import QUASARDataset
from model import ModelS

In [2]:
# dir = '/Users/hankyang/Datasets/QUASAR'
dir = '/home/hank/Datasets/QUASAR'
dataset = QUASARDataset(dir)
writer = SummaryWriter("./log/" + datetime.now().strftime("%Y%m%d-%H%M%S"))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
# loader  = DataLoader(dataset,batch_size=1,shuffle=True)
def train(dataset,writer):
    # build model
    model   = ModelS(mp_input_dim=6,mp_hidden_dim=32,mp_output_dim=64,mp_num_layers=1, 
                     primal_node_mlp_hidden_dim=64,primal_node_mlp_output_dim=10,
                     dual_node_mlp_hidden_dim=64,dual_node_mlp_output_dim=10,
                     node_mlp_num_layers=1,
                     primal_edge_mlp_hidden_dim=64,primal_edge_mlp_output_dim=10, 
                     dual_edge_mlp_hidden_dim=64,dual_edge_mlp_output_dim=6, 
                     edge_mlp_num_layers=1, 
                     dropout_rate=0.2)
    model.double() # convert all parameters to double
    model.to(device)
    opt = optim.Adam(model.parameters(),lr=0.01)

    # train
    num_epoches = 10
    for epoch in range(num_epoches):
        total_loss = 0
        model.train()
        for i in range(len(dataset)):
            opt.zero_grad()
            data = dataset[i].to(device)
            x, X, S, Aty = model(data)
            loss = model.loss(data,X,S,Aty)
            loss.backward()
            opt.step()
            total_loss += loss.item()
            print('graph {}. loss: {:.4f}.'.format(i,loss.item()))
        total_loss /= len(dataset)
        writer.add_scalar("loss", total_loss, epoch)
        print("Epoch {}. Loss: {:.4f}.".format(epoch, total_loss))
    return model

In [8]:
model = train(dataset,writer)

graph 0. loss: 11564.0055.
graph 1. loss: 11479.8635.
graph 2. loss: 11375.2209.
graph 3. loss: 11479.7515.
graph 4. loss: 11407.2278.
graph 5. loss: 11484.6304.
graph 6. loss: 11556.8459.
graph 7. loss: 11607.1547.
graph 8. loss: 11517.2378.
graph 9. loss: 11489.4784.
graph 10. loss: 11629.6318.
graph 11. loss: 11425.0710.
graph 12. loss: 11361.8278.
graph 13. loss: 11536.2976.
graph 14. loss: 11566.8950.
graph 15. loss: 11442.8712.
graph 16. loss: 11465.8513.
graph 17. loss: 11467.0776.
graph 18. loss: 11667.9359.
graph 19. loss: 11466.3238.
graph 20. loss: 11489.5964.
graph 21. loss: 11567.1923.
graph 22. loss: 11616.5023.
graph 23. loss: 11434.3724.
graph 24. loss: 11459.6099.
graph 25. loss: 11425.7042.
graph 26. loss: 11655.6241.
graph 27. loss: 11325.9890.
graph 28. loss: 11889.2223.
graph 29. loss: 11532.0765.
graph 30. loss: 11799.4028.
graph 31. loss: 11520.3834.
graph 32. loss: 11436.6303.
graph 33. loss: 11531.4156.
graph 34. loss: 11396.4439.
graph 35. loss: 11467.1679.
gr

KeyboardInterrupt: 