In [4]:
import scipy 
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import torch
import torch.nn as nn
import math
import random

In [2]:
import torcheeg as teeg
from torcheeg.datasets.constants.emotion_recognition.dreamer import DREAMER_ADJACENCY_MATRIX, DREAMER_CHANNEL_LOCATION_DICT
from torcheeg.datasets.constants.emotion_recognition.seed import SEED_ADJACENCY_MATRIX, SEED_CHANNEL_LOCATION_DICT
from torcheeg.datasets.constants.emotion_recognition.seed_iv import SEED_IV_ADJACENCY_MATRIX, SEED_IV_CHANNEL_LOCATION_DICT
from torcheeg.datasets import DREAMERDataset, SEEDDataset, SEEDIVDataset
from torcheeg import transforms
from torcheeg.transforms.pyg import ToG

In [5]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


seed_everything(73)

In [118]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.parametrizations import weight_norm


class GatedLinearUnits(nn.Module):
    def __init__(self, in_channels, out_channels, hid_channels=16, kernel_size=2, dilation=1,  groups=4, activate=False):
        super(GatedLinearUnits, self).__init__()

        self.kernel_size = kernel_size
        self.dilation = dilation
        self.activate = activate

        self.conv = weight_norm(nn.Conv2d(in_channels, out_channels,
                                          (1, kernel_size), 
                                          dilation=(1, dilation), bias=True, groups=groups))
        nn.init.xavier_uniform_(self.conv.weight, gain=np.sqrt(2.0))
        nn.init.constant_(self.conv.bias, 0.1)
        self.gate = weight_norm(nn.Conv2d(in_channels, out_channels,
                                          (1, kernel_size),
                                          dilation=(1, dilation), bias=True, groups=groups))
        nn.init.xavier_uniform_(self.gate.weight, gain=np.sqrt(2.0))
        nn.init.constant_(self.gate.bias, 0.1)
        self.downsample = weight_norm(nn.Conv2d(in_channels, out_channels, (1, 1), bias=True))
        nn.init.xavier_uniform_(self.downsample.weight, gain=np.sqrt(2.0))
        nn.init.constant_(self.downsample.bias, 0.1)

        self.relu = nn.ReLU()
        self.bn = nn.BatchNorm2d(out_channels, momentum=0.2)
        self.bn.weight.data.fill_(1)
        self.bn.bias.data.fill_(0.1)

        self.sigmod = nn.Sigmoid()
        
    def forward(self, X):
        res = X
        gate = X
        print('X_in', X.shape)
        # X = nn.functional.pad(X, ((self.kernel_size-1)*self.dilation, 0, 0))
        out = self.conv(X)
        if self.activate:
            out = F.tanh(out)

        gate = nn.functional.pad(gate, ((self.kernel_size-1)*self.dilation, 0, 0))
        gate = self.gate(gate)
        gate = self.sigmod(gate)

        out = torch.mul(out, gate)
        ones = torch.ones_like(gate)

        # print('res', res.shape, out.shape)
        if res.shape[1] != out.shape[1]:
            res = self.downsample(res)
        res = torch.mul(res, ones-gate)
        out = out + res
        # out = self.bn(self.relu(out))
        out = self.relu(self.bn(out))
        # print('X_out', out.shape)
        return out


class TimeBlock(nn.Module):
    """
    Neural network block that applies a temporal convolution to each node of
    a graph in isolation.
    """

    def __init__(self, in_channels, out_channels, kernel_size=2, nhid_channels=128, dropout=0.6, layer=3):
        """
        :param in_channels: Number of input features at each node in each time
        step.
        :param out_channels: Desired number of output channels at each node in
        each time step.
        :param kernel_size: Size of the 1D temporal kernel.
        """
        super(TimeBlock, self).__init__()
        layers = []
        for i in range(layer):
            print('in_channels', in_channels)
            if i == 0:
                layers.append(GatedLinearUnits(in_channels, nhid_channels, kernel_size=1, dilation=1, groups=1))
                print(i, in_channels, nhid_channels)
            elif i == layer-1:
                layers.append(GatedLinearUnits(nhid_channels, out_channels, kernel_size=1, dilation=1, groups=1))
                print(i, nhid_channels, out_channels)
            else:
                layers.append(GatedLinearUnits(nhid_channels, nhid_channels, kernel_size=kernel_size, dilation=2**(i), groups=1))
                print(i, nhid_channels, nhid_channels)
                
        self.units = nn.Sequential(*layers)

    def forward(self, X):
        print('in', X.shape)
        X = X.permute(0, 2, 1)
        out = self.units(X)
        out = out.permute(0, 1, 2)
        print('out', X.shape)
        return out


In [87]:
def conv_init(conv):
    nn.init.kaiming_normal_(conv.weight, mode='fan_out')
    # nn.init.constant_(conv.bias, 0)


