In [2]:
import logging
import os
import argparse
import math
import random
import tqdm
import numpy as np
import pandas as pd
from sklearn import preprocessing

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils as utils
import matplotlib.pyplot as plt
import torch.nn as nn

from script import dataloader, utility, earlystopping
from model import models, layers


In [None]:
from main import *

In [None]:
class dilated_inception(nn.Module):
    def __init__(self, cin, cout, dilation_factor=2):
        super(dilated_inception, self).__init__()
        self.tconv = nn.ModuleList()
        self.kernel_set = [2,3,6,7]
        cout = int(cout/len(self.kernel_set))
        for kern in self.kernel_set:
            self.tconv.append(nn.Conv2d(cin,cout,(1,kern),dilation=(1,dilation_factor)))

    def forward(self,input):
        x = []
        for i in range(len(self.kernel_set)):
            x.append(self.tconv[i](input))
        for i in range(len(self.kernel_set)):
            x[i] = x[i][...,-x[-1].size(3):]
        x = torch.cat(x,dim=1)
        return x

In [None]:
logging.basicConfig(level=logging.INFO)

args, device, blocks = get_parameters()
n_vertex, zscore, train_iter, val_iter, test_iter = data_preparate(args, device)    # n_vertex = 207,

for x, y in tqdm.tqdm(train_iter):
    num = 0

print(x.shape)

In [26]:
kern = 2
dilation_factor = 1
cin = 12
cout = 16

TCconv = nn.ModuleList()
kernel_set = [2,3,4,5]

x = torch.rand(32,1,12,207)
cin = x.shape[1]

conv1 = nn.Conv2d(cin,cout,kernel_size=(1,2))
conv2 = nn.Conv2d(cin,cout,kernel_size=(1,3))
conv3 = nn.Conv2d(cin,cout,kernel_size=(1,4))
conv4 = nn.Conv2d(cin,cout,kernel_size=(1,5))

In [20]:
x.permute(0,1,3,2).shape

torch.Size([32, 1, 207, 12])

In [27]:
print(conv1(x.permute(0,1,3,2)).shape)
print(conv2(x.permute(0,1,3,2)).shape)
print(conv3(x.permute(0,1,3,2)).shape)
print(conv4(x.permute(0,1,3,2)).shape)

torch.Size([32, 16, 207, 11])
torch.Size([32, 16, 207, 10])
torch.Size([32, 16, 207, 9])
torch.Size([32, 16, 207, 8])


In [28]:
x1 = conv1(x.permute(0,1,3,2))
x2 = conv2(x.permute(0,1,3,2))
x3 = conv3(x.permute(0,1,3,2))
x4 = conv4(x.permute(0,1,3,2))

In [35]:
out = []
out.append(x1[...,-x4.shape[-1]:])
out.append(x2[...,-x4.shape[-1]:])
out.append(x3[...,-x4.shape[-1]:])
out.append(x4[...,-x4.shape[-1]:])

In [36]:
torch.cat(out,dim=1).shape

torch.Size([32, 64, 207, 8])

In [None]:
for kernel_size in kernel_set:
    TCconv.append(nn.Conv2d(cin,cout,kernel_size=(1,kernel_size),dilation=(1,dilation_factor)))

In [38]:
x.permute(0,1,3,2).shape

torch.Size([32, 1, 207, 12])

In [40]:
class Dialated_Block(nn.Module):
    def __init__(self, cin, cout, dilation_factor=1):
        super(Dialated_Block,self).__init__()
        self.TCconv = nn.ModuleList()
        self.kernel_set = [2,3,4,5]
        cout = int(cout / len(self.kernel_set))
        for kernel_size in self.kernel_set:
            self.TCconv.append(nn.Conv2d(cin,cout,kernel_size=(1,kernel_size),dilation=(1,dilation_factor)))
        
    def forward(self,input):
        x = []
        for i in range(len(self.kernel_set)):
            x.append(self.TCconv[i](input))
            
        for i in range(len(self.kernel_set)):
            x[i] = x[i][...,-x[-1].shape[-1]:]
        x = torch.cat(x,dim=1)
        x = x.permute(0,1,3,2)
        return x


In [45]:
class Dialated_Block(nn.Module):
    def __init__(self, cin, cout, dilation_factor=1):
        super(Dialated_Block,self).__init__()
        self.TCconv = nn.ModuleList()
        self.kernel_set = [2,3]
        cout = int(cout / len(self.kernel_set))
        for kernel_size in self.kernel_set:
            self.TCconv.append(nn.Conv2d(cin,cout,kernel_size=(1,kernel_size),dilation=(1,dilation_factor)))
        
    def forward(self,input):
        input = input.permute(0,1,3,2)
        x = []
        
        for i in range(len(self.kernel_set)):
            x.append(self.TCconv[i](input))
        for i in range(len(self.kernel_set)):
            x[i] = x[i][...,-x[-1].shape[-1]:]
        x = torch.cat(x,dim=1)
        x = x.permute(0,1,3,2)
        return x


In [46]:
x.shape

torch.Size([32, 1, 12, 207])

In [41]:
tep_layer = Dialated_Block(1,64)

In [42]:
tep_layer

Dialated_Block(
  (TCconv): ModuleList(
    (0): Conv2d(1, 16, kernel_size=(1, 2), stride=(1, 1))
    (1): Conv2d(1, 16, kernel_size=(1, 3), stride=(1, 1))
    (2): Conv2d(1, 16, kernel_size=(1, 4), stride=(1, 1))
    (3): Conv2d(1, 16, kernel_size=(1, 5), stride=(1, 1))
  )
)

In [44]:
tep_layer(x.permute(0,1,3,2)).shape  

torch.Size([32, 64, 207, 8])