In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Lambda
import torchinfo
from sklearn.metrics import roc_auc_score, roc_curve

import numpy as np
import math
import time

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU Device:【{}:{}】".format(device.type, device.index))
    torch.cuda.set_device(0)

GPU Device:【cuda:None】


In [3]:
class CustomDataset(Dataset):
    def __init__(self, data:np.ndarray, labels:np.ndarray, transform=ToTensor(), 
    target_transform=Lambda(lambda y: torch.zeros(2, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))):
        self.data:torch.Tensor = torch.from_numpy(data)
        self.labels:torch.Tensor = torch.from_numpy(labels)
        self.transform = None
        self.target_transform = None
        # self.transform = transform
        # self.target_transform = target_transform
        # self.shuffle()
    
    def shuffle(self, seed=None):
        '\n        seed(self, seed=None)\n\n        Reseed a legacy MT19937 BitGenerator\n        '
        self.shuffle_seed = np.random.randint(1, 65535) if seed is None else seed
        print(f"随机种子：{self.shuffle_seed}")
        np.random.seed(self.shuffle_seed)
        np.random.shuffle(self.data)
        np.random.seed(self.shuffle_seed)
        np.random.shuffle(self.labels)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.labels[idx, 0]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        return data, label

In [4]:
def load_dataset(path="dataset.npz", train_percent=0.8) -> tuple:
    with np.load(path) as dataset:
        full_data = dataset["data"].astype(np.float32).reshape((-1, 1, 116, 116))
        full_labels = dataset["labels"].astype(np.int64)
    train_size = int(full_data.shape[0]*train_percent)
    test_size = full_data.shape[0]-train_size
    seed = np.random.randint(1, 65535)
    np.random.seed(seed)
    np.random.shuffle(full_data)
    np.random.seed(seed)
    np.random.shuffle(full_labels)
    train_data, test_data = full_data[:train_size], full_data[train_size:]
    train_labels, test_labels = full_labels[:train_size], full_labels[train_size:]
    print(f"训练集大小：{train_size}", f"测试集大小：{test_size}", f"随机种子：{seed}")
    train_dataset = CustomDataset(train_data, train_labels)
    test_dataset = CustomDataset(test_data, test_labels)
    return train_dataset, test_dataset

In [5]:
# train_dataset, test_dataset = load_dataset("D:\\datasets\\ABIDE\\ABIDE_FC_dataset.npz", 0.8)
train_dataset, test_dataset = load_dataset("D:\\datasets\\ABIDE\\ABIDE_FC_augmented_dataset.npz", 0.8)

训练集大小：9636 测试集大小：2409 随机种子：11000


In [6]:
batch_size=64

# Create data loaders.
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

for X, y in test_dataloader:
    data_shape = X.shape
    label_shape = y.shape
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

Shape of X [N, C, H, W]: torch.Size([64, 1, 116, 116])
Shape of y: torch.Size([64]) torch.int64


In [7]:
class DWConv(nn.Module):
    def __init__(self, dim:int):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x):
        x = self.dwconv(x)
        return x

In [8]:
class LKA(nn.Module):
    def __init__(self, dim:int):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv_spatial = nn.Conv2d(dim, dim, 7, stride=1, padding=9, groups=dim, dilation=3)
        self.conv1 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn = self.conv_spatial(attn)
        attn = self.conv1(attn)

        return u * attn, attn

In [9]:
class ResBlock(nn.Module):

    def __init__(self, out_channels:int=None, device=None, *args, **wargs):
        super().__init__(*args, **wargs)
        self.__built = False
        self.out_channels = out_channels
        self.device = device

    def build(self, input_shape):
        self.input_shape = input_shape
        in_channels = input_shape[1]
        out_channels = in_channels if self.out_channels is None else in_channels
        # resolve output shape in model summary
        self.conv = nn.Conv2d(in_channels, in_channels, 1)
        self.lka = LKA(in_channels)
        self.activation = nn.ReLU()
        self.bn = nn.BatchNorm2d(in_channels)
        self.downconv = DWConv(in_channels)
        self.output_shape = input_shape
        self.__built = True
        self.to(self.device)

    def forward(self, x:np.ndarray):
        fx:np.ndarray = x
        fx = self.conv(fx)
        fx, attn = self.lka(fx)
        fx = self.bn(fx)
        if fx.shape[-1] != x.shape[-1]:
            x = self.downconv(x)
        return fx + x, attn

    def freeze(self):
        for param in self.parameters():
            param.requires_grad = False

