In [1]:
import h5py
import numpy as np
import time

In [2]:
class DataWindow():
    def __init__(self, comp_env_window, R, S, beta, k, validRS, a_0):
        self.comp_env_window = comp_env_window
        self.R = R
        self.S = S
        self.beta = beta
        self.k = k
        self.validRS = validRS
        self.a_0 = a_0
    
    def __str__(self):
        return f"comp_env_window: {self.comp_env_window}\nR: {self.R}\nS: {self.S}\nbeta: {self.beta}\nk: {self.k}\nvalidRS: {self.validRS}\na_0: {self.a_0}"

In [3]:
def print_dataset_sizes(h5_path):
    with h5py.File(h5_path, 'r') as file:
        for group in file.keys():
            dataset_name = f'{group}/comp_env_interp_1'
            if dataset_name in file:
                data = np.array(file[dataset_name])
                print(f'Size of dataset {dataset_name}: {data.shape}')
            else:
                print(f'Dataset {dataset_name} not found')

def print_first_group_datasets(h5_path):
    with h5py.File(h5_path, 'r') as file:
        first_group = list(file.keys())[0]
        print(f"Datasets in the first group ({first_group}):")
        for dataset in file[first_group].keys():
            data = np.array(file[first_group][dataset])
            print(f" - {dataset}: {data.shape}")

In [4]:
h5_path = '../data/dataoncosalud/res_valid/breast_comp_env_data.h5'
# print_dataset_sizes(h5_path)
print_first_group_datasets(h5_path)

Datasets in the first group (file_0001):
 - R_matrix: (188, 200)
 - S_matrix: (188, 200)
 - a_0: (1, 1)
 - b_0: (1, 1)
 - beta_matrix: (188, 200)
 - comp_env_interp_1: (244, 256)
 - env_rf_interp: (244, 256)
 - k_matrix: (188, 200)
 - n: (1, 1)
 - validRS: (188, 200)


In [5]:
def calculate_total_windows(h5_path, n):
    total_windows = 0
    with h5py.File(h5_path, 'r') as file:
        for group in file.keys():
            dataset_name = f'{group}/comp_env_interp_1'
            if dataset_name in file:
                data = np.array(file[dataset_name])
                if data.ndim == 2:
                    num_windows = ((data.shape[0] - n + 1) * (data.shape[1] - n + 1))
                    total_windows += num_windows
                else:
                    raise ValueError(f"Dataset {dataset_name} is not 2-dimensional")
            else:
                raise ValueError(f"Dataset {dataset_name} not found")
    return total_windows

# def get_window_xy(h5_path, n, window_idx):
#     total_windows = calculate_total_windows(h5_path, n)
#     if window_idx >= total_windows:
#         raise IndexError("Window index out of range")
    
#     current_window = 0
#     with h5py.File(h5_path, 'r') as file:
#         for group in file.keys():
#             dataset_name = f'{group}/comp_env_interp_1'
#             validRS_name = f'{group}/validRS'
#             if dataset_name in file and validRS_name in file:
#                 data = np.array(file[dataset_name])
#                 if data.ndim == 2:
#                     num_windows = ((data.shape[0] - n + 1) * (data.shape[1] - n + 1))
#                     if current_window + num_windows > window_idx:
#                         local_idx = window_idx - current_window
#                         row_idx = local_idx // (data.shape[1] - n + 1)
#                         col_idx = local_idx % (data.shape[1] - n + 1)
#                         comp_env_window = data[row_idx:row_idx+n, col_idx:col_idx+n]

#                         validRS = np.array(file[validRS_name])
#                         validRS_value = validRS[row_idx, col_idx]

#                         return comp_env_window, validRS_value
#                     current_window += num_windows
#                 else:
#                     raise ValueError(f"Dataset {dataset_name} is not 2-dimensional")
#             else:
#                 raise ValueError(f"Dataset {dataset_name} or {validRS_name} not found")
#     raise IndexError("Window index out of range")

