In [8]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [18]:
%%capture
!pip install flask-ngrok
!pip install flask_cors
!pip install efficientnet_pytorch
!pip install ratelim

In [19]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import cycle
from scipy import interp
import pickle
from collections import defaultdict, OrderedDict
from tqdm.notebook import tqdm
import warnings
from IPython.display import clear_output
import ratelim


from sklearn.metrics import roc_curve, auc, roc_auc_score, confusion_matrix
from sklearn.model_selection import train_test_split


import skimage.transform
import cv2
from PIL import Image
import scipy.ndimage as ndimage
import scipy.ndimage.filters as filters


import albumentations as A
import albumentations.augmentations.functional as F
import albumentations.augmentations.transforms as T


import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.autograd import Function
from torchvision import models, utils
from efficientnet_pytorch import EfficientNet


from flask_cors import CORS, cross_origin
from flask import Flask, jsonify, request
from flask_ngrok import run_with_ngrok


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLASS_NAMES = np.array(['Atelectasis', 'Cardiomegaly', 'Effusion',
                        'Infiltration', 'Mass', 'Nodule', 'Pneumonia',
                        'Pneumothorax', 'Consolidation', 'Edema', 'Emphysema',
                        'Fibrosis', 'Pleural_Thickening', 'Hernia'])
N_CLASSES = len(CLASS_NAMES)


class DenseNet121(nn.Module):
    """Model modified.
    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.
    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return x

torch.cuda.empty_cache()
model = DenseNet121(N_CLASSES).to(device)

torch.cuda.empty_cache()
CKPT_PATH = '/models/densnet_best.pth'


if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    state_dict = torch.load(CKPT_PATH)
    model.load_state_dict(state_dict)
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")

model.to(device)
model.eval()
print('densnet is ready')

=> loading checkpoint
=> loaded checkpoint
densnet is ready


In [66]:
class EfficientNetModel(nn.Module):

    def __init__(self, out_size):
        super(EfficientNetModel, self).__init__()

        self.efficient_net = EfficientNet.from_pretrained('efficientnet-b3')
        fltrs = self.efficient_net._fc.in_features
        self.efficient_net._fc = nn.Sequential(
                                                nn.Linear(in_features=fltrs, out_features=out_size, bias=True),
                                                nn.Sigmoid()
                                                )
       # self.efficient_net._swish = nn.Softmax()

    def forward(self, x):
        x = self.efficient_net(x)
        return x

model_net = EfficientNetModel(N_CLASSES)
model_net.to(device)

torch.cuda.empty_cache()
CKPT_PATH = '/models/efficientnet_best.pth'


if os.path.isfile(CKPT_PATH):
    print("=> loading checkpoint")
    state_dict = torch.load(CKPT_PATH)
    model_net.load_state_dict(state_dict)
    print("=> loaded checkpoint")
else:
    print("=> no checkpoint found")

model_net.eval()
print('efficientnet is ready')

Loaded pretrained weights for efficientnet-b3
=> loading checkpoint
=> loaded checkpoint
efficientnet is ready


In [67]:
normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                  [0.229, 0.224, 0.225])

transform = transforms.Compose([transforms.Resize(256),
                                transforms.TenCrop(224),
                                transforms.Lambda(
                                    lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])
                                    ),
                                transforms.Lambda(
                                    lambda crops: torch.stack([normalize(crop) for crop in crops])
                                    ),
                                ])

In [68]:
with open('threshold_sc_14_densnet.pkl', 'rb') as f:
    threshold_sc_14 = pickle.load(f)

with open('threshold_sc_14_effnet.pkl', 'rb') as f:
    threshold_sc_14_net = pickle.load(f)

In [69]:
app = Flask(__name__)
cors = CORS(app)
app.config['CORS_HEADERS'] = 'Content-Type'

run_with_ngrok(app)   #starts ngrok when the app is run

def transform_image(image_bytes, transform):    
    image = Image.fromarray(image_bytes).convert('RGB')
    return transform(image).unsqueeze(0)

def get_prediction(model, image_bytes, transform, threshold):
    tensor = transform_image(image_bytes=image_bytes, transform=transform)
    bs, n_crops, c, h, w = tensor.shape
    input_var = tensor.view(-1, c, h, w).to(device)
    output = model(input_var)
    output_mean = output.view(bs, n_crops, -1).mean(1)
    activated_classes = (output_mean.cpu().detach().numpy() > threshold)[0]
    return CLASS_NAMES[activated_classes]

@app.route('/submit_densnet', methods=['POST'])
@ratelim.greedy(10, 5)
@cross_origin()
def predict_densnet():
    if request.method == 'POST':
        predict_dict = {'Labels': [], 'id': []}
        filestr = request.files.getlist('file')
        # convert string data to numpy array
        for i in range(len(filestr)):
            npimg = np.fromstring(filestr[i].read(), np.uint8)
        # convert numpy array to image
            img = cv2.imdecode(npimg, cv2.IMREAD_UNCHANGED)
            class_names = get_prediction(model=model,
                                         image_bytes=img,
                                         transform=transform,
                                         threshold=threshold_sc_14)
            if len(class_names) > 0:
                predict_dict['Labels'].append(', '.join(class_names.tolist()))
            if len(filestr) > 1:
                predict_dict['id'].append(filestr[i].filename)
            else:
                predict_dict['id'] = filestr[i].filename
        return jsonify(predict_dict)
  
@app.route('/submit_effnet', methods=['POST'])
@ratelim.greedy(10, 5)
@cross_origin()
def predict_effnet():
    if request.method == 'POST':
        predict_dict = {'Labels': [], 'id': []}
        filestr = request.files.getlist('file')
        # convert string data to numpy array
        for i in range(len(filestr)):
            npimg = np.fromstring(filestr[i].read(), np.uint8)
            # convert numpy array to image
            img = cv2.imdecode(npimg, cv2.IMREAD_UNCHANGED)
            class_names = get_prediction(model=model_net,
                                         image_bytes=img,
                                         transform=transform,
                                         threshold=threshold_sc_14_net)
            if len(class_names) > 0:
                predict_dict['Labels'].append(', '.join(class_names.tolist()))
            if len(filestr) > 1:
                predict_dict['id'].append(filestr[i].filename)
            else:
                predict_dict['id'] = filestr[i].filename
        return jsonify(predict_dict)

In [70]:
app.run()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)


 * Running on http://dc4c11f6a859.ngrok.io
 * Traffic stats available on http://127.0.0.1:4040


127.0.0.1 - - [19/Oct/2020 14:56:37] "[37mPOST /submit_densnet HTTP/1.1[0m" 200 -
127.0.0.1 - - [19/Oct/2020 14:56:42] "[37mPOST /submit_effnet HTTP/1.1[0m" 200 -
127.0.0.1 - - [19/Oct/2020 14:59:08] "[37mPOST /submit_densnet HTTP/1.1[0m" 200 -
127.0.0.1 - - [19/Oct/2020 14:59:24] "[37mPOST /submit_effnet HTTP/1.1[0m" 200 -
