# Precocessing Dataset

In [16]:
import random
import torch
import numpy as np
import os
import matplotlib.image as mpimg
from torch.utils.data import Dataset, DataLoader, random_split

In [8]:
def set_seed(seed):
    """
    Use this to set ALL the random seeds to a fixed value and take out any randomness from cuda kernels
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.benchmark = True  ##uses the inbuilt cudnn auto-tuner to find the fastest convolution algorithms. -
    torch.backends.cudnn.enabled   = True

    return True

In [11]:
class Crop3D:
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, input, output):
        h, w, d = input.shape
        new_h, new_w, new_d = self.output_size
        

        top = np.random.randint(0, h - new_h + 1)
        left = np.random.randint(0, w - new_w + 1)
        front = np.random.randint(0, d - new_d + 1)
        # top = 800
        # left = 1000
        # front = 1000
        
        input = input[top: top + new_h, left: left + new_w, front: front + new_d]
        output = output[top: top + new_h, left: left + new_w, front: front + new_d]
        return input, output
    

class CTDataset(Dataset):
    def __init__(self, base_dir, global_min, global_max,transform=None):
        self.base_dir = base_dir
        self.dry_scan_dirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("004_estaillades1_DI")]
        self.exper_dirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, d)) and d.startswith("032_estaillades1_q01_fw07_us")]
        self.global_min = global_min
        self.global_max = global_max
        self.sample_list = self._create_sample_list()
        self.transform = transform
        
        self.dry_scan_path = self._create_dry_scan_list()
        
    def _create_dry_scan_list(self):
        sub_dirs = self.dry_scan_dirs
        total_time_steps = []

        for sub_dir in sub_dirs:
            time_steps = [os.path.join(sub_dir, f) for f in os.listdir(sub_dir) if os.path.isdir(os.path.join(sub_dir, f))]
            total_time_steps.extend(time_steps)
        total_time_steps.sort()
        
        sample_list = []
        for time_step in total_time_steps:
            for i in range(1600):
                formatted_number = f"{(i+1):04d}"
                img_path = os.path.join(time_step, f'004_estaillades1_DI_{formatted_number}.rec.16bit.tif')
                if os.path.exists(img_path):
                    sample_list.append(img_path)
        return sample_list
        
    def _create_sample_list(self, dry = False):
        sample_list = []
        
        if not dry:
            sub_dirs = self.exper_dirs
            total_time_steps = []
            
            for sub_dir in sub_dirs:
                time_steps = [os.path.join(sub_dir, f) for f in os.listdir(sub_dir) if os.path.isdir(os.path.join(sub_dir, f))]
                total_time_steps.extend(time_steps)
            total_time_steps.sort()
        
#             for time_step in total_time_steps:
#                 folder_name = os.path.basename(time_step)[-5:]
#                 for i in range(1600):
#                     formatted_number = f"{(i+1):04d}"
#                     img_path = os.path.join(time_step, f'032_estaillades1_q01_fw07_us_{folder_name}_{formatted_number}.rec.16bit.tif')
#                     if os.path.exists(img_path):
#                         sample_list.append(img_path
        
#             for time_step in total_time_steps:
#                 for i in range(1600):
#                     formatted_number = f"{(i+1):04d}"
#                     img_path = os.path.join(time_step, f'004_estaillades1_DI_{formatted_number}.rec.16bit.tif')
#                     if os.path.exists(img_path):
#                         sample_list.append(img_path)
            
        return total_time_steps

    def __len__(self):
        return len(self.sample_list)

    def preprocessing(self, idx):
        sample_path = self.sample_list[idx]
#         dry_scan_path = self.dry_scan_path
#         print(dry_scan_path)
        
        img = np.zeros((1600,2016,2016))
#         dry_scan = np.zeros((1600,2016,2016))
        
        folder_name_input = os.path.basename(sample_path)[-5:]
#         folder_name_output = os.path.basename(output_path)[-5:]
        for i in range(1600):
            formatted_number = f"{(i+1):04d}"
            img_path = os.path.join(sample_path, f'032_estaillades1_q01_fw07_us_{folder_name_input}_{formatted_number}.rec.16bit.tif')
#             output_one_channel_path = os.path.join(output_path, f'032_estaillades1_q01_fw07_us_{folder_name_output}_{formatted_number}.rec.16bit.tif')
#             dry_scan_one_channel_path = dry_scan_path[i]
            if os.path.exists(img_path):
                img_one_channel = mpimg.imread(img_path)
                img_one_channel_normalized = (img_one_channel - self.global_min) / (self.global_max - self.global_min)
                img[i,:,:] = img_one_channel_normalized
                
#                 output_scan_one_channel = mpimg.imread(output_one_channel_path)
#                 output_scan_one_channel_normalized = (output_scan_one_channel- self.global_min) / (self.global_max - self.global_min)
#                 output[i,:,:] = output_scan_one_channel_normalized
                
#                 dry_scan_one_channel = mpimg.imread(dry_scan_one_channel_path)
#                 dry_scan_one_channel_normalized = (dry_scan_one_channel - self.global_min) / (self.global_max - self.global_min)
#                 dry_scan[i,:,:] = dry_scan_one_channel_normalized
                
        print(img_path)
        img = torch.from_numpy(img).float()
#         dry_scan = torch.from_numpy(dry_scan).float()
        print("the img shape:",img.shape)
        torch.save(img, f'/rds/general/user/hw123/ephemeral/dataset/wet_scan/time_step_{folder_name_input}.pt')
        print(f"time_step_{folder_name_input} is saved")
        
        
        
#         img = mpimg.imread(img_path)
#         img_normalized = (img - self.global_min) / (self.global_max - self.global_min)
#         img_normalized = torch.tensor(img_normalized, dtype=torch.float32)
        
#         if self.transform:
#             img, output = self.transform(img, output)
        

In [12]:
## Cacluate the max and min
def estimate_global_min_max(base_dir, num_samples = 4, num_slices = 100):
    sub_dirs = [os.path.join(base_dir, d) for d in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir,d))]
    sampled_dirs = random.sample(sub_dirs, min(num_samples, len(sub_dirs)))
    
    print(sampled_dirs)
    global_min = float('inf')
    global_max = float('-inf')

    for sub_dir in sampled_dirs:
        folder_name = os.path.basename(sub_dir)[-5:]
        sampled_slices = random.sample(range(1600), num_slices)
        
        for i in sampled_slices:
            formatted_number = f"{(i+1):04d}"
            img_path = os.path.join(sub_dir, f'032_estaillades1_q01_fw07_us_{folder_name}_{formatted_number}.rec.16bit.tif')
            
            if os.path.exists(img_path):
                img = mpimg.imread(img_path)
                global_min = min(global_min, img.min())
                global_max = max(global_max, img.max())
    
    return global_min, global_max

In [13]:
base_dir = f'../data_set/IRP_porescale_AI/032_estaillades1_q01_fw07_us'
set_seed(42)
global_min, global_max = estimate_global_min_max(base_dir)
print(f"Global min: {global_min}, Global max: {global_max}")

['../data_set/IRP_porescale_AI/032_estaillades1_q01_fw07_us/rec_16bit_phase_00019', '../data_set/IRP_porescale_AI/032_estaillades1_q01_fw07_us/rec_16bit_phase_00004', '../data_set/IRP_porescale_AI/032_estaillades1_q01_fw07_us/rec_16bit_phase_00009', '../data_set/IRP_porescale_AI/032_estaillades1_q01_fw07_us/rec_16bit_phase_00026']
Global min: 0, Global max: 65535


In [14]:
base_dir = f'../data_set/IRP_porescale_AI'
dataset = CTDataset(base_dir=base_dir,global_max=global_max,global_min=global_min)
len(dataset)

30

In [None]:
for i in range(len(dataset)):
    dataset.preprocessing(i)

In [17]:
img = torch.load(f'/rds/general/user/hw123/ephemeral/dataset/wet_scan/time_step_00030.pt')
img.shape,img

(torch.Size([1600, 2016, 2016]),
 tensor([[[0.3121, 0.3144, 0.3094,  ..., 0.3037, 0.2889, 0.2913],
          [0.2910, 0.3065, 0.2989,  ..., 0.2914, 0.2905, 0.2998],
          [0.2976, 0.3067, 0.2987,  ..., 0.2964, 0.2841, 0.2918],
          ...,
          [0.2940, 0.3015, 0.2916,  ..., 0.3165, 0.3387, 0.2903],
          [0.2888, 0.3017, 0.2944,  ..., 0.2994, 0.3283, 0.3230],
          [0.2966, 0.3093, 0.3048,  ..., 0.3074, 0.3006, 0.3048]],
 
         [[0.2933, 0.2940, 0.3025,  ..., 0.2984, 0.2974, 0.3012],
          [0.2853, 0.3066, 0.3144,  ..., 0.3080, 0.2985, 0.3064],
          [0.2762, 0.2868, 0.3096,  ..., 0.3091, 0.2837, 0.3095],
          ...,
          [0.3116, 0.3083, 0.3129,  ..., 0.3051, 0.3053, 0.2924],
          [0.3025, 0.3097, 0.3120,  ..., 0.3138, 0.3142, 0.2966],
          [0.3119, 0.3078, 0.2919,  ..., 0.3177, 0.3140, 0.3215]],
 
         [[0.2718, 0.2832, 0.3159,  ..., 0.3145, 0.3128, 0.2996],
          [0.2872, 0.2880, 0.3038,  ..., 0.3100, 0.2972, 0.2799],
       

In [5]:
dry = torch.load(f'../data_set/IRP_porescale_tensor/dry_scan/dry_scan.pt')

In [6]:
dry

tensor([[[0.2600, 0.2712, 0.2880,  ..., 0.2985, 0.3037, 0.2907],
         [0.2866, 0.2713, 0.2797,  ..., 0.2955, 0.2798, 0.2751],
         [0.3160, 0.2963, 0.2928,  ..., 0.2678, 0.2712, 0.2769],
         ...,
         [0.2999, 0.3286, 0.3283,  ..., 0.2923, 0.2996, 0.2880],
         [0.3217, 0.3284, 0.2912,  ..., 0.2966, 0.3095, 0.2918],
         [0.2964, 0.2893, 0.2870,  ..., 0.2951, 0.3130, 0.3215]],

        [[0.2640, 0.2807, 0.2925,  ..., 0.3103, 0.3147, 0.3093],
         [0.2831, 0.2821, 0.2859,  ..., 0.3004, 0.2840, 0.2728],
         [0.3056, 0.2986, 0.2973,  ..., 0.2723, 0.2734, 0.2778],
         ...,
         [0.3057, 0.2969, 0.3073,  ..., 0.3142, 0.3072, 0.2917],
         [0.2958, 0.2970, 0.3004,  ..., 0.3153, 0.3226, 0.2978],
         [0.2883, 0.2952, 0.3036,  ..., 0.2973, 0.3312, 0.3331]],

        [[0.2624, 0.2789, 0.3001,  ..., 0.2951, 0.2884, 0.3064],
         [0.2843, 0.2847, 0.2913,  ..., 0.2841, 0.2930, 0.3138],
         [0.2956, 0.2967, 0.2992,  ..., 0.2871, 0.3022, 0.

In [18]:
dataset._create_dry_scan_list()

['../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0001.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0002.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0003.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0004.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0005.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0006.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0007.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0008.rec.16bit.tif',
 '../data_set/IRP_porescale_AI/004_estaillades1_DI/rec_16bit_phase/004_estaillades1_DI_0009.rec.16bit.tif',
 '../data_set/IRP_porescale_