In [1]:
from flask import Flask,request,jsonify,render_template,redirect
import io
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# from torch_utils import transform_image,get_prediction

In [2]:
LABELS=["Bacterial","Normal","Viral"]

In [3]:
app=Flask(__name__)

In [4]:
model_name = "densenet"

num_classes = 3

feature_extract = True

def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
        for param in model.parameters():
            param.requires_grad = False

def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    model_ft = None
    input_size = 0

    if model_name == "resnet":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 280

    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "inception":
        """ Inception v3
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        set_parameter_requires_grad(model_ft, feature_extract)
        
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
       
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()

    return model_ft, input_size


model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)


#print(model_ft)

PATH="Pattern_model/densenet.pt"

model_ft.load_state_dict(torch.load(PATH))
model_ft.eval()

DenseNet(
  (features): Sequential(
    (conv0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (norm0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu0): ReLU(inplace=True)
    (pool0): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (denseblock1): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (norm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      )
      (denselayer2): _DenseLayer(
        (norm1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu

In [5]:
def transform_image(image_bytes):
    transform = transforms.Compose([transforms.Resize((224,224)),
                                        transforms.ToTensor(),
                                        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    image=image.convert("RGB")
    im=transform(image)
    return im.unsqueeze(0)


def get_prediction(image_tensor):
    with torch.no_grad():
        output=model_ft(image_tensor)
        prediction=output.argmax()
        return prediction

In [6]:
@app.route('/',methods=['GET','POST'])
def index():
    if request.method=='POST':
        file=request.files.get('file')
        if file is None or file.filename=="":
            return jsonify({'error':'no file'})
#     try:
        img_bytes=file.read()
        tensor=transform_image(image_bytes=img_bytes)
        prediction=get_prediction(tensor)
        data={'prediction':LABELS[prediction.item()]}
        filename='../chest_xray/test/'+LABELS[prediction.item()].upper()+'/'+str(file.filename)
        return render_template('index.html',prediction=LABELS[prediction.item()],filename=filename)
    else:
        return render_template('index.html')
#     except:
#         return jsonify({'error':'error during prediction'})

In [7]:
if __name__=="__main__":
    app.run()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:5000/ (Press CTRL+C to quit)
127.0.0.1 - - [28/Oct/2020 00:31:07] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:31:14] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:31:20] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:31:31] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:31:59] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:32:07] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:32:16] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:32:26] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:32:43] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:32:49] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:32:58] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:33:05] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:33:32] "[37mPOST / HTTP/1.1[0m" 200 -
127.0.0.1 - - [28/Oct/2020 00:3