Skip to content

Commit

Permalink
update segmentation example
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed Apr 1, 2024
1 parent 2d904d6 commit 37c7113
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
28 changes: 17 additions & 11 deletions examples/Liver-tumor-segmentation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Liver tumor segmentation with nitrain\n",
"# Liver segmentation with nitrain\n",
"\n",
"![image.png](attachment:image.png)\n",
"\n",
"This example shows you how to train a model to perform liver tumor segmentation using nitrain. It is a classic example of medical image segmentation. \n",
"This example shows you how to train a model to perform liver segmentation using nitrain. It is a classic example of medical image segmentation. \n",
"\n",
"We will create a model with keras and do everything else (data sampling + augmentation, training, explaining results) with nitrain.\n",
"\n",
"## About the data\n",
"\n",
"The dataset can be downloaded from the [Liver Tumor Segmentation](https://www.kaggle.com/datasets/andrewmvd/liver-tumor-segmentation/data) dataset on Kaggle. It is about 5 GB in size and contains 130 CT scans of the liver along with associated segmentation images where tumors have been identified.\n",
"The dataset can be downloaded from the [Liver Tumor Segmentation](https://www.kaggle.com/datasets/andrewmvd/liver-tumor-segmentation/data) dataset on Kaggle. It is about 5 GB in size and contains 130 CT scans of the liver along with associated segmentation images identifying the liver and also tumors within the liver. We will only use the liver segmentation for this example.\n",
"\n",
"To run this example, download the dataset (\"archive.zip\") and unpack it onto your desktop. Then we are ready to go!"
]
Expand Down Expand Up @@ -100,23 +100,27 @@
"\n",
"### Applying fixed transforms\n",
"\n",
"We don't need to do much pre-processing of the images to get them into a format ready for training. To make things easier for the model, we will normalize the intensity of the input images to be between 0 and 1. Additionally, we will downsample the images to a smaller size to make training go faster."
"We don't need to do much pre-processing of the images to get them into a format ready for training. To make things easier for the model, we will normalize the intensity of the input image to be between 0 and 1. \n",
"\n",
"The segmentation images have a value of 1 if the voxel is part of the liver and a value of 2 if the voxel is part of a tumor in the liver. The segmentation image has a value of 0 otherwise (i.e., any part of the image that is non-liver). Since we only care about the liver segmentation, we can add a custom transform to basically disregard the extra information about tumors.\n",
"\n",
"Additionally, we will downsample both the input image and its segmentation to a smaller size to make training go faster. Finally, we will reorient the images so that the correct view of the image along the inferior - posterior axis is used to train the model. You can think of this as cross-sections of the abdomen going successively down the spine."
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "FolderDataset.__init__() got an unexpected keyword argument 'co_transforms'",
"ename": "NameError",
"evalue": "name 'FolderDataset' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnitrain\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m transforms \u001b[38;5;28;01mas\u001b[39;00m tx\n\u001b[0;32m----> 2\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mFolderDataset\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m~/Desktop/kaggle-liver-ct\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mvolumes/volume-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 4\u001b[0m \u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m{\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mpattern\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43msegmentations/segmentation-\u001b[39;49m\u001b[38;5;132;43;01m{id}\u001b[39;49;00m\u001b[38;5;124;43m.nii\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m}\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 5\u001b[0m \u001b[43m \u001b[49m\u001b[43mx_transforms\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mtx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRangeNormalize\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mco_transforms\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m[\u001b[49m\u001b[43mtx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mResample\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m128\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m64\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mTypeError\u001b[0m: FolderDataset.__init__() got an unexpected keyword argument 'co_transforms'"
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mnitrain\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m transforms \u001b[38;5;28;01mas\u001b[39;00m tx\n\u001b[0;32m----> 2\u001b[0m dataset \u001b[38;5;241m=\u001b[39m \u001b[43mFolderDataset\u001b[49m(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m~/Desktop/kaggle-liver-ct\u001b[39m\u001b[38;5;124m'\u001b[39m,\n\u001b[1;32m 3\u001b[0m x\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpattern\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mvolumes/volume-\u001b[39m\u001b[38;5;132;01m{id}\u001b[39;00m\u001b[38;5;124m.nii\u001b[39m\u001b[38;5;124m'\u001b[39m},\n\u001b[1;32m 4\u001b[0m y\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mpattern\u001b[39m\u001b[38;5;124m'\u001b[39m: \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msegmentations/segmentation-\u001b[39m\u001b[38;5;132;01m{id}\u001b[39;00m\u001b[38;5;124m.nii\u001b[39m\u001b[38;5;124m'\u001b[39m},\n\u001b[1;32m 5\u001b[0m x_transforms\u001b[38;5;241m=\u001b[39m[tx\u001b[38;5;241m.\u001b[39mRangeNormalize(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m1\u001b[39m)],\n\u001b[1;32m 6\u001b[0m y_transforms\u001b[38;5;241m=\u001b[39m[tx\u001b[38;5;241m.\u001b[39mCustomFunction(\u001b[38;5;28;01mlambda\u001b[39;00m x: x \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m)],\n\u001b[1;32m 7\u001b[0m co_transforms\u001b[38;5;241m=\u001b[39m[tx\u001b[38;5;241m.\u001b[39mResample((\u001b[38;5;241m128\u001b[39m, \u001b[38;5;241m128\u001b[39m, \u001b[38;5;241m64\u001b[39m))])\n",
"\u001b[0;31mNameError\u001b[0m: name 'FolderDataset' is not defined"
]
}
],
Expand All @@ -126,7 +130,9 @@
" x={'pattern': 'volumes/volume-{id}.nii'},\n",
" y={'pattern': 'segmentations/segmentation-{id}.nii'},\n",
" x_transforms=[tx.RangeNormalize(0, 1)],\n",
" co_transforms=[tx.Resample((128, 128, 64))])"
" y_transforms=[tx.CustomFunction(lambda x: x == 1)],\n",
" co_transforms=[tx.Resample((128, 128, 64)),\n",
" tx.Reorient('IPR')])"
]
},
{
Expand Down
9 changes: 7 additions & 2 deletions nitrain/transforms/structural_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __call__(self, image, co_image=None):
image = ants.resample_image(image, self.params, not self.use_spacing, interpolation_value)

if co_image is not None:
co_image = ants.resample_image(image, self.params, not self.use_spacing, interpolation_value)
co_image = ants.resample_image(co_image, self.params, not self.use_spacing, interpolation_value)
return image, co_image

return image
Expand Down Expand Up @@ -65,9 +65,14 @@ class Reorient(BaseTransform):
def __init__(self, orientation='RAS'):
self.orientation = orientation

def __call__(self, image):
def __call__(self, image, co_image=None):
image = ants.reorient_image2(image, self.orientation)

if co_image is not None:
co_image = ants.reorient_image2(co_image, self.orientation)
return image, co_image

return image

class Slice(BaseTransform):
"""
Expand Down

0 comments on commit 37c7113

Please sign in to comment.