In [4]:
'''
Pytorch图像预处理时，通常使用transforms.Normalize(mean, std)对图像按通道进行标准化，即减去均值，再除以方差。这样做可以加快模型的收敛速度。其中参数mean和std分别表示图像每个通道的均值和方差序列。

Imagenet数据集的均值和方差为：mean=(0.485, 0.456, 0.406)，std=(0.229, 0.224, 0.225)，因为这是在百万张图像上计算而得的，所以我们通常见到在训练过程中使用它们做标准化。而对于特定的数据集，选择这个值的结果可能并不理想。接下来给出计算特定数据集的均值和方差的方法。

'''

'\nPytorch图像预处理时，通常使用transforms.Normalize(mean, std)对图像按通道进行标准化，即减去均值，再除以方差。这样做可以加快模型的收敛速度。其中参数mean和std分别表示图像每个通道的均值和方差序列。\n\nImagenet数据集的均值和方差为：mean=(0.485, 0.456, 0.406)，std=(0.229, 0.224, 0.225)，因为这是在百万张图像上计算而得的，所以我们通常见到在训练过程中使用它们做标准化。而对于特定的数据集，选择这个值的结果可能并不理想。接下来给出计算特定数据集的均值和方差的方法。\n\n'

In [1]:
#只有MRI可以跑通
import os 
os.environ['CUDA_VISIBLE_DEVICES'] = "0" 
import time
# import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
import pickle as p
import hiddenlayer as hl
import math

if torch.cuda.is_available():
    print("cuda is available")
    torch.backends.cudnn.deterministic = True

# Device
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device:', DEVICE)

cuda is available
Device: cuda:0


In [3]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root):
        super(MyDataset, self).__init__()
        MRI_PET_match_all = p.load(open(root,"rb"),encoding='iso-8859-1')
        MRI = []
        PET = []
        group = []
        for index,row in MRI_PET_match_all.iterrows():
            MRI.append(row['MRI_img_array'])
            PET.append(row['PET_img_array'])
            group.append(row['Group'])
        self.MRI = MRI
        self.PET = PET
        self.group = group  

    def __getitem__(self, index):
        mri = torch.from_numpy(self.MRI[index].transpose(2,0,1)).float()
        pet = torch.from_numpy(self.PET[index].transpose(2,0,1)).float()
        group = self.group[index]
        
        return mri,pet,group

    def __len__(self):
        return len(self.MRI)

train_data = MyDataset("/home/gc/gechang/gec_multi_fusion/end_to_end/train.pkl")
#valid_data = MyDataset("/home/gc/gechang/gec_multi_fusion/utils/valid_onlyone.pkl")
test_data = MyDataset("/home/gc/gechang/gec_multi_fusion/end_to_end/test.pkl")

train_loader = DataLoader(train_data, batch_size = 16, num_workers = 8, shuffle=True)
#valid_loader = DataLoader(valid_data, batch_size = 16, num_workers = 8, shuffle=True)
test_loader = DataLoader(test_data, batch_size = 16, num_workers = 8)

In [4]:

train_mean_mri = torch.zeros(3)
train_std_mri = torch.zeros(3)
train_mean_pet = torch.zeros(3)
train_std_pet = torch.zeros(3)
for mri,pet,group in train_loader:
    for d in range(3):
        train_mean_mri[d] += mri[:, d, :, :].mean()
        train_std_mri[d] += mri[:, d, :, :].std()
        train_mean_pet[d] += pet[:, d, :, :].mean()
        train_std_pet[d] += pet[:, d, :, :].std()
train_mean_mri.div_(len(train_data))
train_std_mri.div_(len(train_data))
train_mean_pet.div_(len(train_data))
train_std_pet.div_(len(train_data))
print( list(train_mean_mri.numpy()), list(train_std_mri.numpy()),list(train_mean_pet.numpy()), list(train_std_pet.numpy()))



[4.176061, 4.176061, 4.176061] [5.231413, 5.231413, 5.231413] [4.017313, 4.017313, 4.017313] [5.1714053, 5.1714053, 5.1714053]


In [5]:

test_mean_mri = torch.zeros(3)
test_std_mri = torch.zeros(3)
test_mean_pet = torch.zeros(3)
test_std_pet = torch.zeros(3)
for mri,pet,group in test_loader:
    for d in range(3):
        test_mean_mri[d] += mri[:, d, :, :].mean()
        test_std_mri[d] += mri[:, d, :, :].std()
        test_mean_pet[d] += pet[:, d, :, :].mean()
        test_std_pet[d] += pet[:, d, :, :].std()
test_mean_mri.div_(len(test_data))
test_std_mri.div_(len(test_data))
test_mean_pet.div_(len(test_data))
test_std_pet.div_(len(test_data))
print( list(test_mean_mri.numpy()), list(test_std_mri.numpy()),list(test_mean_pet.numpy()), list(test_std_pet.numpy()))



[4.3262687, 4.3262687, 4.3262687] [5.419836, 5.419836, 5.419836] [4.1623745, 4.1623745, 4.1623745] [5.3586817, 5.3586817, 5.3586817]


In [8]:

valid_mean_mri = torch.zeros(3)
valid_std_mri = torch.zeros(3)
valid_mean_pet = torch.zeros(3)
valid_std_pet = torch.zeros(3)
for mri,pet,group in valid_loader:
    for d in range(3):
        valid_mean_mri[d] += mri[:, d, :, :].mean()
        valid_std_mri[d] += mri[:, d, :, :].std()
        valid_mean_pet[d] += pet[:, d, :, :].mean()
        valid_std_pet[d] += pet[:, d, :, :].std()
valid_mean_mri.div_(len(valid_data))
valid_std_mri.div_(len(valid_data))
valid_mean_pet.div_(len(valid_data))
valid_std_pet.div_(len(valid_data))
print( list(valid_mean_mri.numpy()), list(valid_std_mri.numpy()),list(valid_mean_pet.numpy()), list(valid_std_pet.numpy()))

[4.7232037, 4.7232037, 4.7232037] [5.8529644, 5.8529644, 5.8529644] [4.5747848, 4.5747848, 4.5747848] [5.809622, 5.809622, 5.809622]
