In [2]:
import os
import argparse
import json
import numpy as np
import random
from typing import Tuple
import torch
from torch import nn, optim
from torch.distributed import Backend
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils.data import Dataset, DataLoader, DistributedSampler
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import nibabel as nib
from training import Trainer
from notionIntegration import tqdm_notion
from dotenv import load_dotenv
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from skimage.transform import resize

from MedT_C import MedT_C

load_dotenv()

METADATA_FILENAME = "meta_data_with_label.json"

class ADNI(Dataset):
    """ADNI Dataset."""

    def __init__(self, root, transform=None):
        self._transform = transform
        self._root = root

        # Load information from metadata file.
        metadata_file = open(os.path.join(self._root, METADATA_FILENAME))
        self._metadata = json.load(metadata_file)
        self._data = next(os.walk(os.path.join(root, "images")), (None, None, []))[2]  # [] if no file

    def __len__(self):
        return len(self._metadata)*16 # Because we took 16 slices.

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.toList()

        filename = self._data[idx]
        nifti_image = nib.load(os.path.join(self._root, "images", filename))
        nii_data = nifti_image.get_fdata()

        if self._transform:
            nii_data = self._transform(nii_data)

        label = self._metadata[filename.split("_")[0]]["label"]

        return (nii_data, label)

    class Resize(object):
        def __init__(self, output_size):
            assert isinstance(output_size, (int, tuple))
            if isinstance(output_size, int):
                self.output_size = (output_size, output_size)
            else:
                assert len(output_size) == 2
                self.output_size = output_size
        
        def __call__(self, sample):
            return resize(sample, self.output_size)
    
    class Normalize(object):        
        def __call__(self, sample):
            return sample*255.0/sample.max()


transform = transforms.Compose(
    [
        ADNI.Normalize(),
        ADNI.Resize((224,224)),
        transforms.ToTensor(),
        # transforms.Normalize((91.8696,), (495.5406,))
    ]
)

dataset = ADNI("./data/ADNI_sliced", transform=transform)

# img = dataset.__getitem__(50)[0]
# plt.imshow(img)

image_loader = DataLoader(dataset = dataset, 
                          batch_size = 64, 
                          shuffle = False, 
                          num_workers = 8,
                          pin_memory = True)

# placeholders
psum = torch.tensor([0.0])
psum_sq = torch.tensor([0.0])

# loop through images
for data, labels in tqdm(image_loader):
    # break
    psum += data.sum(axis = [0, 2, 3])
    psum_sq += (data ** 2).sum(axis = [0, 2, 3])

# (images, labels) = next(iter(image_loader))
def show(img):
    npimg = img.numpy()
    fig = plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
    fig.axes.get_xaxis().set_visible(False)
    fig.axes.get_yaxis().set_visible(False)
    plt.show()

# show(make_grid(images.cpu().data))

# show(make_grid(zeros.cpu().data))

100%|██████████| 547/547 [00:16<00:00, 32.59it/s]


In [3]:
####### FINAL CALCULATIONS

# pixel count
count = len(dataset) * 224 * 224

# mean and std
total_mean = psum / count
total_var  = (psum_sq / count) - (total_mean ** 2)
total_std  = torch.sqrt(total_var)

# output
print('mean: '  + str(total_mean))
print('std:  '  + str(total_std))

mean: tensor([27.0715])
std:  tensor([53.3210])
