# Data Preprocessing for NSCLC

## Step 0: Download and Organize NSCLC Datasets

Download the following public NSCLC datasets from The Cancer Imaging Archive (TCIA):

- [NSCLC-Radiomics](https://www.cancerimagingarchive.net/collection/nsclc-radiomics/)
- [NSCLC-Radiomics-Interobserver1](https://www.cancerimagingarchive.net/collection/nsclc-radiomics-interobserver1/)
- [RIDER-LungCT-Seg](https://www.cancerimagingarchive.net/collection/rider-lung-ct/)
- [NSCLC Radiogenomics](https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/)
- [LUNG-PET-CT-Dx](https://www.cancerimagingarchive.net/collection/lung-pet-ct-dx/)

After downloading, organize the data into the following directory structure:

```text
/workspace/data/Original_dataset/NSCLC/
├── NSCLC-Radiomics/
│   ├── LUNG1-001/
│   │   ├── ct.nii.gz           # CT image
│   │   ├── gtv_1.nii.gz        # Tumor 1 mask (binary: 1 = tumor, 0 = background)
│   │   ├── gtv_2.nii.gz        # Tumor 2 mask (if multiple tumors)
│   └── ...
├── NSCLC-Radiomics-Interobserver1/
│   ├── interobs01/
│   │   ├── ct.nii.gz
│   │   ├── gtv_1.nii.gz
│   └── ...
├── RIDER-LungCT-Seg/
│   ├── RIDER-0000000000/
│   │   ├── ct.nii.gz
│   │   ├── gtv_1.nii.gz
│   └── ...
├── NSCLC-Radiogenomics/
│   ├── R01-001/
│   │   ├── ct.nii.gz
│   │   ├── gtv_1.nii.gz
│   └── ...
├── Lung-PET-CT-Dx/
│   ├── Lung_Dx-A0001/
│   │   ├── ct.nii.gz
│   │   ├── gtv_1.nii.gz
│   └── ...
```

## Tumor Subtype Labels
In addition to organizing the image and mask files, prepare a JSON file named `gtv_labels.json` in each dataset folder (if applicable) to store the tumor subtype for each tumor **instance**.

Each key should follow the format `<patient_id>_<tumor_index>`, where `tumor_index` matches the suffix in the corresponding `gtv_i.nii.gz` file.


### Example:
If a folder contains:
```text
LUNG1-001/
├── ct.nii.gz
├── gtv_1.nii.gz
├── gtv_2.nii.gz
```
Then the `gtv_labels.json` file should look like:
```json
{
  "LUNG1-001_1": "SCC",
  "LUNG1-001_2": "ADC"
}
```

### Allowed label values:
- `SCC`: Squamous Cell Carcinoma
- `ADC`: Adenocarcinoma
- `LCC`: Large Cell Carcinoma
- `NOS`: Not Otherwise Specified
- `NaN`: Unknown or unavailable

This label file will be used later for subtype-aware training and evaluation.


## Step 1: Resample and Pad Volumes

To ensure consistency across all samples, we perform the following preprocessing steps:

### 1. Resample to Isotropic Spacing
- All CT images and tumor masks are resampled to a uniform voxel spacing of **(1.5, 1.5, 1.5)** mm using:
  - **Linear interpolation** for CT images
  - **Nearest-neighbor interpolation** for binary tumor masks

### 2. Zero-Padding for Minimum Size
- The image encoder requires input volumes of shape **(128, 128, 128)**.
- If a resampled volume is smaller along any axis, **zero-padding** is applied:
  - Padding is added to the **end** of each axis (i.e., inferior, posterior, or right side).
  - For CT images, the padding value is set to the **minimum intensity** of the image.
  - For masks, the padding value is set to **0** (background).

### Output Structure
Preprocessed volumes are saved in the following directory structure:
```text
/workspace/data/NSCLC/
├── images/        # Resampled and padded CT images
├── labels/        # Resampled and padded tumor masks
```

Each output file follows the naming convention:
```
<patient_id>_<tumor_index>.nii.gz
```
**Examples:**
- `LUNG1-001_1.nii.gz`
- `R01-001_2.nii.gz`

In [20]:
import os
import numpy as np
import SimpleITK as sitk
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import json

In [21]:
def is_mask_volume_zero(mask_sitk):
    mask_arr = sitk.GetArrayFromImage(mask_sitk)
    return np.sum(mask_arr) == 0

def resample_volume(volume, interpolator = sitk.sitkLinear, new_spacing = [1.5, 1.5, 1.5]):
    original_spacing = volume.GetSpacing()
    original_size = volume.GetSize()
    new_size = [int(round(osz*ospc/nspc)) for osz,ospc,nspc in zip(original_size, original_spacing, new_spacing)]

    min_value = float(sitk.GetArrayViewFromImage(volume).min())
    return sitk.Resample(volume, new_size, sitk.Transform(), interpolator,
                         volume.GetOrigin(), new_spacing, volume.GetDirection(), min_value,
                         volume.GetPixelID())

In [None]:
data_root = '/workspace/data/Original_dataset/NSCLC'

save_root = '/workspace/data/NSCLC'
images_save_root = os.path.join(save_root, 'images')
labels_save_root = os.path.join(save_root, 'labels')
os.makedirs(images_save_root, exist_ok=True)
os.makedirs(labels_save_root, exist_ok=True)

total_gtv_labels = dict()
gtv_types = ["SCC", "ADC", "LCC", "NOS", "NaN"]

datasets = os.listdir(data_root)
for dataset in datasets:
    isGTV = True
    gtv_label_path = os.path.join(data_root, dataset, 'gtv_labels.json')
    if os.path.exists(gtv_label_path):
        with open(gtv_label_path, 'r') as f:
            gtv_labels= json.load(f)
    else:
        print(f'No GTV labels found for dataset {dataset}, skipping...')
        isGTV = False

    dataset_root = os.path.join(data_root, dataset)
    print(f'Processing dataset: {dataset}')
    pats = [f for f in os.listdir(dataset_root) if os.path.isdir(os.path.join(dataset_root, f))]
    for pat in tqdm(pats):
        img_path = os.path.join(dataset_root, pat, 'ct.nii.gz')
        img_sitk = sitk.ReadImage(img_path)

        gtv_files = [f for f in os.listdir(os.path.join(dataset_root, pat)) if 'gtv' in f and f.endswith('.nii.gz')]
        if not gtv_files:
            print(f'{dataset}-{pat} does not have GTV mask')
            continue
        for gtv_file in gtv_files:
            target_label = int(gtv_file.split('.')[0].split('_')[-1])

            gtv_name = pat+'_'+str(target_label)
            if not isGTV or gtv_name not in gtv_labels:
                print(f'{dataset}-{pat} does not have GTV label {gtv_name}')
                total_gtv_labels[gtv_name] = "NaN"
            else:
                if gtv_labels[gtv_name] not in gtv_types:
                    print(f'{dataset}-{pat} has unknown GTV type {gtv_labels[gtv_name]} for {gtv_name}, setting to NaN')
                    total_gtv_labels[gtv_name] = "NaN"
                else:
                    total_gtv_labels[gtv_name] = gtv_labels[gtv_name]

            target_path = os.path.join(dataset_root, pat, gtv_file)
            target_sitk = sitk.ReadImage(target_path)

            if img_sitk.GetSize() != target_sitk.GetSize():
                print(f'{dataset}-{pat} does not have same shape between image and target mask {img_sitk.GetSize()} != {target_sitk.GetSize()}')
                continue

            if is_mask_volume_zero(target_sitk):
                print(f'{dataset}-{pat}-{target_label} mask is empty')
                continue

            # Resample spacing to 1.5x1.5x1.5
            target_spacing = (1.5, 1.5, 1.5)
            img_sitk = resample_volume(img_sitk, interpolator=sitk.sitkLinear, new_spacing=target_spacing)
            target_sitk = resample_volume(target_sitk, interpolator=sitk.sitkNearestNeighbor, new_spacing=target_spacing)

            img_arr = sitk.GetArrayFromImage(img_sitk)
            target_arr = sitk.GetArrayFromImage(target_sitk).astype(np.uint8)
            
            min_size = 128
            if img_arr.shape[0] < min_size:
                padding_x = min_size - img_arr.shape[0]
            elif img_arr.shape[0] % 4 != 0: # Pad to make size divisible by 4 for APE
                padding_x = 4 - (img_arr.shape[0] % 4)
            else:
                padding_x = 0

            if img_arr.shape[1] < min_size:
                padding_y = min_size - img_arr.shape[1]
            elif img_arr.shape[1] % 4 != 0: # Pad to make size divisible by 4 for APE
                padding_y = 4 - (img_arr.shape[1] % 4)
            else:
                padding_y = 0

            if img_arr.shape[2] < min_size:
                padding_z = min_size - img_arr.shape[2]
            elif img_arr.shape[2] % 4 != 0: # Pad to make size divisible by 4 for APE
                padding_z = 4 - (img_arr.shape[2] % 4)
            else:
                padding_z = 0

            img_arr = np.pad(img_arr, ((0, padding_x), (0, padding_y), (0, padding_z)), 'constant', constant_values=img_arr.min())
            target_arr = np.pad(target_arr, ((0, padding_x), (0, padding_y), (0, padding_z)), 'constant', constant_values=(0,0))

            img_sitk = sitk.GetImageFromArray(img_arr)
            target_sitk = sitk.GetImageFromArray(target_arr)

            img_sitk.SetSpacing(target_spacing)
            target_sitk.SetSpacing(target_spacing)

            img_save_path = os.path.join(images_save_root, pat+'_'+str(target_label)+'.nii.gz')
            target_save_path = os.path.join(labels_save_root, pat+'_'+str(target_label)+'.nii.gz')

        sitk.WriteImage(img_sitk, img_save_path)
        sitk.WriteImage(target_sitk, target_save_path)

# Save the GTV labels
gtv_labels_save_path = os.path.join(save_root, 'gtv_labels.json')
with open(gtv_labels_save_path, 'w') as f:
    json.dump(total_gtv_labels, f, indent=4)
print("\nGTV labels successfully saved!")
print(f"File path: {gtv_labels_save_path}")
print(f"Total number of GTV instances: {len(total_gtv_labels)}\n")

# Count the types of GTV labels
gtv_type_counts = {gtv_type: 0 for gtv_type in gtv_types}
for label in total_gtv_labels.values():
    if label in gtv_type_counts:
        gtv_type_counts[label] += 1

print("GTV Subtype Distribution")
print("-" * 30)
for gtv_type, count in gtv_type_counts.items():
    print(f"{gtv_type:<5}: {count}")
print("-" * 30)

## Step 2: Extract APE (Anatomical Positional Embedding)

To extract anatomical positional embeddings (APE) for each CT volume, please run the following notebook in a **separate Docker container** to avoid environment conflicts:

**[extract_ape_NSCLC.ipynb](extract_ape_NSCLC.ipynb)**

This will generate:
- `.npy`: a single 3D tensor with **3 channels** in shape `(3, H, W, D)`  
- `.nii.gz`: **three separate 3D NIfTI volumes** (one per channel)  

The outputs will be saved to:

```text
/workspace/data/NSCLC/
├── apes_npy/      # 3-channel APE tensors as .npy
├── apes_nii/      # Individual APE channels as .nii.gz
```


## Step 3: Extract Radiomics Features

we extract handcrafted radiomics features from the preprocessed CT images and corresponding tumor masks using the [PyRadiomics](https://pyradiomics.readthedocs.io/) library.

### Features extracted:
For each tumor region, we compute four types of features:
- 14 Shape-based features
- 18 Histogram-based first-order features
- 24 Texture features (GLCM)
- 16 Texture features (GLSZM)

### Output:
All valid features are saved to a single CSV file:
/workspace/data/NSCLC/radiomics_features.csv
- Each row corresponds to one tumor sample (ID: `<patient_id>_<tumor_index>`)
- Each column represents a radiomics feature
- Samples with missing or invalid segmentations are excluded from the output

In [7]:
import os
import numpy as np
import SimpleITK as sitk
import pandas as pd
import radiomics
from radiomics import featureextractor, firstorder, glcm, imageoperations, shape, glszm
import json
import warnings
# Ignore specific warning messages
warnings.simplefilter('ignore', DeprecationWarning)
from tqdm import tqdm
logger = radiomics.logging.getLogger("radiomics")
logger.setLevel(radiomics.logging.ERROR)

In [8]:
def Shape_Feature_Extract(ID, image, ROI):
    ShapeFeatureExtractor = radiomics.shape.RadiomicsShape(image, ROI)
    ShapeFeatureExtractor.enableAllFeatures()
    ShapeFeatureExtractor.execute()
    
    result = pd.DataFrame([ShapeFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['Shape_'+x for x in list(result.columns[1:])]
    
    return result

def Hist_Feature_Extract(ID, image, ROI):
    settings = {'binCount': 128, 'interpolator' : None, 'verbose' : True}
    
    HistFeatureExtractor = radiomics.firstorder.RadiomicsFirstOrder(image, ROI, **settings)
    HistFeatureExtractor.enableAllFeatures()
    HistFeatureExtractor.execute()
    
    result = pd.DataFrame([HistFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['Hist_'+x for x in list(result.columns[1:])]
    
    return result

def GLCM_Feature_Extract(ID, image, ROI):
    settings = {'binCount': 128, 'interpolator' : None, 'verbose' : True}
    
    GLCMFeatureExtractor = radiomics.glcm.RadiomicsGLCM(image, ROI, **settings)
    GLCMFeatureExtractor.enableAllFeatures()
    GLCMFeatureExtractor.execute()
    
    result = pd.DataFrame([GLCMFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['GLCM_'+x for x in list(result.columns[1:])]
    
    return result

def GLSZM_Feature_Extract(ID, image, ROI):
    settings = {'binCount': 128, 'interpolator' : None, 'verbose' : True}
    
    GLSZMFeatureExtractor = radiomics.glszm.RadiomicsGLSZM(image, ROI, **settings)
    GLSZMFeatureExtractor.enableAllFeatures()
    GLSZMFeatureExtractor.execute()
    
    result = pd.DataFrame([GLSZMFeatureExtractor.featureValues])
    result.insert(loc=0, column='ID', value=ID)
    result.columns = ['ID']+['GLSZM_'+x for x in list(result.columns[1:])]
    
    return result

In [None]:
data_root = '/workspace/data/NSCLC'
imgs = os.listdir(os.path.join(data_root, 'labels'))

shape_storage = dict()
hist_storage = dict()
glcm_storage = dict()
glszm_storage = dict()

Except_img = dict()

for img in tqdm(imgs):
    if 'nii' not in img:
        continue
    img_name = img.split('.')[0]
    img_path = os.path.join(data_root, 'images', img)
    seg_path = os.path.join(data_root, 'labels', img)

    try:
        img_sitk = sitk.ReadImage(img_path)
        target_sitk = sitk.ReadImage(seg_path)
        target_sitk = target_sitk != 0

        # empty seg
        if sitk.GetArrayFromImage(target_sitk).sum() == 0:
            print(f'{img_name} has no segmentation')
            Except_img[img_name] = 'no segmentation'
            continue

        shape_features = Shape_Feature_Extract(img_name, img_sitk, target_sitk)
        hist_features = Hist_Feature_Extract(img_name, img_sitk, target_sitk)
        glcm_features = GLCM_Feature_Extract(img_name, img_sitk, target_sitk)
        glszm_features = GLSZM_Feature_Extract(img_name, img_sitk, target_sitk)
    
    except:
        print(f'{img_name} feature extraction failed')
        Except_img[img_name] = 'Feature extraction failed'
        continue

    else:
        isnan = False

        shape = dict(shape_features.iloc[0,1:])
        for f in shape:
            shape[f] = float(shape[f])
            if np.isnan(shape[f]):
                isnan = True
        hist = dict(hist_features.iloc[0,1:])
        for f in hist:
            hist[f] = float(hist[f])  
            if np.isnan(hist[f]):
                isnan = True
        glcm = dict(glcm_features.iloc[0,1:])
        for f in glcm:
            glcm[f] = float(glcm[f])
            if np.isnan(glcm[f]):
                isnan = True
        glszm = dict(glszm_features.iloc[0,1:])
        for f in glszm:
            glszm[f] = float(glszm[f])
            if np.isnan(glszm[f]):
                isnan = True

        if isnan:
            print(f'{img_name} has nan value')
            Except_img[img_name] = 'nan value'
            continue

        shape_storage[img_name] = shape
        hist_storage[img_name] = hist
        glcm_storage[img_name] = glcm
        glszm_storage[img_name] = glszm

        file_name = "Exception_patients.json"
        file_path = os.path.join(data_root, file_name)
        with open(file_path, 'w') as f:
            json.dump(Except_img, f, indent=4)

        # file_name = "Shape_Features.json"
        # file_path = os.path.join(data_root, file_name)
        # with open(file_path, 'w') as f:
        #     json.dump(shape_storage, f, indent=4)

        # file_name = "Hist_Features.json"
        # file_path = os.path.join(data_root, file_name)
        # with open(file_path, 'w') as f:
        #     json.dump(hist_storage, f, indent=4)

        # file_name = "GLCM_Features.json"
        # file_path = os.path.join(data_root, file_name)
        # with open(file_path, 'w') as f:
        #     json.dump(glcm_storage, f, indent=4)

        # file_name = "GLSZM_Features.json"
        # file_path = os.path.join(data_root, file_name)
        # with open(file_path, 'w') as f:
        #     json.dump(glszm_storage, f, indent=4)

print('Done')

file_name = "Exception_patients.json"
file_path = os.path.join(data_root, file_name)
with open(file_path, 'w') as f:
    json.dump(Except_img, f, indent=4)

df_shape = pd.DataFrame(shape_storage).T
df_hist = pd.DataFrame(hist_storage).T
df_glcm = pd.DataFrame(glcm_storage).T
df_glszm = pd.DataFrame(glszm_storage).T

radiomics_features = pd.concat([df_shape, df_hist, df_glcm, df_glszm], axis=1)

print('Shape features:', df_shape.shape)
print('Histogram features:', df_hist.shape)
print('GLCM features:', df_glcm.shape)
print('GLSZM features:', df_glszm.shape)
print('Total Radiomics features:', radiomics_features.shape)

radiomics_features.index.name = 'ID'
radiomics_features.reset_index(inplace=True)
# radiomics_features.to_csv(os.path.join(data_root, 'radiomics_features.csv'), index=False)

# print('\nAll features saved to radiomics_features.csv')

# Load the GTV labels
gtv_labels_path = os.path.join(data_root, 'gtv_labels.json')
with open(gtv_labels_path, 'r') as f:
    gtv_labels = json.load(f)

# Add GTV labels to the radiomics features DataFrame
radiomics_features['GTV_label'] = radiomics_features['ID'].map(gtv_labels)

# Save the updated DataFrame with GTV labels
radiomics_features.to_csv(os.path.join(data_root, 'radiomics_features_with_gtv_labels.csv'), index=False)
print('Radiomics features with GTV labels saved to radiomics_features_with_gtv_labels.csv')

## Step 4: Convert to HDF5 Format

To facilitate fast I/O and unified data handling in downstream tasks, we convert all preprocessed data into a single HDF5 file.

Each tumor sample includes:
- **CT image** (`image`): 3D volume of shape (H, W, D)
- **Tumor mask** (`tumor`): Binary segmentation mask aligned with the image
- **Anatomical Positional Embedding** (`ape`): 3-channel tensor of shape (3, H, W, D)

The output HDF5 file is structured as:

```text
/workspace/data/NSCLC/NSCLC_data.hdf5
├── LUNG1-001_1/
│   ├── image   → (H, W, D)
│   ├── tumor   → (H, W, D)
│   ├── ape     → (3, H, W, D)
├── R01-001_1/
│   ├── ...
```

In [29]:
import os
import numpy as np
import SimpleITK as sitk
from tqdm import tqdm
import json
import h5py

In [None]:
data_root = '/workspace/data/NSCLC'
imgs = os.listdir(os.path.join(data_root, 'labels'))

hdf5_save_path = os.path.join(data_root, 'NSCLC_data.hdf5')
with h5py.File(hdf5_save_path, 'w') as hf:
    pass

for img in tqdm(imgs):
    img_name = img.split('.')[0]
    img_path = os.path.join(data_root, 'images', img)
    seg_path = os.path.join(data_root, 'labels', img)
    ape_path = os.path.join(data_root, 'apes_npy', img_name + '.npy')

    img_arr = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(img_path)), (2, 1, 0))
    seg_arr = np.transpose(sitk.GetArrayFromImage(sitk.ReadImage(seg_path)), (2, 1, 0))
    ape_arr = np.load(ape_path)

    if img_arr.shape != seg_arr.shape or img_arr.shape != ape_arr.shape[1:]:
        print(f'{img_name} has inconsistent shapes: {img_arr.shape}, {seg_arr.shape}, {ape_arr.shape}')
        continue

    with h5py.File(hdf5_save_path, 'a') as hf:
        grp = hf.create_group(img_name)
        grp.create_dataset("image", data=img_arr, compression="lzf")
        grp.create_dataset("tumor", data=seg_arr, compression="lzf")
        grp.create_dataset("ape", data=ape_arr, compression="lzf")

print(f'\nAll data saved to {hdf5_save_path}')

## Step 5: Prepare Radiomics JSONL Dataset

In this step, we structure and export the radiomics feature dataset for model training or retrieval tasks.

### Process Overview
- Load `radiomics_features_with_gtv_labels.csv`, which contains extracted radiomics features and tumor subtype labels.
- Use patient split information from `data_split.json` to assign each sample to `train`, `validation`, or `test` sets.

### Radiomics Normalization
- For each radiomics feature (excluding the label), compute the **min/max values from the training set only**.
- Save the normalization statistics to: `/workspace/data/NSCLC/radiomics_features_min_max.json`
  
### Output Files
The following .jsonl files are created in `/workspace/data/NSCLC/`, where each line is a JSON object in the following format:

```json
{
  "id": "LUNG1-001_1",
  "radiomics": {
    "Shape_Elongation": 0.676522,
    "Hist_Entropy": 4.897085,
    ...
  },
  "label": "LCC"
}
```

- `train.jsonl`: : Training set samples
- `val.jsonl`: Validation set samples
- `test.jsonl`: Test set samples
- `total.jsonl`: Full dataset (union of all above)

### Label Distribution
The script also prints out the number of samples per tumor subtype (`SCC`, `LCC`, `NOS`, `ADC`, `NaN`) in each split to help verify class balance.

In [21]:
import os
import pandas as pd
import json
from tqdm import tqdm
from collections import OrderedDict


In [22]:
data_root = '/workspace/data/NSCLC'
data_split_path = os.path.join(data_root, 'data_split.json')
with open(data_split_path, 'r') as f:
    data_split = json.load(f)

train_pats = data_split['train']
val_pats = data_split['validation']
test_pats = data_split['test']

radiomics_features_path = os.path.join(data_root, 'radiomics_features_with_gtv_labels.csv')
radiomics_features = pd.read_csv(radiomics_features_path)

In [None]:
trainset = []
valset = []
testset = []
totalset = []

train_radiomics = pd.DataFrame(None, columns=radiomics_features.columns[1:-1])
val_radiomics = pd.DataFrame(None, columns=radiomics_features.columns[1:-1])
test_radiomics = pd.DataFrame(None, columns=radiomics_features.columns[1:-1])

for idx, row in tqdm(radiomics_features.iterrows(), total=len(radiomics_features)):
    img_id = row['ID']
    pat_id = "_".join(img_id.split('_')[:-1])

    my_data = OrderedDict()
    my_data['id'] = img_id
    my_data['radiomics'] = row.iloc[1:-1].to_dict()
    my_data['label'] = row['GTV_label']

    if pat_id in train_pats:
        trainset.append(my_data)
        totalset.append(my_data)
        train_radiomics.loc[len(train_radiomics)] = row.iloc[1:-1]
    elif pat_id in val_pats:
        valset.append(my_data)
        totalset.append(my_data)
        val_radiomics.loc[len(val_radiomics)] = row.iloc[1:-1]
    elif pat_id in test_pats:
        testset.append(my_data)
        totalset.append(my_data)
        test_radiomics.loc[len(test_radiomics)] = row.iloc[1:-1]
    else:
        print(f'Patient {pat_id} not found in any split, skipping...')

print(f'\nTotal number of patients: {len(radiomics_features)}')
print(f'Train set size: {len(trainset)}')
print(f'Validation set size: {len(valset)}')
print(f'Test set size: {len(testset)}')

radiomics_features_min_max = dict()
for col in train_radiomics.columns:
    radiomics_features_min_max[col] = (train_radiomics[col].min(), train_radiomics[col].max())

radiomics_features_min_max_path = os.path.join(data_root, 'radiomics_features_min_max.json')
with open(radiomics_features_min_max_path, 'w') as f:
    json.dump(radiomics_features_min_max, f, indent=4)

# print(train_radiomics.describe())

train_scc = []
train_lcc = []
train_nos = []
train_adc = []
train_nan = []

for data in trainset:
    if data['label'] == 'SCC':
        train_scc.append(data)
    elif data['label'] == 'LCC':
        train_lcc.append(data)
    elif data['label'] == 'NOS':
        train_nos.append(data)
    elif data['label'] == 'ADC':
        train_adc.append(data)
    else:
        train_nan.append(data)

val_scc = []
val_lcc = []
val_nos = []
val_adc = []
val_nan = []

for data in valset:
    if data['label'] == 'SCC':
        val_scc.append(data)
    elif data['label'] == 'LCC':
        val_lcc.append(data)
    elif data['label'] == 'NOS':
        val_nos.append(data)
    elif data['label'] == 'ADC':
        val_adc.append(data)
    else:
        val_nan.append(data)

test_scc = []
test_lcc = []
test_nos = []
test_adc = []
test_nan = []

for data in testset:
    if data['label'] == 'SCC':
        test_scc.append(data)
    elif data['label'] == 'LCC':
        test_lcc.append(data)
    elif data['label'] == 'NOS':
        test_nos.append(data)
    elif data['label'] == 'ADC':
        test_adc.append(data)
    else:
        test_nan.append(data)

print("\nGTV Subtype Distribution")
print(f'Train SCC: {len(train_scc)} / LCC: {len(train_lcc)} / NOS: {len(train_nos)} / ADC: {len(train_adc)} / NaN: {len(train_nan)}')
print(f'Val SCC: {len(val_scc)} / LCC: {len(val_lcc)} / NOS: {len(val_nos)} / ADC: {len(val_adc)} / NaN: {len(val_nan)}')
print(f'Test SCC: {len(test_scc)} / LCC: {len(test_lcc)} / NOS: {len(test_nos)} / ADC: {len(test_adc)} / NaN: {len(test_nan)}')

# Save the datasets
jsonl_file_path = os.path.join(data_root, 'train.jsonl')
with open(jsonl_file_path, 'w') as f:
    for data in trainset:
        json.dump(data, f)
        f.write('\n')

jsonl_file_path = os.path.join(data_root, 'val.jsonl')
with open(jsonl_file_path, 'w') as f:
    for data in valset:
        json.dump(data, f)
        f.write('\n')

jsonl_file_path = os.path.join(data_root, 'test.jsonl')
with open(jsonl_file_path, 'w') as f:
    for data in testset:
        json.dump(data, f)
        f.write('\n')

jsonl_file_path = os.path.join(data_root, 'total.jsonl')
with open(jsonl_file_path, 'w') as f:
    for data in totalset:
        json.dump(data, f)
        f.write('\n')