In [1]:
# import necessary libraries
import h5py
import os
import numpy as np
from PIL import Image
import cv2

## Load Data
Loads the balanced dataset (500 train, 100 validation, 100 test) and imbalanced dataset (same as balanced except class 2 is 836 train, 50 validation, 50 train). For each, class 4 when included is 219 train, 50 val, 50 test.

In [96]:
def load_data(save_path: str, data_type: str, balanced: bool):
    
    if data_type == "tumor healthy":
        classes = [1,2,3,4]
    elif data_type == "tumor only":
        classes = [1,2,3]
    elif data_type == "2 tumors":
        classes = [1,2]
    else:
        raise ValueError('Invalid data type given.')
    
    tum_cnt_dct = {'train':{1:0, 2:0,3:0}, 'test':{1:0, 2:0,3:0}, 'validation':{1:0, 2:0,3:0}}
    i = 0

    # load fig share data (tumors)
    for dir in ['raw_data/fig_share/brainTumorDataPublic_1766', 'raw_data/fig_share/brainTumorDataPublic_7671532', 'raw_data/fig_share/brainTumorDataPublic_15332298', 'raw_data/fig_share/brainTumorDataPublic_22993064']:    
        for fi_name in os.listdir(dir):
            path = os.path.join(dir, fi_name)
            
            with h5py.File(path, 'r') as file:
                img = file['cjdata']['image'][:]
                
                img = img.astype(np.uint8)
                img = cv2.resize(img, (64, 64))
                img = Image.fromarray(img)

                label = int(file['cjdata']['label'][:][0][0])
                if label not in classes:
                    break
                
                if balanced or label != 2:
                    if tum_cnt_dct['train'][label]<500:
                            folder='train'
                            tum_cnt_dct['train'][label] += 1
                    elif tum_cnt_dct['validation'][label] < 100:
                        folder = 'validation'
                        tum_cnt_dct['validation'][label] += 1
                    elif tum_cnt_dct['test'][label] < 100:
                        folder = 'test'
                        tum_cnt_dct['test'][label] += 1
                    else:
                        break
                else:
                    if tum_cnt_dct['train'][label]<836:
                            folder='train'
                            tum_cnt_dct['train'][label] += 1
                    elif tum_cnt_dct['validation'][label] < 50:
                        folder = 'validation'
                        tum_cnt_dct['validation'][label] += 1
                    elif tum_cnt_dct['test'][label] < 50:
                        folder = 'test'
                        tum_cnt_dct['test'][label] += 1
                    else:
                        break
    
                img.save(f"{save_path}/{folder}/{label}/{i}.jpg")
                i +=1
        
    # load sartaj data (healthy) if needed
    if 4 in classes:
        dir = 'raw_data/SARTAJ'
        hlthy_cnt = 0
        
        for i, fi_name in enumerate(os.listdir(dir)):
             
            if hlthy_cnt < 219:
                 folder = 'train'
            elif hlthy_cnt < 269:
                 folder = 'validation'
            elif hlthy_cnt < 319:
                 folder = 'test'
            else:
                 break
            
            hlthy_cnt += 1
            
            img_path = f"{save_path}/{folder}/4/{i}.jpg"
            img = Image.open(os.path.join(dir, fi_name))
            img = img.resize((64, 64))
            img.save(img_path) 
            
            i += 1               


Load Tumor and Healthy Data

In [45]:
# load balanced data
load_data("balanced_data/tumor_healthy", 'tumor healthy', True)

In [102]:
# load imbalanced data
load_data("imbalanced_data/tumor_healthy", 'tumor healthy', False)

Load Tumor Data

In [55]:
# load balanced data
load_data("balanced_data/tumor_only", 'tumor only', True)

In [104]:
# load imbalanced data
load_data("imbalanced_data/tumor_only", 'tumor only', False)

Load 2 Tumor Data

In [105]:
# load balanced data
load_data("balanced_data/2_tumors", '2 tumors', True)

In [106]:
# load imbalanced data
load_data("imbalanced_data/2_tumors", '2 tumors', False)