### 脑电图注意力网络（GAT）

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import add_self_loops, degree

In [2]:
# import numpy as np

# def data_split(train_ratio=0.7):
#     load_dir = '../global_data/time_76800x32x128/'

#     trials = np.load(load_dir + 'trials.npy')
#     bases = np.load(load_dir + 'bases.npy')
#     labels = np.load(load_dir + 'labels.npy')
#     # print(trials.shape, bases.shape, labels.shape)
    
#     # 去基线
#     for i, base in enumerate(bases):
#         trials[i * 60 : (i + 1) * 60] -= base
    
#     # 离散化标签
#     labels = np.where(labels >= 5, 1, 0)

#     # 复制标签以对齐样本
#     labels = np.repeat(labels, 60, axis = 0)
#     # print(labels.shape)
    
#     shuffle_list = np.arange(trials.shape[0])
#     np.random.shuffle(shuffle_list)
#     trials = trials[shuffle_list]
#     labels = labels[shuffle_list]
    
#     cut_point = int(trials.shape[0] * train_ratio)
#     train_features, train_labels = trials[:cut_point], labels[:cut_point]
#     test_features, test_labels = trials[cut_point:], labels[cut_point:]
    
#     train_features = train_features.reshape((cut_point, 32 * 128))
#     test_features = test_features.reshape((trials.shape[0] - cut_point, 32 * 128))
    
#     mean = train_features.mean(axis = 0)
#     std = train_features.std(axis = 0)
    
#     train_features = (train_features - mean) / std
#     test_features = (test_features - mean) / std
    
#     train_features = train_features.reshape((cut_point, 32, 128))
#     test_features = test_features.reshape((trials.shape[0] - cut_point, 32, 128))
    
#     save_dir = 'data/data_split/'
#     np.save(save_dir + 'train_features.npy', train_features)
#     np.save(save_dir + 'train_labels.npy', train_labels)
#     np.save(save_dir + 'test_features.npy', test_features)
#     np.save(save_dir + 'test_labels.npy', test_labels)

# data_split(train_ratio=0.9)

In [3]:
def load_data(is_train_data=True):
    save_dir = 'data/data_split/'
    if is_train_data:
        features = np.load(save_dir + 'train_features.npy')
        labels = np.load(save_dir + 'train_labels.npy')
    else:
        features = np.load(save_dir + 'test_features.npy')
        labels = np.load(save_dir + 'test_labels.npy')
    return features, labels

In [4]:
def get_edge_index(create_complete_graph=False, self_loop_only=False):
    edge_index = [[],[]]
    weight = []
    
    if create_complete_graph:
        for i in range(32):
            for j in range(32):
                edge_index[0].append(i)
                edge_index[1].append(j)
        edge_index = torch.tensor(np.array(edge_index), dtype=torch.long)
        return edge_index
    
    
    if self_loop_only:
        edge_index = torch.tensor(np.array(edge_index), dtype=torch.long)
        edge_index, _ = add_self_loops(edge_index, num_nodes=32)
        return edge_index
    
    adjacency_edge = {
        1:[2],
        2:[3, 19],
        3:[5, 6],
        4:[5],
        5:[8, 7],
        6:[7, 24],
        7:[9, 10],
        8:[9],
        9:[12, 11],
        10:[11, 16],
        11:[13],
        12:[],
        13:[14, 15],
        14:[15],
        15:[],
        16:[13, 31],
        17:[18],
        18:[19, 20],
        19:[6, 23],
        20:[23, 22],
        21:[22],
        22:[25, 26],
        23:[24, 25],
        24:[10, 28],
        25:[28, 27],
        26:[27],
        27:[29, 30],
        28:[16, 29],
        29:[31],
        30:[],
        31:[15, 32],
        32:[15]
    }
    
    for start, end_list in adjacency_edge.items():
        if len(end_list) == 0:
            continue
        for end in end_list:
            edge_index[0].append(start - 1)
            edge_index[1].append(end - 1)
            edge_index[0].append(end - 1)
            edge_index[1].append(start - 1)
           
    edge_index = torch.tensor(np.array(edge_index), dtype=torch.long)
    
    edge_index, _ = add_self_loops(edge_index, num_nodes=32)
    return edge_index

In [5]:
def get_edge_index(create_complete_graph=False, self_loop_only=False):
    edge_index = [[],[]]
    weight = []
    
    if self_loop_only:
        edge_index = torch.tensor(np.array(edge_index), dtype=torch.long)
        edge_index, _ = add_self_loops(edge_index, num_nodes=32)
        return edge_index
    
    adjacency_edge = {
        1:[2],
        2:[3, 19],
        3:[5, 6],
        4:[5],
        5:[8, 7],
        6:[7, 24],
        7:[9, 10],
        8:[9],
        9:[12, 11],
        10:[11, 16],
        11:[13],
        12:[],
        13:[14, 15],
        14:[15],
        15:[],
        16:[13, 31],
        17:[18],
        18:[19, 20],
        19:[6, 23],
        20:[23, 22],
        21:[22],
        22:[25, 26],
        23:[24, 25],
        24:[10, 28],
        25:[28, 27],
        26:[27],
        27:[29, 30],
        28:[16, 29],
        29:[31],
        30:[],
        31:[15, 32],
        32:[15]
    }
    
    for start, end_list in adjacency_edge.items():
        if len(end_list) == 0:
            continue
        for end in end_list:
            edge_index[0].append(start - 1)
            edge_index[1].append(end - 1)
            edge_index[0].append(end - 1)
            edge_index[1].append(start - 1)
           
    edge_index = torch.tensor(np.array(edge_index), dtype=torch.long)
    
    edge_index, _ = add_self_loops(edge_index, num_nodes=32)
    return edge_index