def bn_init(bn, scale):
    nn.init.constant_(bn.weight, scale)
    nn.init.constant_(bn.bias, 0)


def fc_init(fc):
    nn.init.xavier_normal_(fc.weight)
    nn.init.constant_(fc.bias, 0)

In [114]:
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, in_features, out_features, num_nodes, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat
        self.num_nodes = num_nodes

        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=np.sqrt(2.0))
        self.a1 = nn.Parameter(torch.zeros(size=(out_features, 1)))
        self.a2 = nn.Parameter(torch.zeros(size=(out_features, 1)))
        nn.init.xavier_uniform_(self.a1.data, gain=np.sqrt(2.0))
        nn.init.xavier_uniform_(self.a2.data, gain=np.sqrt(2.0))

        self.leakyrelu = nn.LeakyReLU(self.alpha)
        self.downsample = nn.Conv1d(in_features, out_features, 1)

        self.bias = nn.Parameter(torch.zeros(num_nodes, out_features))


    def forward(self, input, adj):
        batch_size = input.size(0)
        h = torch.bmm(input, self.W.expand(batch_size, self.in_features, self.out_features))
        f_1 = torch.bmm(h, self.a1.expand(batch_size, self.out_features, 1))
        f_2 = torch.bmm(h, self.a2.expand(batch_size, self.out_features, 1))
        e = self.leakyrelu(f_1 + f_2.transpose(2,1))
        # add by xyk
        attention = torch.mul(adj, e)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.bmm(attention, h) + self.bias.expand(batch_size, self.num_nodes, self.out_features)
        if input.shape[-1] != h_prime.shape[-1]:
            input = self.downsample(input.permute(0, 2, 1)).permute(0, 2, 1).contiguous()
            h_prime = h_prime + input
        else:
            h_prime = h_prime + input
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


