In [1]:
import statistics
import gradio as gr
import os
import json
import cv2    

import torch
assert torch.cuda.is_available(), "CUDA not available"
import ultralytics
from ultralytics import YOLO, settings

In [2]:
weights = "runs/detect/train15/weights/best.pt"
model = YOLO(weights, task="detect")

In [3]:
def flip_image(image):
    return [
        image,
        cv2.flip(image, 0),
        cv2.flip(image, 1),
        cv2.flip(cv2.flip(image, 1), 0),
    ]

def pred_to_image(pred):
    img_path = os.path.join(pred.save_dir, pred.path)
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

def preds_to_scores(predictions):
    scores_list = [[],[]] # [negatives, positives]
    
    for pred in (predictions):
        if len(pred.boxes) > 0:
            for box in pred.boxes:
                scores_list[int(box.cls)].append(float(box.conf))

    views = int(len(predictions) / 2)

    print(scores_list)
    neg_count = len(scores_list[0])
    neg_max = max(scores_list[0], default=0.0)
    neg_avg = 0 if neg_count == 0 else statistics.mean(scores_list[0])

    print(f'\nC:\n{neg_count}\nV:\n{views}\nM:\n{neg_max}')
    neg_weighted_avg = (neg_count / views) * neg_max

    pos_count = len(scores_list[1])
    pos_max = max(scores_list[1], default=0.0)
    pos_avg = 0 if pos_count == 0 else statistics.mean(scores_list[1])
    pos_weighted_avg = (pos_count / views) * pos_max

    diagnosis = 1 if pos_max > 0.25 or pos_count >= min(2, views) else 0
    
    return {
        'views': views,
        'neg_count': neg_count,
        'neg_max': neg_max,
        'neg_avg': neg_avg,
        'neg_weighted_avg': neg_weighted_avg,
        'pos_count': pos_count,
        'pos_max': pos_max,
        'pos_avg': pos_avg,
        'pos_weighted_avg': pos_weighted_avg,
        'diagnosis': diagnosis
    }

def multi_step_inference(model, image, conf_neg, conf_pos, proj):
    images = flip_image(image)

    negatives = model.predict(images, conf=conf_neg, classes=[0], iou=0.4, save=True, project=proj, name='negatives')
    positives = model.predict(images, conf=conf_pos, classes=[1], iou=0.4, save=True, project=proj, name='positives')

    results = negatives + positives
    scores = preds_to_scores(results)

    imgs = [pred_to_image(pred) for pred in results]

    return (imgs, scores)

In [4]:
def diagnose_patient(upload_filepath, conf_neg, conf_pos):
    image = cv2.imread(upload_filepath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    (images, scores) = multi_step_inference(model, image, conf_neg, conf_pos, "runs/detect/final")

    diagnosis = [
        "Congratulations, no tumor was detected!",
        "Unfortunately, we found a tumor :("
    ][scores["diagnosis"]]
    return [
        scores,
        diagnosis,
        images[:int(len(images) / 2)],
        images[int(len(images) / 2):]
    ]

with gr.Blocks(fill_height=True) as app:
    gr.Markdown("Get your free diagnosis here!")

    with gr.Row(equal_height=True):
        with gr.Column():
            upload = gr.File(label="Upload an MRI of your brain")

            conf_neg = gr.Slider(0, 1, value=0.5, label="Neg conf")
            conf_pos = gr.Slider(0, 1, value=0.25, label="Pos conf")

            btn = gr.Button(value="Scan")

        with gr.Column():
            scores = gr.Textbox(label="Scores")
            diagnosis = gr.Textbox(label="Diagnosis")
    with gr.Row(equal_height=True):
        with gr.Column():

            neg_gallery = gr.Gallery(
                label="Looking for negatives",
                columns=[2],
                rows=[2],
                object_fit="contain",
                height="auto"
            )
            
        with gr.Column():
            pos_gallery = gr.Gallery(
                label="Looking for positives",
                columns=[2],
                rows=[2],
                object_fit="contain",
                height="auto"
            )

    
    btn.click(
        fn=diagnose_patient,
        inputs=[upload, conf_neg, conf_pos],
        outputs=[scores, diagnosis, neg_gallery, pos_gallery]
    )    


In [5]:
app.launch()

* Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.


