In [1]:
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = "0" 
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import hiddenlayer as hl
import torch
import torch.nn as nn
from torch.optim import SGD,Adam
import torch.utils.data as Data
from torchvision import models
from  torchvision import transforms
from  torchvision.datasets import ImageFolder
import pickle as pkl
import torchvision.models as models
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from datetime import datetime


In [2]:
# 使用预训练模型
resnet = models.resnet18(pretrained=True)


# 修改最后的全连接层改为十分类
resnet.fc = nn.Linear(512, 3)



In [3]:
# 查看模型一级子层个数
count = 0
for layer in resnet.children():
    count = count + 1
print(count)  # 共10个一级子层

10


In [4]:
# 迁移学习冻结部分参数
count = 0
for layer in resnet.children():
    count = count + 1
    if count < 8:
        for param in layer.parameters():
            param.requires_grad = False  # 前9个一级子层无需参数更新
    # 最后一个(第10个)调整过输出的全连接层不冻住参数

In [5]:
for k, v in resnet.named_parameters():
    print("{}:参数需要计算梯度并更新吗?{}".format(k, v.requires_grad))  # False表示冻结了，True表示没有冻结

conv1.weight:参数需要计算梯度并更新吗?False
bn1.weight:参数需要计算梯度并更新吗?False
bn1.bias:参数需要计算梯度并更新吗?False
layer1.0.conv1.weight:参数需要计算梯度并更新吗?False
layer1.0.bn1.weight:参数需要计算梯度并更新吗?False
layer1.0.bn1.bias:参数需要计算梯度并更新吗?False
layer1.0.conv2.weight:参数需要计算梯度并更新吗?False
layer1.0.bn2.weight:参数需要计算梯度并更新吗?False
layer1.0.bn2.bias:参数需要计算梯度并更新吗?False
layer1.1.conv1.weight:参数需要计算梯度并更新吗?False
layer1.1.bn1.weight:参数需要计算梯度并更新吗?False
layer1.1.bn1.bias:参数需要计算梯度并更新吗?False
layer1.1.conv2.weight:参数需要计算梯度并更新吗?False
layer1.1.bn2.weight:参数需要计算梯度并更新吗?False
layer1.1.bn2.bias:参数需要计算梯度并更新吗?False
layer2.0.conv1.weight:参数需要计算梯度并更新吗?False
layer2.0.bn1.weight:参数需要计算梯度并更新吗?False
layer2.0.bn1.bias:参数需要计算梯度并更新吗?False
layer2.0.conv2.weight:参数需要计算梯度并更新吗?False
layer2.0.bn2.weight:参数需要计算梯度并更新吗?False
layer2.0.bn2.bias:参数需要计算梯度并更新吗?False
layer2.0.downsample.0.weight:参数需要计算梯度并更新吗?False
layer2.0.downsample.1.weight:参数需要计算梯度并更新吗?False
layer2.0.downsample.1.bias:参数需要计算梯度并更新吗?False
layer2.1.conv1.weight:参数需要计算梯度并更新吗?False
layer2.1.bn1.weight:参数需要计

In [6]:
from torch.utils.data import DataLoader
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root,transform_pet = None):
        super(MyDataset, self).__init__()
        MRI_PET_match_all = pkl.load(open(root,"rb"),encoding='iso-8859-1')
        MRI = []
        PET = []
        group = []
        for index,row in MRI_PET_match_all.iterrows():
            MRI.append(row['MRI_img_array'])
            PET.append(row['PET_img_array'])
            group_ = torch.tensor(row['Group'],dtype=torch.int64)
            group.append(group_)
        self.MRI = MRI
        self.PET = PET
        self.group = group  
        self.transform_pet = transform_pet

    def __getitem__(self, index):
        mri =torch.from_numpy(self.MRI[index].transpose([2,0,1])).float().to(DEVICE)
        pet = torch.from_numpy(self.PET[index].transpose([2,0,1])).float().to(DEVICE)
        pet = self.transform_pet(pet)
        group = self.group[index].to(DEVICE)
        return pet,group

    def __len__(self):
        return len(self.MRI)

#not 16
# train_mean_mri = [4.1620684, 4.1620684, 4.1620684]
# train_std_mri = [5.2131376, 5.2131376, 5.2131376]
# train_mean_pet = [4.081158, 4.081158, 4.081158] 
# train_std_pet = [5.1888165, 5.1888165, 5.1888165]

# test_mean_mri = [4.1623616, 4.1623616, 4.1623616]
# test_std_mri = [5.2136188, 5.2136188, 5.2136188]
# test_mean_pet = [4.106387, 4.106387, 4.106387]
# test_std_pet = [5.18535, 5.18535, 5.18535]

#16-in-1
train_mean_mri = [4.176061, 4.176061, 4.176061]
train_std_mri = [5.231413, 5.231413, 5.231413]
train_mean_pet = [4.017313, 4.017313, 4.017313] 
train_std_pet = [5.1714053, 5.1714053, 5.1714053]