class STGATBlock(nn.Module):
    def __init__(self, in_channels, spatial_channels, out_channels,
                 num_nodes, num_timesteps_input,
                 dropout=0.6, alpha=0.2, nheads=4, concat=True):
        super(STGATBlock, self).__init__()
        self.nheads = nheads
        self.concat = concat
        self.spatial_channels = spatial_channels
        self.temporal1 = nn.Sequential(TimeBlock(in_channels=in_channels,
                                                 out_channels=out_channels),
        )
        self.attentions = [GraphAttentionLayer(
            out_channels*(num_timesteps_input),
            spatial_channels, num_nodes=num_nodes,
            dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]
        
        self.relu = nn.ReLU()
        self.batch_norm = nn.BatchNorm2d(num_nodes)

    
    def forward(self, X, A_hat):
        residual = X
        t = self.temporal1(X)
        t = t.contiguous().view(t.shape[0], t.shape[1], -1)
        if self.concat:
            t2 = torch.cat([att(t, A_hat) for att in self.attentions], dim=2)
        else:
            t2 = sum([att(t, A_hat) for att in self.attentions]) / self.nheads

        t2 = t2.view(t2.shape[0], t2.shape[1], -1, self.spatial_channels)
        t3 = t2
        if t3.shape[-1] == residual.shape[-1]:
            t3 = t3 + residual[:,:,-t3.shape[2]:,:]
        else:
            t3 = t3
        return self.relu(self.batch_norm(t3))


class DSTANet(nn.Module):
    def __init__(self, num_class=60, num_point=128, num_frame=32, num_subset=3, dropout=0., config=None, num_person=2,
                 num_channel=64, glo_reg_s=True, att_s=True, glo_reg_t=False, att_t=True,
                 use_temporal_att=True, use_spatial_att=True, attentiondrop=0, dropout2d=0, use_pet=True, use_pes=True):
        super(DSTANet, self).__init__()

        self.out_channels = config[-1][1]
        in_channels = config[0][0]

        self.input_map = nn.Sequential(
            nn.Conv2d(num_channel, in_channels//4, 1),
            nn.BatchNorm2d(in_channels//4),
            nn.LeakyReLU(0.1),
        )
        self.diff_map1 = nn.Sequential(
            nn.Conv1d(in_channels//4, in_channels//2, 3),
            nn.BatchNorm1d(in_channels//2),
            nn.LeakyReLU(0.1),
        )
        self.diff_map2 = nn.Sequential(
            nn.Conv1d(in_channels//2, in_channels, 3),
            nn.BatchNorm1d(in_channels),
            nn.LeakyReLU(0.1),
        )

        
        # param = {
        #     'num_node': num_point,
        #     'num_subset': num_subset,
        #     'glo_reg_s': glo_reg_s,
        #     'att_s': att_s,
        #     'glo_reg_t': glo_reg_t,
        #     'att_t': att_t,
        #     'use_spatial_att': use_spatial_att,
        #     'use_temporal_att': use_temporal_att,
        #     'use_pet': use_pet,
        #     'use_pes': use_pes,
        #     'attentiondrop': attentiondrop
        # }
        
        self.transformer_block = STGATBlock(in_channels=in_channels, out_channels=self.out_channels,
                                            concat=False,
                                            spatial_channels=self.out_channels,
                                            num_nodes=num_channel, 
                                            num_timesteps_input=num_point-4, nheads=4)
        # self.graph_layers = nn.ModuleList()
        # for index, (in_channels, out_channels, inter_channels, stride) in enumerate(config):
        #     self.graph_layers.append(
        #         STAttentionBlock(in_channels, out_channels, inter_channels, stride=stride, num_frame=num_frame,
        #                          **param))
        #     num_frame = int(num_frame / stride + 0.5)

        self.fc = nn.Linear(self.out_channels, num_class)

        # self.drop_out = nn.Dropout(dropout)
        # self.drop_out2d = nn.Dropout2d(dropout2d)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                conv_init(m)
            elif isinstance(m, nn.BatchNorm2d):
                bn_init(m, 1)
            elif isinstance(m, nn.Linear):
                fc_init(m)
        

    def forward(self, x, A):
        """

        :param x: N x C x T
                N - batch size, C - channels amount, T - timestamps amount
        :return: classes scores
        """
        N, C, T = x.shape

        # Features reducing part
        x = self.input_map(x.view(N, C, 1, T))
        
        x = self.diff_map1(x.view(N, -1, T))
        x = self.diff_map2(x)
        print(x.shape)
        x = self.transformer_block(x, A)
        

        # for i, m in enumerate(self.graph_layers):
        #     x = m(x)

        # # NM, C, T, V
        # x = x.view(N, M, self.out_channels, -1)
        # x = x.permute(0, 1, 3, 2).contiguous().view(N, -1, self.out_channels, 1)  # whole channels of one spatial
        # x = self.drop_out2d(x)
        # x = x.mean(3).mean(1)

        # x = self.drop_out(x)  # whole spatial of one channel

        # return self.fc(x)

In [8]:
DATA_DIR = os.path.join("..", "data")

In [9]:
dataset = DREAMERDataset(io_path=os.path.join(DATA_DIR, "dreamer"),
                         mat_path=os.path.join(DATA_DIR, 'DREAMER.mat'),
                         online_transform=transforms.Compose([
                             ToG(DREAMER_ADJACENCY_MATRIX)
                         ]),
                         label_transform=transforms.Compose([
                             transforms.Select('arousal'),
                             transforms.Binary(3.0)
                         ]),
                         # num_worker=4
                        )

dataset already exists at path ..\data\dreamer, reading from path...


In [15]:
from torcheeg.model_selection import KFoldPerSubject

SPLIT_PATH = os.path.join(DATA_DIR, 'dreamer_split')
k_fold = KFoldPerSubject(n_splits=5,
                         split_path=SPLIT_PATH,
                         shuffle=True)

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [119]:
import torch.nn as nn
# from torcheeg.models import CCNN

from torcheeg.model_selection import train_test_split
# from torch.utils.data.dataloader import DataLoader
from torch_geometric.loader import DataLoader
from torch_geometric.utils import scatter

loss_fn = nn.CrossEntropyLoss()
batch_size = 16
n_channels = 14

test_accs = []
test_losses = []

for i, (train_dataset, test_dataset) in enumerate(k_fold.split(dataset)):
    # initialize model
    # model = CCNN(num_classes=2, in_channels=4, grid_size=(9, 9)).to(device)
    model = DSTANet(num_class=2, num_channel=n_channels, config=[[64, 2], [2, 2]]).to(device)
    # initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=1e-4)  # official: weight_decay=5e-1
    # split train and val
    train_dataset, val_dataset = train_test_split(
        train_dataset,
        test_size=0.2,
        split_path=SPLIT_PATH+f'{i}',
        shuffle=True)
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
    for batch_idx, batch in enumerate(train_loader):
        X = batch[0].x.view(batch_size, n_channels, -1).to(device)
        y = batch[1].to(device)
        A = torch.sparse_coo_tensor(batch[0].edge_index,
                                    batch[0].edge_weight,
                                    size=(batch_size, n_channels*n_channels)
                                   ).to_dense().view(batch_size, n_channels, n_channels)
        # A = batch[0].edge_index.view(batch_size, 2, -1)
        model(X, A)

        
        break
    break
        # print(batch[0].shape)
        # print(np.unique(batch[1]))
        # X = batch[0].to(device)
        # y = batch[1].to(device)


in_channels 64
0 64 128
in_channels 64
1 128 128
in_channels 64
2 128 2
torch.Size([16, 64, 124])
in torch.Size([16, 64, 124])
X_in torch.Size([16, 124, 64])


RuntimeError: Given groups=1, weight of size [128, 64, 1, 1], expected input[1, 16, 124, 64] to have 64 channels, but got 16 channels instead