# Attention region around path

Note that `get_transforms()` only applies to the training set [Link](https://fastai1.fast.ai/vision.transform.html).

[Link](https://forums.fast.ai/t/saving-segmentation-masks-tutorial/62898) for saving masks.

## Set-up

In [1]:
!pip install "torch==1.4" "torchvision==0.5.0"

Collecting torch==1.4
[?25l  Downloading https://files.pythonhosted.org/packages/24/19/4804aea17cd136f1705a5e98a00618cb8f6ccc375ad8bfa437408e09d058/torch-1.4.0-cp36-cp36m-manylinux1_x86_64.whl (753.4MB)
[K     |████████████████████████████████| 753.4MB 21kB/s 
[?25hCollecting torchvision==0.5.0
[?25l  Downloading https://files.pythonhosted.org/packages/7e/90/6141bf41f5655c78e24f40f710fdd4f8a8aff6c8b7c6f0328240f649bdbe/torchvision-0.5.0-cp36-cp36m-manylinux1_x86_64.whl (4.0MB)
[K     |████████████████████████████████| 4.0MB 24.0MB/s 
Installing collected packages: torch, torchvision
  Found existing installation: torch 1.7.0+cu101
    Uninstalling torch-1.7.0+cu101:
      Successfully uninstalled torch-1.7.0+cu101
  Found existing installation: torchvision 0.8.1+cu101
    Uninstalling torchvision-0.8.1+cu101:
      Successfully uninstalled torchvision-0.8.1+cu101
Successfully installed torch-1.4.0 torchvision-0.5.0


In [2]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
from fastai.vision import *
from PIL import Image as PImage

In [4]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

## Model

In [5]:
path = '/gdrive/My Drive/Labeling/Segmentation'
image_path = os.path.join(path, 'Data', 'Original')
mask_path = os.path.join(path, 'Data', 'EnumMasks', 'png_0_1')

get_image_label = lambda x: os.path.join(mask_path, f'{x.stem}_P.png')

In [None]:
# Image size
img_f = get_image_files(image_path)[0]
mask = open_mask(get_image_label(img_f))
src_size = np.array(mask.shape[1:])
size = src_size//2

# Batch size
bs = 1

# Class names
codes = np.loadtxt(os.path.join(path, 'codes.txt'), dtype=str)

In [None]:
# One image in the training set
np.random.seed(42)
data_test = (SegmentationItemList
            .from_folder(image_path)
            .filter_by_func(lambda fname: os.path.split(fname)[1] == '964868548s701ms.jpg')
            .split_none()
            .label_from_func(get_image_label, classes=codes)
            .transform(get_transforms(do_flip=False), size=size, tfm_y=True)
            .databunch(bs=1)
            .normalize(imagenet_stats))

In [None]:
train_x = data_test.train_ds.x
train_y = data_test.train_ds.y
valid_x = data_test.valid_ds.x
valid_y = data_test.valid_ds.y

print(len(train_x), len(train_y))
print(len(valid_x), len(valid_y))

1 1
1 1


In [None]:
def acc_segmentation(input, target):
  target = target.squeeze(1)
  mask = target >= 0 # there is no code for 2 so this shouldn't do anything
  return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

def dice_iou(input, target):
  return dice(input, target, iou=True)

metrics = [acc_segmentation, dice_iou]

In [None]:
learn = unet_learner(data_test, models.resnet34, metrics=metrics, wd=1e-2)

In [None]:
learn.load(os.path.join(path, 'v4', 'v4-stage-2'))

## Create masks for all images in the dataset

In [6]:
all_image_path = '/gdrive/My Drive/Labeling/Model/Images'
image_categories = pd.read_csv('/gdrive/My Drive/Labeling/PathDetectionClassifier/PredictImages/validated_image_categories_2.csv')
path_category = image_categories[image_categories['category'] == 'Path']
path_images = np.unique(path_category['image'])

pred_path = '/gdrive/My Drive/Labeling/Segmentation/Predictions/Masks'
attention_path = '/gdrive/My Drive/Labeling/Segmentation/Predictions/AttentionRegion'
overlay_path = '/gdrive/My Drive/Labeling/Segmentation/Predictions/Overlay'

get_mask = lambda x: f'{pred_path}/{os.path.splitext(x)[0]}_Pred.png'
get_attention = lambda x: f'{attention_path}/255/{os.path.splitext(x)[0]}.png'
get_overlay = lambda x: f'{overlay_path}/{os.path.splitext(x)[0]}_Overlay.png'

#### Make and save predictions

In [None]:
count = 0
for filename in path_images:
  infile = os.path.join(all_image_path, filename)
  
  count += 1
  if count % 10 == 0:
    print(count, '/', len(path_images))

  if os.path.isdir(infile) or os.path.splitext(infile)[1] != '.jpg':
    print('Invalid image:', infile)
    continue
  
  if os.path.isfile(get_mask(filename)) and os.path.isfile(get_overlay(filename)):
    print('Mask and overlay already generated for', filename)
    continue
  
  # Make prediction
  img = open_image(infile)
  pred = learn.predict(img)

  # Convert mask to [0, 255]
  mask_pred = pred[2].argmax(dim=0)
  mask_pred = mask_pred.numpy()
  rescaled = (255.0/mask_pred.max() * mask_pred - mask_pred.min())
  
  # Resize to 2x
  im = PImage.fromarray(rescaled)
  im = im.convert('L')
  im = im.resize((im.width * 2, im.height * 2))

  # Resize changes some pixels to grey, convert back to [0, 255]
  pixels = np.array(im)
  threshold = 255/2
  pixels[pixels <= threshold] = 0
  pixels[pixels > threshold] = 255

  # Save mask
  im = PImage.fromarray(pixels)
  im.save(get_mask(filename))

  # Save image with overlay
  m = open_mask(get_mask(filename))
  img.show(y=m, figsize=(10, 10))
  plt.savefig(get_overlay(filename), bbox_inches='tight', pad_inches=0)
  plt.clf()
  plt.close()

#### Add attention region to path images using the masks

In [7]:
overlay_images = pd.read_csv(os.path.join(path, 'Predictions', 'overlay_images.csv'))
valid_images = overlay_images[overlay_images['Value'].apply(lambda value: value == 0 or value == 1)]
valid_images = np.unique(valid_images['Image'])

In [8]:
count = 0
for filename in path_images:
  count += 1
  if count % 10 == 0:
    print(count, '/', len(path_images))

  if filename not in valid_images:
    print('Image', filename, 'did not have a clearly predicted path.')
    continue
  
  if os.path.isfile(get_attention(filename)):
    print('Attention region already generated for', filename)
    continue

  infile = os.path.join(all_image_path, filename)
  if os.path.isdir(infile) or os.path.splitext(infile)[1] != '.jpg':
    print('Invalid image:', infile)
    continue

  im = PImage.open(infile)
  mask = PImage.open(get_mask(filename))

  # Everywhere mask is 255, set im to 0 (or 255?)
  # This can be an experiment
  im_pixels = np.array(im)
  mask_pixels = np.array(mask)
  
  is_other = (mask_pixels == 255)
  im_pixels[is_other] = 255

  im = PImage.fromarray(im_pixels)
  im.save(get_attention(filename))

Image 964868548s701ms.jpg did not have a clearly predicted path.
Image 964868549s702ms.jpg did not have a clearly predicted path.
Image 964868568s721ms.jpg did not have a clearly predicted path.
Image 964868569s722ms.jpg did not have a clearly predicted path.
Image 964868571s724ms.jpg did not have a clearly predicted path.
Image 964868572s725ms.jpg did not have a clearly predicted path.
Image 964868573s726ms.jpg did not have a clearly predicted path.
Image 964868575s728ms.jpg did not have a clearly predicted path.
Image 964868576s729ms.jpg did not have a clearly predicted path.
10 / 7070
Image 964868577s730ms.jpg did not have a clearly predicted path.
Image 964868578s697ms.jpg did not have a clearly predicted path.
Image 964868620s706ms.jpg did not have a clearly predicted path.
Image 964868622s708ms.jpg did not have a clearly predicted path.
Image 964868624s710ms.jpg did not have a clearly predicted path.
Image 964868627s713ms.jpg did not have a clearly predicted path.
Image 964868628