def get_window_xy(h5_path, n, window_idx):
    total_windows = calculate_total_windows(h5_path, n)
    if window_idx >= total_windows:
        raise IndexError("Window index out of range")
    
    current_window = 0
    with h5py.File(h5_path, 'r') as file:
        for group in file.keys():
            dataset_name = f'{group}/comp_env_interp_1'
            R_matrix_name = f'{group}/R_matrix'
            S_matrix_name = f'{group}/S_matrix'
            a_0_name = f'{group}/a_0'
            beta_matrix_name = f'{group}/beta_matrix'
            k_matrix_name = f'{group}/k_matrix'
            validRS_name = f'{group}/validRS'
            if dataset_name in file and validRS_name in file:
                data = np.array(file[dataset_name])
                if data.ndim == 2:
                    num_windows = ((data.shape[0] - n + 1) * (data.shape[1] - n + 1))
                    if current_window + num_windows > window_idx:
                        local_idx = window_idx - current_window
                        row_idx = local_idx // (data.shape[1] - n + 1)
                        col_idx = local_idx % (data.shape[1] - n + 1)
                        comp_env_window = data[row_idx:row_idx+n, col_idx:col_idx+n]

                        R = np.array(file[R_matrix_name])
                        R_value = R[row_idx, col_idx]

                        S = np.array(file[S_matrix_name])
                        S_value = S[row_idx, col_idx]

                        beta = np.array(file[beta_matrix_name])
                        beta_value = beta[row_idx, col_idx]

                        k = np.array(file[k_matrix_name])
                        k_value = k[row_idx, col_idx]

                        a_0 = np.array(file[a_0_name])
                        a_value = a_0[0][0]

                        validRS = np.array(file[validRS_name])
                        validRS_value = validRS[row_idx, col_idx]

                        return DataWindow(comp_env_window, R_value, S_value, beta_value, k_value, validRS_value, a_value)
                    current_window += num_windows
                else:
                    raise ValueError(f"Dataset {dataset_name} is not 2-dimensional")
            else:
                raise ValueError(f"Dataset {dataset_name} or {validRS_name} not found")
    raise IndexError("Window index out of range")

# Example usage
n=57
total_windows = calculate_total_windows(h5_path, n)
print(f'Total number of {n}x{n} windows: {total_windows}')

window_idx = 8964152
comp_env_window = get_window_xy(h5_path, n, window_idx)
print(comp_env_window)

Total number of 57x57 windows: 22988144
comp_env_window: [[137.75766683 133.2111603  141.7927883  ... 177.76114717 176.87902033
  173.15100083]
 [164.15741549 167.85168201 171.77379703 ... 171.75821628 162.46338627
  154.54537015]
 [153.70165877 163.45111897 165.52004356 ... 150.50108707 154.16908375
  162.24663787]
 ...
 [153.71254841 165.316187   157.39771172 ... 155.60340152 165.28534673
  152.62939152]
 [155.92038459 161.35482999 157.31458726 ... 147.1035658  156.74833358
  146.60485593]
 [160.48098243 166.74628761 160.25674777 ... 161.06037448 162.99252765
  159.6153866 ]]
R: 1.8394473726407958
S: 1.445157404233818
beta: 0.38330175426580004
k: 3.3610563904074784
validRS: 0.0
a_0: 22.4888433408961


In [6]:
def create_validRS_dataset(h5_path):
    validRS_values = []
    with h5py.File(h5_path, 'r') as file:
        for group in file.keys():
            print(f"{group}/{len(file.keys())}")
            validRS_name = f'{group}/validRS'
            if validRS_name in file:
                validRS = np.array(file[validRS_name])
                if validRS.ndim == 2:
                    for row_idx in range(validRS.shape[0]):
                        for col_idx in range(validRS.shape[1]):
                            validRS_value = validRS[row_idx, col_idx]
                            validRS_values.append(validRS_value)
                else:
                    raise ValueError(f"Dataset {validRS_name} is not 2-dimensional")
            else:
                raise ValueError(f"Dataset {validRS_name} not found")
    validRS_values = np.array(validRS_values)
    return validRS_values

validRS_values=create_validRS_dataset(h5_path)

file_0001/451
file_0002/451
file_0003/451
file_0004/451
file_0005/451
file_0006/451
file_0007/451
file_0008/451
file_0009/451
file_0010/451
file_0011/451
file_0012/451
file_0013/451
file_0014/451
file_0015/451
file_0016/451
file_0017/451
file_0018/451
file_0019/451
file_0020/451
file_0021/451
file_0022/451
file_0023/451
file_0024/451
file_0025/451
file_0026/451
file_0027/451
file_0028/451
file_0029/451
file_0030/451
file_0031/451
file_0032/451
file_0033/451
file_0034/451
file_0035/451
file_0036/451
file_0037/451
file_0038/451
file_0039/451
file_0040/451
file_0041/451
file_0042/451
file_0043/451
file_0044/451
file_0045/451
file_0046/451
file_0047/451
file_0048/451
file_0049/451
file_0050/451
file_0051/451
file_0052/451
file_0053/451
file_0054/451
file_0055/451
file_0056/451
file_0057/451
file_0058/451
file_0059/451
file_0060/451
file_0061/451
file_0062/451
file_0063/451
file_0064/451
file_0065/451
file_0066/451
file_0067/451
file_0068/451
file_0069/451
file_0070/451
file_0071/451
file_0