In [6]:
# edge_index = [[],[]]
# weight = []

# #用一个字典保存 通道下标对应 9 * 9 矩阵的下标
# chan_to_1020={0:[0,3],1:[1,3],2:[2,2],3:[2,0],4:[3,1],5:[3,3],6:[4,2],7:[4,0],8:[5,1],
#               9:[5,3],10:[6,2],11:[6,0],12:[7,3],13:[8,3],14:[8,4],15:[6,4],16:[0,5],
#               17:[1,5],18:[2,4],19:[2,6],20:[2,8],21:[3,7],22:[3,5],23:[4,4],24:[4,6],
#                 25:[4,8],26:[5,7],27:[5,5],28:[6,6],29:[6,8],30:[7,5],31:[8,5]}
# maps = np.zeros(shape=(9, 9), dtype=int)

# for k, v in chan_to_1020.items():
#     maps[v[0]][v[1]] = k + 1
# print(maps)
# plt.matshow(maps)

In [7]:
from torch_geometric.data import InMemoryDataset, Data, Dataset

class MyDataset(InMemoryDataset):
    is_train_data = None
    edge_index = None
    def __init__(self, root, is_train_data, edge_index):
        self.is_train_data = is_train_data
        self.edge_index = edge_index
        super(MyDataset, self).__init__(root)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def raw_file_names(self):
        return []
    
    # 检查self.processed_dir目录下是否存在self.processed_file_names属性方法返回的所有文件，没有就会走process
    @property
    def processed_file_names(self):
        if self.is_train_data:
            return ['train.dataset']
        return ['test.datset']
    
    def download(self):
        pass
    
    def process(self):
        features, labels = None, None
        
        if self.is_train_data:
            features, labels = load_data(is_train_data=True)
        else:
            features, labels = load_data(is_train_data=False)
        
        data_list = []
        for i in range(features.shape[0]):
            x = torch.tensor(features[i], dtype=torch.float)
            y = torch.tensor(labels[i].reshape(1, -1), dtype=torch.long)
            data = Data(x = x, edge_index=self.edge_index, y=y)
            data_list.append(data)
        data, slices = self.collate(data_list)
        
        torch.save((data, slices), self.processed_paths[0])

+ data.x: Node feature matrix with shape [num_nodes, num_node_features]

+ data.edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long

+ data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]

+ data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]

+ data.pos: Node position matrix with shape [num_nodes, num_dimensions]

--- 

- train_mask denotes against which nodes to train (140 nodes),

- val_mask denotes which nodes to use for validation, e.g., to perform early stopping (500 nodes),

- test_mask denotes against which nodes to test (1000 nodes).

In [8]:
from torch_geometric.nn import TopKPooling, SAGEConv, GCNConv, GATConv
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
import torch.nn.functional as F
import torch.nn as nn

embed_dim = 128
class GAT(torch.nn.Module):
    def __init__(self):
        super(GAT, self).__init__()
        
        self.temporalMLPs1 = []
        
        for i in range(32):
            self.temporalMLPs1.append(nn.Linear(128, 256, device=device))
        
        self.temporalMLPs2 = []
        
        for i in range(32):
            self.temporalMLPs2.append(nn.Linear(256, 256, device=device))
        
        self.lin1 = torch.nn.Linear(8192, 512)
        self.lin2 = torch.nn.Linear(512, 128)
        self.lin3 = torch.nn.Linear(128, 2)
        
    def forward(self, data):
        # x： n * 1, 其中每个图中点的个数是不同的
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        batch_size = data.y.shape[0]
        x = x.view(batch_size, 32, 128)
        
        temporalMLPs_out = []
        
        for i in range(32):
            x1 = self.temporalMLPs1[i](x[:, i, :])
            x1 = F.relu(x1)
            x1 = F.dropout(x1, 0.2)
            x1 = self.temporalMLPs2[i](x1)
            x1 = F.relu(x1)
            x1 = F.dropout(x1, 0.2)
            temporalMLPs_out.append(x1)
        
        # concat
        x = torch.concat(temporalMLPs_out, dim = 1)
        
        x = F.relu(self.lin1(x))
        x = F.relu(self.lin2(x))
#         x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin3(x)
        return x

