# Getting Started with Med-ImageTools
This notebook is designed to showcase the core functionalities of Med-ImageTools and help you get started with the package. The notebook will guide you through the following steps:

1. Installing the package
2. Processing a sample TCIA dataset using `AutoPipeline` for deep learning segmentation
   
   i. Understanding outputs from `AutoPipeline` for segmentation
   
   ii. Understanding full outputs from `AutoPipeline`
   
3. *(Optional) Processing a sample TCIA dataset using `AutoPipeline` with radiotherapy data*

## 1. Installing the package

Med-ImageTools is available on PyPI and can be installed using pip:

```
pip install med-imagetools
```

In [None]:
!pip install med-imagetools

In [None]:
try:
    import imgtools
    print("Looks like Med-ImageTools is installed!")
except ImportError as e:
    print(e, "Please install the imgtools package")

## 2. Processing a sample package with the AutoPipeline feature

We're going to start off by defining where the dataset is located, and where you want the processed outputs to be saved. 

In [None]:
INPUT_PATH  ="/path/to/tcia/dataset"
OUTPUT_PATH ="/path/to/output/folder"

### a. AutoPipeline dry-run
Now let's dry-run AutoPipeline to understand it's crawl functionality.  
We'll use the same command, but add the **--dry-run** flag to see what it would do without actually running it.

In [None]:
!autopipeline \
     $INPUT_PATH \
     $OUTPUT_PATH \
     --modalities CT,RTSTRUCT \
     --n_jobs 4 \
     --dry_run

Running `AutoPipeline` creates a `.imgtools` folder in the dataset's parent directory.

```
parent_folder
└───.imgtools
│   ├── imgtools_dataset.csv
│   ├── imgtools_dataset.json
│   └── imgtools_dataset_edges.csv
│ 
└───dataset
    ├── patient-001
    ├── patient-002
    ...
```
 
There are three files in the `.imgtools` folder:

* `imgtools_dataset.csv` contains the metadata for the dataset
* `imgtools_dataset.json` contains the metadata for the dataset in JSON format
* `imgtools_dataset_edges.csv` contains the "edges" for the dataset. 
    * An edge is a DICOM-DICOM pair that are connected based on the metadata.


In [None]:
import pandas as pd
import os

parent_folder   = os.path.dirname(INPUT_PATH)
imgtools_folder = os.path.join(parent_folder, ".imgtools")
imgtools_files  = os.listdir(imgtools_folder)

print("Files generated by Med-ImageTools:\n")
print("\n".join(imgtools_files))

This is what the crawled dataset looks like. 
Each row represents a DICOM series (CT, MRI, PET, RTSTRUCT, SEG, RTDOSE, RTPLAN, etc).

In [None]:
df_crawl = pd.read_csv(os.path.join(imgtools_folder, imgtools_files[0]), index_col=0)
df_crawl.head(5)

This is what the adjoined edges of the dataset looks like. 

In [None]:
df_edges = pd.read_csv(os.path.join(imgtools_folder, imgtools_files[-1]))
df_edges.head(5)

Let's see how many edges of each type we have in this dataset.
There are 8 edge types detected by Med-ImageTools:
* (0) RTDOSE-RTSTRUCT
* (1) RTDOSE-CT
* (2) RTSTRUCT-CT
* (3) RTSTRUCT-PET
* (4) CT-PET
* (5) RTDOSE-PET
* (6) RTPLAN-RTSTRUCT
* (7) SEG-CT

In [None]:
print(df_edges.edge_type.value_counts())

### b. AutoPipeline FULL run
Now let's actually run the AutoPipeline and see what we get!

In [None]:
!autopipeline \
     $INPUT_PATH \
     $OUTPUT_PATH \
     --modalities CT,RTSTRUCT \
     --n_jobs 4 

The output folder will be structured like this:
```
output_folder
├── dataset.csv
├── report.md
│ 
├── 0_patient-001
│   ├── CT
│   │   └── CT.nii.gz
│   └── RTSTRUCT_CT
│       ├── Head.nii.gz
│       ├── Shoulder.nii.gz
│       ├── Knees.nii.gz
│       └── Toes.nii.gz
│ 
├── 1_patient-002
├── 2_patient-003
...
```

Let's see what's inside the folder:

In [None]:
output_folders = [path for path in os.listdir(OUTPUT_PATH) if os.path.isdir(os.path.join(OUTPUT_PATH, path))]
print("Output folders:\n")
print("\n".join(output_folders))


Let's take a look at the `dataset.csv` file. 