In [7]:
validRS_values.shape

(22988144,)

In [8]:
from sklearn.model_selection import train_test_split
import pickle

# Get indices of 0s and 1s in validRS_values
zero_indices = np.where(validRS_values == 0)[0]
one_indices = np.where(validRS_values == 1)[0]

# Calculate the number of 1s to be used in the dataset
num_ones = len(one_indices)
num_zeros = num_ones

# Select an equal number of 0s
selected_zero_indices, _ = train_test_split(zero_indices, train_size=num_zeros, random_state=42)

# Combine the selected 0s and 1s
selected_indices = np.concatenate((selected_zero_indices, one_indices))

# Split the selected indices into training and validation sets
train_indices, val_indices = train_test_split(selected_indices, test_size=0.2, random_state=42)

# Ensure 50-50 ratio in both splits
train_zeros = train_indices[validRS_values[train_indices] == 0]
train_ones = train_indices[validRS_values[train_indices] == 1]
val_zeros = val_indices[validRS_values[val_indices] == 0]
val_ones = val_indices[validRS_values[val_indices] == 1]

# Adjust the training set to have equal number of 0s and 1s
if len(train_zeros) > len(train_ones):
    train_zeros = train_zeros[:len(train_ones)]
else:
    train_ones = train_ones[:len(train_zeros)]

# Adjust the validation set to have equal number of 0s and 1s
if len(val_zeros) > len(val_ones):
    val_zeros = val_zeros[:len(val_ones)]
else:
    val_ones = val_ones[:len(val_zeros)]

# Combine adjusted indices
train_indices = np.empty((train_zeros.size + train_ones.size,), dtype=train_zeros.dtype)
train_indices[0::2] = train_zeros
train_indices[1::2] = train_ones

val_indices = np.empty((val_zeros.size + val_ones.size,), dtype=val_zeros.dtype)
val_indices[0::2] = val_zeros
val_indices[1::2] = val_ones

# Save the splits into a pickle file
split_data = {
    'train_files': train_indices,
    'val_files': val_indices
}

# with open('breast_data_splits_CNN.pkl', 'wb') as f:
#     pickle.dump(split_data, f)

print(f'Training files: {len(train_indices)}, Validation files: {len(val_indices)}')
print('Training files: 838308, Validation files: 209280')

Training files: 698046, Validation files: 174468
Training files: 838308, Validation files: 209280


In [9]:
print(len(train_zeros), len(train_ones), len(val_zeros), len(val_ones))
print(validRS_values[train_indices].sum(), validRS_values[val_indices].sum())
print(validRS_values[train_indices[0:len(train_indices)//2]].sum(), validRS_values[val_indices[0:len(val_indices)//2]].sum())

# print('419154 419154 104640 104640')
# print('419154.0 104640.0')
# print('209577.0 52320.0')

349023 349023 87234 87234
349023.0 87234.0
174511.0 43617.0


In [10]:
from tqdm import tqdm

def create_windows_array(h5_path, indices, n):
    num_windows = len(indices)

    windows = np.empty(num_windows, dtype=DataWindow)
    
    for i, idx in enumerate(tqdm(indices, desc="Loading windows")):
        windows[i] = get_window_xy(h5_path, n, idx)
    
    return windows

# Create arrays for training and validation sets
train_windows = create_windows_array(h5_path, train_indices[:3200], n)
val_windows = create_windows_array(h5_path, val_indices[:800], n)

print(f'Train comp_env_windows shape: {train_windows.shape}')
print(f'Validation comp_env_windows shape: {val_windows.shape}')

# Save the arrays into a pickle file
data_arrays = {
    'train_windows': train_windows,
    'val_windows': val_windows
}

with open('breast_data_arrays_CNN.pkl', 'wb') as f:
    pickle.dump(data_arrays, f)

Loading windows: 100%|██████████| 3200/3200 [16:14<00:00,  3.29it/s]
Loading windows: 100%|██████████| 800/800 [04:11<00:00,  3.18it/s]


Train comp_env_windows shape: (3200,)
Validation comp_env_windows shape: (800,)
