In [5]:
import torch
import torch.nn.functional as F
import numpy as np
from model import TemporalGNNBatch
from dataloader import DataLoader
from torch_geometric_temporal.signal import temporal_signal_split
from tqdm import tqdm 
import torch.optim.lr_scheduler as lr_scheduler
import datetime
import os

from torch_geometric_temporal.nn.recurrent import A3TGCN,A3TGCN2
import yaml
from tensorboardX import SummaryWriter

torch.manual_seed(100)

name_exp = 'biovid'
config_file=open("./config/"+name_exp+".yml", 'r')
config = yaml.safe_load(config_file)


In [3]:
class TemporalGNNBatch(torch.nn.Module):
    def __init__(self, node_features=4,output_features=1,num_nodes=51,embed_dim=32, periods=137,batch_size=32):
        super(TemporalGNNBatch, self).__init__()
        self.node_features = node_features
        self.embed_dim=embed_dim
        self.periods=periods
        self.output_features=output_features
        self.num_nodes=num_nodes
        self.batch_size=batch_size
        self.tgnn = A3TGCN2(in_channels=self.node_features,
                           out_channels=self.embed_dim,
                           periods=self.periods,batch_size=self.batch_size)
        self.dropout = torch.nn.Dropout(0.2)
        self.linear_1= torch.nn.Linear(self.embed_dim*self.num_nodes, 500)# input[32,[51*32]] uscita come [batch 32, 51*1] per ogni nodo ha un imbed di dimensione 32 invece del 137
        self.linear_2=torch.nn.Linear(500, self.output_features)# batch [32, 51] [32,1]


    def forward(self, x, edge_index):
        """
        x = Node features for T time steps [51,4,137]
        edge_index = Graph edge indices [2,num_edges]
        """
        h = self.tgnn(x, edge_index) #batch,
      
        h=self.dropout(h)
        h=h.view(-1,self.embed_dim*self.num_nodes)
        h = self.linear_1(h)
        #h = F.relu(h)
        #print("linear 1 shape", h)
       
        #print(h.shape)
       # print("linear 1 shape after view", h.shape)
        h= self.linear_2(h)
        #print("linear 12 shape", h.shape)
        #h = F.relu(h)
        h=5.0 * torch.sigmoid(h)
        
        #h=h.view(-1)
        #print("relu", h.shape)
       # print("finale h ", h)
        return h


In [6]:
TemporalGNNBatch()

TemporalGNNBatch(
  (tgnn): A3TGCN2(
    (_base_tgcn): TGCN2(
      (conv_z): GCNConv(4, 32)
      (linear_z): Linear(in_features=64, out_features=32, bias=True)
      (conv_r): GCNConv(4, 32)
      (linear_r): Linear(in_features=64, out_features=32, bias=True)
      (conv_h): GCNConv(4, 32)
      (linear_h): Linear(in_features=64, out_features=32, bias=True)
    )
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (linear_1): Linear(in_features=1632, out_features=500, bias=True)
  (linear_2): Linear(in_features=500, out_features=1, bias=True)
)

In [7]:
if torch.cuda.is_available():
    print("set cuda device")
    device="cuda"
    torch.cuda.set_device(1)
else:
    device="cpu"
    print('Warning: Using CPU')

set cuda device
