In [1]:
import torch
from torch import nn
import seaborn as sns
import matplotlib.pyplot as plt
import sys
import os
import math
import nibabel as nib
import numpy as np
import pandas as pd
sys.path.append('../..')
from src.utils.data import writePandas, getPandas, getConfig, getDict
from src.model.feature import load_radiomics
os.chdir('../../..')

In [2]:
class AfterNet(nn.Module):
    def __init__(self, hidden_dim=512) -> None:
        super().__init__()
        self.hidden_dim = 512
        self.feature_size = 150
        self.flatten = nn.Flatten(start_dim=2, end_dim=-1)
        self.train_qw = nn.Linear(4, hidden_dim)
        nn.init.normal_(self.train_qw.weight, mean=0, std=0.01)
        self.train_kw = nn.Linear(self.feature_size, hidden_dim)
        nn.init.normal_(self.train_kw.weight, mean=0, std=0.01)
        self.train_vw = nn.Linear(self.feature_size, hidden_dim)
        nn.init.normal_(self.train_vw.weight, mean=0, std=0.01)
        self.train_softmax = nn.Softmax(dim=2)
        self.train_affine = nn.Linear(self.feature_size, 1)
        self.flatten2 = nn.Flatten()
        self.train_layernorm = nn.LayerNorm(512)
        self.train_fc = nn.Linear(hidden_dim, 1)
    
    def forward(self, x, tex, demo):
        x = self.flatten(x)
        res_x = x
        q = self.train_qw(tex)
        k = self.train_kw(x)
        v = self.train_vw(x)
        x = torch.bmm(q, k.transpose(1, 2))
        x = x / math.sqrt(self.feature_size)
        x = self.train_softmax(x)
        x = torch.bmm(x, v)
        x = self.flatten2(x)
        res_x = self.train_affine(res_x)
        res_x = res_x.transpose(1, 2)
        res_x = self.flatten2(res_x)
        x = x + res_x
        x = self.train_layernorm(x)
        x = self.train_fc(x)
        return x
    
from torch.utils.data import dataset, dataloader
class MyDataset(dataset.Dataset):
    def __init__(self, keys, path, radiomics, cats, ages, sexs, scores, ledds, durations):
        self.keys = keys
        self.radiomics = radiomics
        self.labels = cats
        self.ages = ages
        self.sexs = sexs
        self.scores = scores
        self.ledds = ledds
        self.durations = durations
        self.data = np.load(path)
    def __getitem__(self, index):
        key = self.keys[index]
        radiomic = self.radiomics[index]
        img = self.data[index]
        tex = np.array(radiomic)
        tex = np.expand_dims(tex, axis=0)
        tex = torch.from_numpy(tex)
        demo = np.array([self.ages[index], self.sexs[index], self.scores[index], self.ledds[index], self.durations[index]])
        demo = torch.from_numpy(demo)
        label = self.labels[index]
        label = torch.from_numpy(np.array([label]))
        return img, tex, demo, label
    def __len__(self):
        return len(self.keys)

In [3]:
data = getPandas('pat_data')
conf = getConfig('data')
train_idx = conf['indices']['pat']['train']
test_idx = conf['indices']['pat']['test']
keys = data['KEY'].values
paths = data['ANTs_Reg'].values
cats = data['CAT'].values
radiomics = getPandas('pat_ANTs_Reg_radiomic')
radiomics = radiomics.drop(['KEY'], axis=1)

train_keys = keys[train_idx]
train_path = 'data/bin/pat_resnet_10.npy'
train_radiomics = radiomics.iloc[train_idx]
radiomic_cols = [
    'rTHA_original_gldm_LargeDependenceHighGrayLevelEmphasis',
    'rTHA_original_glszm_LargeAreaHighGrayLevelEmphasis',
    'rSN_original_glcm_ClusterProminence',
    'rCAU_original_gldm_LargeDependenceHighGrayLevelEmphasis'
]
train_radiomics = train_radiomics[radiomic_cols].values
radiomic_mean = train_radiomics.mean(axis=0)
radiomic_std = train_radiomics.std(axis=0)
train_radiomics = (train_radiomics - radiomic_mean) / radiomic_std
train_cats = cats[train_idx]
train_ages = data['AGE'].values[train_idx]/100
train_sexs = data['SEX'].values[train_idx]
train_scores = data['NUPDR3OF'].values[train_idx]
train_ledds = data['LEDD'].values[train_idx]/100
train_durations = data['DURATION'].values[train_idx]/100

In [4]:
from torch import nn, optim
loss_fn = nn.BCEWithLogitsLoss()
optim_params = []

