Skip to content

Commit

Permalink
Add Image Resizing and Colab Support
Browse files Browse the repository at this point in the history
Fixes qubvel#412 : Resizing Images to avoid errors
Fixes Colab Bug : Using Segmentation Models
  • Loading branch information
khanfarhan10 committed May 18, 2021
1 parent 94f624b commit ff86938
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions examples/binary segmentation (camvid).ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,21 @@
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\n",
"If dataset needs to be resized use these values MAIN_SIZE_X,MAIN_SIZE_Y = 512,512 else put MAIN_SIZE_X,MAIN_SIZE_Y = None,None\n",
"\"\"\"\n",
"MAIN_SIZE_X,MAIN_SIZE_Y = None,None\n",
"\n",
"\n",
"\n",
"# helper function for data visualization\n",
"def visualize(**images):\n",
" \"\"\"PLot images in one row.\"\"\"\n",
" n = len(images)\n",
" plt.figure(figsize=(16, 5))\n",
" plt.figure(figsize=(20, 12))\n",
" for i, (name, image) in enumerate(images.items()):\n",
" plt.subplot(1, n, i + 1)\n",
" plt.axis('off')\n",
" plt.xticks([])\n",
" plt.yticks([])\n",
" plt.title(' '.join(name.split('_')).title())\n",
Expand Down Expand Up @@ -175,7 +183,11 @@
" # read data\n",
" image = cv2.imread(self.images_fps[i])\n",
" image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n",
" if MAIN_SIZE_X,MAIN_SIZE_Y == None,None:\n",
" image = cv2.resize(image,(MAIN_SIZE_Y,MAIN_SIZE_X))\n",
" mask = cv2.imread(self.masks_fps[i], 0)\n",
" if MAIN_SIZE_X,MAIN_SIZE_Y == None,None:\n",
" mask = cv2.resize(mask,(MAIN_SIZE_Y,MAIN_SIZE_X))\n",
" \n",
" # extract certain classes from mask (e.g. cars)\n",
" masks = [(mask == v) for v in self.class_values]\n",
Expand Down Expand Up @@ -216,7 +228,6 @@
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.indexes = np.arange(len(dataset))\n",
"\n",
" self.on_epoch_end()\n",
"\n",
" def __getitem__(self, i):\n",
Expand All @@ -230,8 +241,8 @@
" \n",
" # transpose list of lists\n",
" batch = [np.stack(samples, axis=0) for samples in zip(*data)]\n",
" \n",
" return batch\n",
" return tuple(batch)\n",
"\n",
" \n",
" def __len__(self):\n",
" \"\"\"Denotes the number of batches per epoch\"\"\"\n",
Expand Down Expand Up @@ -447,6 +458,10 @@
}
],
"source": [
"\"\"\"\n",
"While Using in Colab Use\n",
"%env SM_FRAMEWORK=tf.keras\n",
"\"\"\"\n",
"import segmentation_models as sm\n",
"\n",
"# segmentation_models could also use `tf.keras` if you do not have Keras installed\n",
Expand Down

0 comments on commit ff86938

Please sign in to comment.