<a href="https://colab.research.google.com/github/karencfisher/COVID19/blob/main/notebooks/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import drive
drive.mount('/content/drive')

BASE_PATH = '/content/drive/MyDrive/COVID-19_Radiography_Dataset'
 


!wget https://raw.githubusercontent.com/karencfisher/COVID19/main/tools/util.py

Mounted at /content/drive
--2021-09-01 21:41:24--  https://raw.githubusercontent.com/karencfisher/COVID19/main/tools/util.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5910 (5.8K) [text/plain]
Saving to: ‘util.py’


2021-09-01 21:41:24 (49.1 MB/s) - ‘util.py’ saved [5910/5910]



In [8]:
import os
import random
import shutil
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import cv2
import pandas as pd

from sklearn.metrics import confusion_matrix, roc_curve, roc_auc_score, classification_report
from sklearn.utils import class_weight

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, MaxPool2D, BatchNormalization, Input
from tensorflow.keras.layers import Conv2DTranspose, Concatenate, Activation
from tensorflow.keras.preprocessing.image import ImageDataGenerator, load_img, img_to_array
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import backend as K

import util

### Define blocks

In [12]:
def conv_block(input, num_filters):
  x = Conv2D(num_filters, 3, padding='same')(input)
  x = BatchNormalization()(x)
  x = Activation('relu')(x)

  x = Conv2D(num_filters, 3, padding='same')(x)
  x = BatchNormalization()(x)
  x = Activation('relu')(x)

  return x

def encoder_block(input, num_filters):
  x = conv_block(input, num_filters)
  p = MaxPool2D((2, 2))(x)
  return x, p

def decoder_block(input, skip_features, num_filters):
  x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(input)
  x = Concatenate()([x, skip_features])
  x = conv_block(x, num_filters)
  return x

### Build U-Net model

In [25]:
def build_unet(input_shape, num_layers, min_num_filters=64):
  inputs = Input(input_shape)
  skip_features = []

  s, x = encoder_block(inputs, min_num_filters)
  skip_features.append(s)

  for i in range(1, num_layers):
    num_filters = min_num_filters * 2 ** i
    s, x = encoder_block(x, num_filters)
    skip_features.append(s)

  num_filters = min_num_filters * 2 ** num_layers
  x = conv_block(x, num_filters)

  for i in range(num_layers):
    num_filters /= 2
    s = skip_features.pop()
    x = decoder_block(x, s, num_filters)

  x = Conv2D(1, 1, padding='same')(x)
  outputs = Activation('sigmoid')(x)

  model = Model(inputs, outputs, name='U-Net')
  return model
   

In [26]:
input_shape = (224, 224, 3)
model = build_unet(input_shape, 4)
model.summary()

Model: "U-Net"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_8 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv2d_133 (Conv2D)             (None, 224, 224, 64) 1792        input_8[0][0]                    
__________________________________________________________________________________________________
batch_normalization_126 (BatchN (None, 224, 224, 64) 256         conv2d_133[0][0]                 
__________________________________________________________________________________________________
activation_133 (Activation)     (None, 224, 224, 64) 0           batch_normalization_126[0][0]    
______________________________________________________________________________________________