data_set = MyDataset(train_keys, train_path, train_radiomics, train_cats, train_ages, train_sexs, train_scores, train_ledds, train_durations)
data_loader = dataloader.DataLoader(data_set, batch_size=8, shuffle=True)
epoch = 40
from torcheval.metrics.aggregation.auc import AUC
metric = AUC()

In [None]:
loss_rec = []
lr_lsit = [1e-3]
accumulation_steps = 8
gq_list = []
gv_list = []
gk_list = []
gl_list = []
hidden_dim = 512
for lr in lr_lsit:
    loss_list = []
    model = AfterNet(hidden_dim=hidden_dim)
    kv_size = model.feature_size * hidden_dim
    q_size = 4 * hidden_dim
    model = model.cuda()
    for name, param in model.named_parameters():
        if 'train' in name:
            param.requires_grad = True
        else:
            param.requires_grad = False
    optimizer = optim.SGD([
            {'params': model.train_qw.parameters(), 'lr': 1e-2},
            {'params': model.train_kw.parameters(), 'lr': 1e-2},
            {'params': model.train_vw.parameters(), 'lr': 1e-2},
            {'params': model.train_fc.parameters(), 'lr': 1e-2},
            {'params': model.train_affine.parameters(), 'lr': 1e-2},
            {'params': model.train_layernorm.parameters(), 'lr': 1e-2}
        ], lr=lr)
    #scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
    #scheduler = optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=10, gamma=0.1)
    for i in range(epoch):
        model.train()
        for j, (img, tex, demo, label) in enumerate(data_loader):
            #wq_pre = model.train_qw.weight
            #wk_pre = model.train_kw.weight
            #wv_pre = model.train_vw.weight
            img = img.cuda().float()
            tex = tex.cuda().float()
            demo = demo.cuda().float()
            label = label.cuda().float()
            out = model(img, tex, demo)
            loss = loss_fn(out, label)
            loss = loss / accumulation_steps
            loss.backward()
            if (j + 1) % accumulation_steps == 0:
                optimizer.step()
                wq = model.train_qw.weight
                wk = model.train_kw.weight
                wv = model.train_vw.weight
                wl = model.train_fc.weight
                gq = wq.grad / q_size
                gk = wk.grad / kv_size
                gv = wv.grad / kv_size 
                gl = wl.grad / hidden_dim
                gq_sum = np.sum(np.abs(gq.cpu().detach().numpy()))
                gk_sum = np.sum(np.abs(gk.cpu().detach().numpy()))
                gv_sum = np.sum(np.abs(gv.cpu().detach().numpy()))
                gl_sum = np.sum(np.abs(gl.cpu().detach().numpy()))
                print('gq: {}, gk: {}, gv: {}, gl: {}'.format(gq_sum, gk_sum, gv_sum, gl_sum))
                gq_list.append(gq_sum)
                gk_list.append(gk_sum)
                gv_list.append(gv_sum)
                gl_list.append(gl_sum)
                if i > 1:

                    sns.lineplot(gq_list[1:])
                    plt.title('gq')
                    plt.show()

                    sns.lineplot(gk_list[1:])
                    plt.title('gk')
                    plt.show()

                    sns.lineplot(gv_list[1:])
                    plt.title('gv')
                    plt.show()

                    sns.lineplot(gl_list[1:])
                    plt.title('gl')
                    plt.show()
                optimizer.zero_grad()
            #optimizer.zero_grad()
            #loss.backward()
            #optimizer.step()
        #print('epoch: {}, loss: {}'.format(i, loss.item()))
        #scheduler.step()
        model.eval()
        with torch.no_grad():
            total = 0
            correct = 0
            total_loss = 0
            metric.reset()
            output = np.array([])
            for img, tex, demo, label in data_loader:
                img = img.cuda().float()
                tex = tex.cuda().float()
                demo = demo.cuda().float()
                label = label.cuda().float()
                out = model(img, tex, demo)
                loss = loss_fn(out, label)
                total_loss += loss.item()
                #print(torch.squeeze(out).cpu().detach().numpy(), torch.squeeze(label).cpu().detach().numpy())
                output = np.append(output, torch.squeeze(out).cpu().detach().numpy())
                out = torch.sigmoid(out)
                metric.update(torch.squeeze(out), torch.squeeze(label))
                out = torch.round(out)
                out = out.cpu().detach().numpy()
                total += label.size(0)
                correct += (out == label.cpu().detach().numpy()).sum().item()
            sns.histplot(output)
            plt.show()
            loss_item = total_loss / len(train_ages)
            loss_list.append(loss_item)
            print('epoch: {}, loss: {}'.format(i, total_loss / len(data_loader)))
            print('acc: {}'.format(correct / total))
            print('auc: {}'.format(metric.compute().cpu().detach().numpy()))
    loss_rec.append(loss_list)