test_mean_mri = [4.3262687, 4.3262687, 4.3262687]
test_std_mri = [5.419836, 5.419836, 5.419836] 
test_mean_pet = [4.1623745, 4.1623745, 4.1623745]
test_std_pet = [5.3586817, 5.3586817, 5.3586817]

valid_mean_mri = [4.7232037, 4.7232037, 4.7232037]
valid_std_mri =  [5.8529644, 5.8529644, 5.8529644] 
valid_mean_pet = [4.5747848, 4.5747848, 4.5747848]
valid_std_pet = [5.809622, 5.809622, 5.809622]
    
train_transform_mri = transforms.Compose([
    transforms.Normalize(train_mean_mri,train_std_mri),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
    
])

train_transform_pet = transforms.Compose([
    transforms.Normalize(train_mean_pet,train_std_pet),
    transforms.RandomVerticalFlip(),
    transforms.RandomHorizontalFlip(),
])

test_transform_mri = transforms.Compose([
    
    transforms.Normalize(test_mean_mri,test_std_mri),
])

test_transform_pet = transforms.Compose([
   
    transforms.Normalize(test_mean_pet,test_std_pet)
])


train_data = MyDataset("/home/gc/gechang/gec_multi_fusion/end_to_end/train.pkl",transform_pet =train_transform_pet)
test_data = MyDataset("/home/gc/gechang/gec_multi_fusion/end_to_end/test.pkl",transform_pet =test_transform_pet)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32)


In [7]:

# 计算准确率
def get_acc(output, label):
    total = output.shape[0]
    _, pred_label = output.max(1)
    num_correct = (pred_label == label).sum().item()
    return num_correct / total


In [8]:

def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
    prev_time = datetime.now()
    for epoch in range(num_epochs):
        train_loss = 0
        train_acc = 0
        net = net.train()
        for im, label in train_data:
            im = im.to(DEVICE)  # (bs, 3, h, w)
            label = label.to(DEVICE)  # (bs, h, w)
            # forward
            output = net(im)
            loss = criterion(output, label)
            # backward
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
 
            train_loss += loss.item()
            train_acc += get_acc(output, label)
 
        cur_time = datetime.now()
        h, remainder = divmod((cur_time - prev_time).seconds, 3600)
        m, s = divmod(remainder, 60)
        time_str = "Time %02d:%02d:%02d" % (h, m, s)
        if valid_data is not None:
            valid_loss = 0
            valid_acc = 0
            net = net.eval()
            for im, label in valid_data:
                im = im.to(DEVICE)  # (bs, 3, h, w)
                label = label.to(DEVICE)  # (bs, h, w)
                output = net(im)
                loss = criterion(output, label)
                valid_loss += loss.item()
                valid_acc += get_acc(output, label)
            epoch_str = (
                    "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
                    % (epoch, train_loss / len(train_data),
                       train_acc / len(train_data), valid_loss / len(valid_data),
                       valid_acc / len(valid_data)))
        else:
            epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
                         (epoch, train_loss / len(train_data),
                          train_acc / len(train_data)))
        prev_time = cur_time
        print(epoch_str + time_str)


In [9]:
resnet = resnet.to(DEVICE)
criterion = nn.CrossEntropyLoss()  # 损失函数
# 只需要优化最后一层参数
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, resnet.parameters()), lr=1e-3, weight_decay=1e-3, momentum=0.9)  # 优化器

# train
train(resnet, train_loader, test_loader, 130, optimizer, criterion)


Epoch 0. Train Loss: 0.964457, Train Acc: 0.557931, Valid Loss: 0.973052, Valid Acc: 0.522321, Time 00:00:04
Epoch 1. Train Loss: 0.806322, Train Acc: 0.665603, Valid Loss: 0.903139, Valid Acc: 0.580357, Time 00:00:06
Epoch 2. Train Loss: 0.724097, Train Acc: 0.711336, Valid Loss: 0.889219, Valid Acc: 0.589286, Time 00:00:08
Epoch 3. Train Loss: 0.653078, Train Acc: 0.740474, Valid Loss: 0.834940, Valid Acc: 0.589286, Time 00:00:08
Epoch 4. Train Loss: 0.594379, Train Acc: 0.772586, Valid Loss: 0.824213, Valid Acc: 0.625000, Time 00:00:08
Epoch 5. Train Loss: 0.553511, Train Acc: 0.781336, Valid Loss: 0.831803, Valid Acc: 0.629464, Time 00:00:08
Epoch 6. Train Loss: 0.499178, Train Acc: 0.811466, Valid Loss: 0.716014, Valid Acc: 0.683036, Time 00:00:08
Epoch 7. Train Loss: 0.449337, Train Acc: 0.850733, Valid Loss: 0.726203, Valid Acc: 0.687500, Time 00:00:07
Epoch 8. Train Loss: 0.425154, Train Acc: 0.840733, Valid Loss: 0.707814, Valid Acc: 0.687500, Time 00:00:08
Epoch 9. Train Loss

KeyboardInterrupt: 