In [9]:
def train(emo_dim):
    
    train_loss = 0
    train_acc = 0
    model.train()
    
    for batch_id, batch in enumerate(trainDataLoader):
        batch.to(device)
        opt.zero_grad()
        output = model(batch)
        loss = crit(output, batch.y[:, emo_dim])
        loss.backward()
        opt.step()
        
        train_loss += loss.item()
        output = torch.max(output, 1)[1]
        train_acc += (output == batch.y[:, emo_dim]).sum()
    
    num_train_sample = len(trainDataLoader.dataset)
    train_loss = train_loss / num_train_sample
    train_acc = train_acc / num_train_sample
    
    # check测试集的性能
    vali_loss = 0
    vali_acc = 0
    model.eval()
    
    for batch in testDataLoader:
        batch.to(device)
        output = model(batch)
        loss = crit(output, batch.y[:, emo_dim])
        vali_loss += loss.item()
        output = torch.max(output, 1)[1]
        vali_acc += (output == batch.y[:, emo_dim]).sum()
    
    num_test_sample = len(testDataLoader.dataset)
    vali_loss = vali_loss / num_test_sample
    vali_acc = vali_acc / num_test_sample
    
    print(f'train_loss:{train_loss:.6f}, train_acc:{train_acc:.6f}, test_loss:{vali_loss:.6f}, test_acc:{vali_acc:.6f}')
    
    return train_loss, train_acc, vali_loss, vali_acc

# 超参设置

In [10]:
create_complete_graph = False
self_loop_only = False
emo_dim = 0
batch_size = 32

device = torch.device('cuda')

In [11]:
from torch_geometric.loader import DataLoader

edge_index = get_edge_index(create_complete_graph=create_complete_graph, self_loop_only=self_loop_only)

trainData = MyDataset(root='data/data_split', is_train_data=True, edge_index=edge_index)
trainDataLoader = DataLoader(trainData, batch_size=batch_size, shuffle=True)

testData = MyDataset(root='data/data_split', is_train_data=False, edge_index=edge_index)
testDataLoader = DataLoader(testData, batch_size=batch_size)

In [12]:
model = GAT().to(device)
opt = torch.optim.Adam(model.parameters())
crit = nn.CrossEntropyLoss().to(device)

In [13]:
for epoch in range(1000):
    print(f'->epoch:{epoch + 1}', end = ', ')
    train_loss, train_acc, val_loss, val_acc = train(emo_dim)
#     print(f'->epoch:{epoch:3d}, train_loss={train_loss:.6f}, train_acc={train_acc:.4f}, val_loss={val_loss:.6f}, val_acc={val_acc:.4f}')

->epoch:1, train_loss:0.018787, train_acc:0.668287, test_loss:0.016049, test_acc:0.747135
->epoch:2, train_loss:0.013750, train_acc:0.792549, test_loss:0.012537, test_acc:0.817708
->epoch:3, train_loss:0.010569, train_acc:0.848669, test_loss:0.011563, test_acc:0.842969
->epoch:4, train_loss:0.008929, train_acc:0.877734, test_loss:0.010657, test_acc:0.852734
->epoch:5, train_loss:0.007764, train_acc:0.895298, test_loss:0.011569, test_acc:0.847005
->epoch:6, train_loss:0.006855, train_acc:0.909505, test_loss:0.009483, test_acc:0.877995
->epoch:7, train_loss:0.006265, train_acc:0.916999, test_loss:0.010241, test_acc:0.871484
->epoch:8, train_loss:0.005754, train_acc:0.923669, test_loss:0.008613, test_acc:0.887760
->epoch:9, train_loss:0.005399, train_acc:0.928472, test_loss:0.008488, test_acc:0.892839
->epoch:10, train_loss:0.004973, train_acc:0.934823, test_loss:0.008635, test_acc:0.894141
->epoch:11, train_loss:0.004695, train_acc:0.939222, test_loss:0.008537, test_acc:0.895313
->epoch:

- MLP_base ->epoch:195, train_loss:0.000070, train_acc:0.999494, test_loss:0.022968, test_acc:0.933854
- MLP_2层dropout（p=0.2）

# 增加模型容量
- ->epoch:86, train_loss:0.002378, train_acc:0.980787, test_loss:0.017223, test_acc:0.908203
- 改为heads=3， ->epoch:167, train_loss:0.001960, train_acc:0.985171, test_loss:0.022945, test_acc:0.910417

## 比较实验
### GCN
+ 仅包括自环时，->epoch:25, train_loss:0.000922, train_acc:0.990784, test_loss:0.010875, test_acc:0.919401
+ 加上3x3卷积核的邻接边时，->epoch:32, train_loss:0.000819, train_acc:0.992173, test_loss:0.020655, test_acc:0.895313，邻接边设计的不好，限制了模型的发挥
+ 别人的方法的准确率：89/90、93/94
### GAT
+ 仅包括自环时，->epoch:30, train_loss:0.001252, train_acc:0.986531, test_loss:0.015308, test_acc:0.912630
+ 使用自己设计的边，->epoch:123, train_loss:0.002104, train_acc:0.982161, test_loss:0.029411, test_acc:0.904688