In [None]:
df_data = pd.read_csv(os.path.join(OUTPUT_PATH, "dataset.csv"), index_col=0)
df_data.head(5)

There are 3 main types of columns in the dataset.csv file that are important for analysis:
* `patient_ID`: Defines the patient ID excluding index number. 
* `output_folder_{modality}`: Path to the output folder per modality. 
  * For example, for CT,RTSTRUCT modality pairs, the output folders will be `output_folder_CT` and `output_folder_RTSTRUCT_CT`. 
* DICOM imaging metadata: Imaging parameters saved in the metadata. 

In [None]:
df_data.columns.tolist()

If you want to create a PyTorch Dataset/DataLoader using a Med-ImageTools processed dataset, you can use the `dataset.csv` to easily refer to the data.

Here's an example of what a PyTorch Dataset/DataLoader might look like:

```python
class MedImageToolsDataset(Dataset):
    def __init__(self,
                 data_folder, 
                 roi="GTV"):
        self.data_dir = data_folder
        self.data_df  = pd.read_csv(os.path.join(data_folder, "dataset.csv"))
        self.roi      = roi

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        # get the row of the dataframe
        row = self.data_df.iloc[idx]
        
        # get image and mask
        img = sitk.ReadImage(os.path.join(self.data_dir, row["output_folder_CT"], "CT.nii.gz"))
        mask = sitk.ReadImage(os.path.join(self.data_dir, row["output_folder_RTSTRUCT_CT"], f"{self.roi}.nii.gz"))
        
        # return the pair!
        return img, mask

# Create a DataLoader
dataloader = DataLoader(MedImageToolsDataset(dataset), batch_size=32)
```

A less-simplified version of the code with safer error handling and more comments looks like this. Try it out!

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataloader import default_collate
import SimpleITK as sitk
import re
import pathlib

# Define the Dataset class
class MedImageToolsDataset(Dataset):
    def __init__(self, 
                 data_folder, 
                 roi="GTV"):
        """
        Parameters
        ----------
        data_folder : str
            Path to the folder containing the dataset.csv file and the output folders
        roi : str
            Name of the Region of Interest (ROI) to extract from the RTSTRUCT masks. 
            Regex expressions are accepted.
        """

        if not os.path.exists(data_folder):
            raise FileNotFoundError(f"Folder {data_folder} does not exist")
        self.data_dir = data_folder

        # Load the dataset.csv file
        data_df_path = os.path.join(data_folder, "dataset.csv")
        if not os.path.exists(data_df_path):
            raise FileNotFoundError(f"File dataset.csv not found in {data_folder}")
        self.data_df  = pd.read_csv(data_df_path)

        self.output_cols   = [col for col in self.data_df.columns if col.startswith("output_folder_")]
        self.roi           = roi

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        row = self.data_df.iloc[idx]

        for col in self.output_cols:
            if 'folder_CT' in col:
                img_path = pathlib.Path(self.data_dir, row[col], "CT.nii.gz").as_posix()
                if os.path.exists(img_path):
                    img = sitk.ReadImage(img_path)
                else:
                    raise FileNotFoundError(f"CT image not found at {img_path}")
                    break
            elif 'RTSTRUCT' in col:
                mask_folder_path = pathlib.Path(self.data_dir, row[col]).as_posix()
                if os.path.exists(mask_folder_path):
                    for mask_file in os.listdir(mask_folder_path):
                        roi_name = mask_file.split(".")[0]
                        if re.fullmatch(self.roi, roi_name, flags=re.IGNORECASE) or self.roi in roi_name:
                            mask = sitk.ReadImage(os.path.join(self.data_dir, row[col], mask_file))
                            break
                else:
                    continue
                if 'mask' not in locals():
                    raise FileNotFoundError(f"Mask of {self.roi} not found in {row[col]}")
        
        if 'img' in locals() and 'mask' in locals():
            return img, mask
        else:
            return None
        
# Define a collate function
def my_collate(batch):
    "Puts each data field into a tensor with outer dimension batch size"
    return [x for x in batch if x is not None]
    # batch = filter(lambda x: x is not None, batch)
    # return default_collate(batch)

# Create a DataLoader
dataloader = DataLoader(MedImageToolsDataset(OUTPUT_PATH, 
                                             roi="LUNG_L"), 
                        batch_size=4, 
                        collate_fn=my_collate)

# Print the first batch of data
batch = next(iter(dataloader))
print("Batch size:", len(batch))

img, mask = batch[0]
print(f"Image: {img.GetSize()}")
print(f"Mask: {mask.GetSize()}")