# resnet50 cross attention baseline

In [1]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = "1" 
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import hiddenlayer as hl
import torch
import torch.nn as nn
from torch.optim import SGD,Adam
import torch.utils.data as Data
from torchvision import models
from  torchvision import transforms
from  torchvision.datasets import ImageFolder
import pickle as pkl
import torchvision.models as models
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [2]:
class FineTuneResnet50(nn.Module):
    def __init__(self, num_class=3):
        super(FineTuneResnet50, self).__init__()
        self.num_class = num_class
        resnet50_net_MRI = models.resnet50(pretrained=True)
        resnet50_net_PET = models.resnet50(pretrained=True)
        self.features_MRI = nn.Sequential(*list(resnet50_net_MRI.children())[:-1])
        self.features_PET = nn.Sequential(*list(resnet50_net_PET.children())[:-1])
        self.fc_comb = nn.Sequential(
            nn.Linear(4096,256),
            nn.ReLU(),
            nn.Dropout(p = 0.5),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Dropout(p = 0.5),
            nn.Linear(128,3)
        )
        
         
    def MA(self, x, label):
        # x (k, v), label (q)
        B, C_kv = x.shape
        B, C_q = label.shape
        self.kv = nn.Linear(C_kv, C_kv * 3 * 2).cuda()
        self.q = nn.Linear(C_q, C_kv * 3).cuda()
        self.at_fx = nn.Linear(C_kv * 3, C_kv).cuda()
        #self.ffn = nn.Linear(C_kv, C_kv).cuda()
        kv = self.kv(x).reshape(2, B, 3, C_kv)
        k, v = kv[0], kv[1]
        q = self.q(label).reshape(B, 3, C_kv)
        attn = torch.einsum("bhq,bhk->bhqk", [q, k])
        attn = attn.softmax(dim=-1)
        x_ = torch.einsum("bhqk,bhk->bhq", [attn, v])
        x_ = x_.reshape(B, C_kv * 3)
        x = self.at_fx(x_) + x
        #x = self.ffn(x) + x
        return x    
        

 
    def forward(self, MRI,PET):
        MRI = self.features_MRI(MRI)
        PET = self.features_PET(PET)
        MRI = MRI.view(MRI.size(0),-1)
        PET = PET.view(PET.size(0),-1)# 将第二次卷积的输出拉伸为一行
        
        MRI_ma = self.MA(MRI,PET)
        PET_ma = self.MA(PET,MRI)
        MRI = torch.cat((MRI_ma, PET_ma), 1)
        MRI = self.fc_comb(MRI)
        return MRI


In [3]:
MyResnet = FineTuneResnet50()

In [4]:
MyResnet

FineTuneResnet50(
  (features_MRI): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0

In [5]:
#定义优化器
optimizer = torch.optim.Adam(MyResnet.parameters(),lr=0.00001,weight_decay=0.01)
loss_func = nn.CrossEntropyLoss()#损失函数

In [6]:
#记录训练过程指标
historyl = hl.History()
#使用Canves进行可视化

canvasl = hl.Canvas()

In [7]:
from torch.utils.data import DataLoader
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root):
        super(MyDataset, self).__init__()
        MRI_PET_match_all = pkl.load(open(root,"rb"),encoding='iso-8859-1')
        MRI = []
        PET = []
        group = []
        for index,row in MRI_PET_match_all.iterrows():
            MRI.append(row['MRI_img_array'])
            PET.append(row['PET_img_array'])
            group_ = torch.tensor(row['Group'],dtype=torch.float)
            group.append(group_)
        self.MRI = MRI
        self.PET = PET
        self.group = group  

    def __getitem__(self, index):
        mri =torch.from_numpy(self.MRI[index].transpose([2,0,1])).float().to(DEVICE)
        pet = torch.from_numpy(self.PET[index].transpose([2,0,1])).float().to(DEVICE)
        group = self.group[index].to(DEVICE)
        return mri,pet,group

    def __len__(self):
        return len(self.MRI)

train_data = MyDataset("/home/gc/gechang/gec_multi_fusion/end_to_end/train.pkl")
test_data = MyDataset("/home/gc/gechang/gec_multi_fusion/end_to_end/test.pkl")

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8)


In [8]:
#对模型进行迭代训练，对所有的数据训练epoch轮
for epoch in range(40):
    train_loss_epoch = 0
    val_loss_epoch = 0
    train_corrects = 0
    val_corrects = 0
    #对训练数据的加载器进行迭代计算
    MyResnet.train().cuda()
    for step,(mri,pet,group) in enumerate(train_loader):
        ##计算每个batch的损失
        output = MyResnet(mri,pet)
        loss = loss_func(output,group.long())#交叉熵损失函数
        pre_lab = torch.argmax(output,1).to(DEVICE)
        optimizer.zero_grad()#每个迭代步的梯度初始化为0
        loss.backward()#损失的后向传播，计算梯度
        optimizer.step()#使用梯度进行优化
        train_loss_epoch += loss.item()*group.size(0)
        train_corrects += torch.sum(pre_lab == group.to(DEVICE).data)
    #计算一个epoch的损失和精度
    train_loss = train_loss_epoch/len(train_data.group)
    train_acc = train_corrects.double()/len(train_data.group)
    print("---------------------------------------------------")
    print("epoch:",epoch,"train_loss:",train_loss,"train_acc:",train_acc)
     #计算在验证集上的表现
    MyResnet.eval()
    for step,(mri,pet,group) in enumerate(test_loader):
        output = MyResnet(mri,pet)
        loss = loss_func(output,group.long())
        pre_lab = torch.argmax(output,1).to(DEVICE)
        val_loss_epoch += loss.item()*group.size(0)
        val_corrects += torch.sum(pre_lab == group.to(DEVICE).data)

    #计算一个epoch上的输出loss和acc
    val_loss = val_loss_epoch/len(test_data.group)
    val_acc = val_corrects.double()/len(test_data.group)
    print("epoch:",epoch,"val_loss:",val_loss,"val_acc:",val_acc)
    #保存每个epoch上的输出loss和acc
    historyl.log(epoch,train_loss=train_loss,val_loss = val_loss,train_acc = train_acc.item(),val_acc = val_acc.item())
    #可视化网络训练的过程
    # with canvasl:
    #     canvasl.draw_plot([historyl["train_loss"],historyl["val_loss"]])
    #     canvasl.draw_plot([historyl["train_acc"],historyl["val_acc"]])


---------------------------------------------------
epoch: 0 train_loss: 1.0000896543600928 train_acc: tensor(0.5533, device='cuda:0', dtype=torch.float64)
epoch: 0 val_loss: 0.9738983535766601 val_acc: tensor(0.5450, device='cuda:0', dtype=torch.float64)
---------------------------------------------------
epoch: 1 train_loss: 0.871044895879298 train_acc: tensor(0.5885, device='cuda:0', dtype=torch.float64)
epoch: 1 val_loss: 0.8790412139892578 val_acc: tensor(0.5550, device='cuda:0', dtype=torch.float64)
---------------------------------------------------
epoch: 2 train_loss: 0.6854362866211416 train_acc: tensor(0.7390, device='cuda:0', dtype=torch.float64)
epoch: 2 val_loss: 0.7578091204166413 val_acc: tensor(0.6050, device='cuda:0', dtype=torch.float64)
---------------------------------------------------
epoch: 3 train_loss: 0.4971851312755192 train_acc: tensor(0.8432, device='cuda:0', dtype=torch.float64)
epoch: 3 val_loss: 0.6147321951389313 val_acc: tensor(0.7600, device='cuda:0'