In [10]:
class LKAResNet(nn.Module):
    def __init__(self, device=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__built:bool = False
        self.device:torch.DeviceObjType = device
    
    def build(self, input_shape):
        # self.lstm = nn.LSTM(116, 116, 2, batch_first=True)
        # An ordinary ResNet, but put blocks in a list. New blocks will be added into this list when training.
        # 常规的残差网络，但将残差块放在一个list中，训练时会将新块添加到这里
        self.blocks:nn.ModuleList = nn.ModuleList([ResBlock() for _ in range(2)])
        for block in self.blocks:
            block.build(input_shape)
            block.to(self.device)
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(np.prod(input_shape[1:]), 2)
        self.__built = True

    def compile(self, dataloader:DataLoader, loss_fn, optimizer, lr=1e-2):
        self.batch_size:int = dataloader.batch_size
        for X, y in dataloader:
            self.input_shape:tuple = X.shape
            self.output_shape:tuple = y.shape
            break
        self.build(self.input_shape)
        self.loss_fn = loss_fn()
        self.optimizer = optimizer(filter(lambda p: p.requires_grad, self.parameters()), lr=lr)
        self.to(self.device)
    
    def forward(self, x):
        # x, (h_n, c_n) = self.lstm(x)
        for blk in self.blocks:
            x, attn = blk(x)
        x = self.flatten(x)
        x = self.linear(x)
        return x

    def get_attn(self, x):
        self.eval()
        attention = []
        with torch.no_grad():
            for blk in self.blocks:
                x, attn = blk(x)
                attention.append(attn.cpu().numpy().reshape((-1, 116, 116)))
        return np.array(attention)

    def fit(self, dataloader:DataLoader, epochs:int=1, test_dataloader=None):
        size = len(dataloader.dataset)
        num_batches = size // self.batch_size
        time_collection = []
        loss_collection = []
        correct_collection = []
        for epoch in range(epochs):
            print(f"Epoch: {epoch+1}/{epochs}")
            self.train()
            loss, correct = 0, 0
            time_delta = 0
            for batch, (X, y) in enumerate(dataloader):
                X = X.to(self.device)
                y = y.to(self.device)

                # 计时
                torch.cuda.synchronize()
                time_start = time.time()

                # Compute prediction error
                pred = self.forward(X)
                batch_loss = self.loss_fn(pred, y)

                # Backpropagation
                batch_loss.backward()
                self.optimizer.step()
                self.optimizer.zero_grad()

                # 计时结束
                torch.cuda.synchronize()
                time_end = time.time()

                current = batch * self.batch_size + len(X)

                batch_loss = batch_loss.item()
                loss += batch_loss

                batch_correct = (pred.argmax(1) == y).type(torch.float).sum().item()
                correct += batch_correct
                batch_correct /= len(X)

                batch_time = time_end - time_start
                time_delta += batch_time
                print(f"\r{batch+1}/{num_batches+1}  [{current:>3d}/{size:>3d}] - batch loss: {batch_loss:>7f} - batch accuracy: {(100*batch_correct):>0.1f}% - {batch_time*1000:>0.3f}ms", end = "", flush=True)
            loss /= num_batches
            correct /= size
            print(f"\n-- Average loss: {loss:>7f} - Accuracy: {(100*correct):>0.1f}% - {time_delta/num_batches*1000:>0.3f}ms/batch")
            time_collection.append(time_delta)
            loss_collection.append(loss)
            correct_collection.append(correct)
            if test_dataloader is not None:
                self.test(test_dataloader)
        print("\n", torchinfo.summary(self, input_size=self.input_shape))
        return correct_collection, loss_collection, time_collection

    def test(self, dataloader:DataLoader, return_preds=False):
        size = len(dataloader.dataset)
        num_batches = len(dataloader)
        ys = []
        preds = []
        self.eval()
        test_loss, correct = 0, 0
        with torch.no_grad():
            for X, y in dataloader:
                X, y = X.to(device), y.to(device)
                pred = self.forward(X)
                test_loss += self.loss_fn(pred, y).item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
                if return_preds:
                    ys = np.hstack((ys, y.cpu()))
                    preds = np.hstack((preds, pred.argmax(1).cpu()))
        test_loss /= num_batches
        correct /= size
        print(f"Test Accuracy: {(100*correct):>0.1f}%, Average loss: {test_loss:>8f} \n")
        if return_preds:
            return ys, preds

In [14]:
lr = 1e-4
epochs = 5
loss_fn = nn.CrossEntropyLoss
optimizer = torch.optim.Adam

In [15]:
correct, loss, timing = {}, {}, {}

In [16]:
base_model = LKAResNet(device)
base_model.compile(train_dataloader, loss_fn=loss_fn, optimizer=optimizer, lr=lr)
correct["RLKA"], loss["RLKA"], timing["RLKA"] = base_model.fit(train_dataloader, epochs=epochs, test_dataloader=test_dataloader)

Epoch: 1/5
151/151  [9636/9636] - batch loss: 0.544637 - batch accuracy: 80.6% - 23.000ms
-- Average loss: 0.594560 - Accuracy: 70.8% - 45.430ms/batch
Test Accuracy: 77.6%, Average loss: 0.513704 

Epoch: 2/5
151/151  [9636/9636] - batch loss: 0.419087 - batch accuracy: 91.7% - 23.998ms
-- Average loss: 0.452903 - Accuracy: 83.2% - 44.006ms/batch
Test Accuracy: 83.1%, Average loss: 0.420595 

Epoch: 3/5
151/151  [9636/9636] - batch loss: 0.324374 - batch accuracy: 94.4% - 23.000ms
-- Average loss: 0.362488 - Accuracy: 88.0% - 44.250ms/batch
Test Accuracy: 86.8%, Average loss: 0.352300 

Epoch: 4/5
151/151  [9636/9636] - batch loss: 0.251975 - batch accuracy: 97.2% - 25.000ms
-- Average loss: 0.295222 - Accuracy: 91.1% - 45.255ms/batch
Test Accuracy: 89.7%, Average loss: 0.298740 

Epoch: 5/5
151/151  [9636/9636] - batch loss: 0.196415 - batch accuracy: 97.2% - 24.001ms
-- Average loss: 0.242222 - Accuracy: 93.4% - 45.017ms/batch
Test Accuracy: 91.8%, Average loss: 0.254690 


Layer (ty

In [17]:
base_model.test(train_dataloader)
base_model.test(test_dataloader)

Test Accuracy: 94.8%, Average loss: 0.210122 

Test Accuracy: 91.8%, Average loss: 0.254690 



In [18]:
for x, y in test_dataloader:
    attn = np.array(base_model.get_attn(x.to(device)))
    break

In [19]:
import pickle
with open("./attention.pkl", "wb") as f:
    pickle.dump(attn, f)