In [1]:
%matplotlib inline

import os
import pandas as pd
from sklearn.model_selection import ShuffleSplit
import numpy as np
import torch
import random

from myUtils import visualize_random_image, make_validation, get_model, get_test_predictions, get_sub_list, write_submission, get_aug
from dataset import parse_dataset, RSNADataset, get_dicom_fps, RSNAAlbumentationsDataset, convertAnnotations, visualize

from albumentations import (
    BboxParams,
    HorizontalFlip,
    VerticalFlip,
    Resize,
    CenterCrop,
    RandomCrop,
    Crop,
    Compose,
    Flip,
    Rotate,
    RandomSizedBBoxSafeCrop
)

In [2]:
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(42)

## Загрузим датасет

Данные скачивались с https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data . Для дальнейшей работы укажите в переменной data_path локальный путь к данным.

In [3]:
data_path = "/mnt/ubuntu_hdd/rsna"
train_img_path = os.path.join(data_path, "stage_2_train_images")
annotations = pd.read_csv(os.path.join(data_path, "stage_2_train_labels.csv"))

image_fps, image_annotations = parse_dataset(train_img_path, anns=annotations)

## Обучение нескольких моделей

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_classes = 2
img_size = 1024
print("your device: {}".format(device))

# Аугментация
aug = get_aug([Flip(p=0.5), Rotate(limit=15, p=0.5)])

# Для аккуратной валидации используйте больше разбиений
n_splits = 2
cv = list(ShuffleSplit(n_splits=n_splits, random_state=15, test_size=0.2).split(np.zeros((len(image_fps), 1)),
                                                                   np.zeros(len(image_fps))))

params = {"img_size": img_size, "num_classes": num_classes, "num_epochs": 5, "device": device}

images_files = np.array(image_fps)

make_validation(images_files=images_files, image_annotations=image_annotations, cv=cv, params=params, dataset="RSNADataset", transformations=None)
# make_validation(images_files=images_files, image_annotations=image_annotations, cv=cv, params=params)

your device: cuda
epoch 0
Epoch: [0]  [ 0/58]  eta: 0:02:02  lr: 0.005000  loss: 0.6860 (0.6860)  loss_classifier: 0.4618 (0.4618)  loss_box_reg: 0.0316 (0.0316)  loss_objectness: 0.1736 (0.1736)  loss_rpn_box_reg: 0.0190 (0.0190)  time: 2.1153  data: 0.5595  max mem: 3854
Epoch: [0]  [10/58]  eta: 0:00:54  lr: 0.005000  loss: 0.3513 (0.3762)  loss_classifier: 0.1597 (0.1879)  loss_box_reg: 0.0985 (0.0994)  loss_objectness: 0.0676 (0.0712)  loss_rpn_box_reg: 0.0190 (0.0177)  time: 1.1257  data: 0.0609  max mem: 4124
Epoch: [0]  [20/58]  eta: 0:00:41  lr: 0.005000  loss: 0.2919 (0.3314)  loss_classifier: 0.1262 (0.1587)  loss_box_reg: 0.1107 (0.1049)  loss_objectness: 0.0319 (0.0511)  loss_rpn_box_reg: 0.0144 (0.0167)  time: 1.0288  data: 0.0121  max mem: 4124
Epoch: [0]  [30/58]  eta: 0:00:29  lr: 0.005000  loss: 0.2890 (0.3209)  loss_classifier: 0.1259 (0.1503)  loss_box_reg: 0.1192 (0.1120)  loss_objectness: 0.0248 (0.0424)  loss_rpn_box_reg: 0.0144 (0.0163)  time: 1.0331  data: 0.01

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

## Test

In [93]:
from torchvision.ops import nms
def get_test_predictions_nms(models, test_images, device, img_size):
    modelListRes = []
    for m in models:
        modelListRes.append(get_test_predictions(m, test_images, device, img_size))

    resDict = {}
    for modelRes in modelListRes:
        for res in modelRes:
            resDict[res['patient_id']] = resDict.get(res['patient_id'], {'patient_id': res['patient_id']})
            if resDict[res['patient_id']].get('boxes', np.array([])).any():
                resDict[res['patient_id']]['boxes'] = np.vstack((resDict[res['patient_id']]['boxes'], res['boxes']))
            else:
                resDict[res['patient_id']]['boxes'] = res['boxes']
            if resDict[res['patient_id']].get('scores', np.array([])).any():
                resDict[res['patient_id']]['scores'] = np.concatenate((resDict[res['patient_id']]['scores'], res['scores']))
            else:
                resDict[res['patient_id']]['scores'] = res['scores']

    a = []
    for key, value in resDict.items(): 
        boxes = torch.tensor(value['boxes'])
        scores = torch.tensor(value['scores'])
        nmsResult = nms(boxes, scores, 0.3)
        imgInfo = {'patient_id':key, 'boxes':[], 'scores':[]}
        for i in range(len(nmsResult)):
            imgInfo['boxes'].append(boxes[nmsResult[i]].tolist())
            imgInfo['scores'].append(scores[nmsResult[i]].item())
            if i == 2:
                break
        imgInfo['boxes'] = np.array(imgInfo['boxes'])
        a.append(imgInfo)
    return a

In [97]:
model1 = get_model(num_classes=num_classes)
model1.load_state_dict(torch.load("fold_num_0_model"))
model1.to(device)

model2 = get_model(num_classes=num_classes)
model2.load_state_dict(torch.load("fold_num_1_model"))
model2.to(device)

test_images = get_dicom_fps(os.path.join(data_path, "stage_2_test_images"))

imgs_info = get_test_predictions_nms([model1, model2], test_images[:2], device, img_size)
imgs_info
# min_conf -- минимальный порог уверенности для того, чтобы считать объект пневмонией
sub_list = get_sub_list(imgs_info, img_size, min_conf=0.7)

# write_submission(sub_list)

HBox(children=(IntProgress(value=0, max=2), HTML(value='')))





HBox(children=(IntProgress(value=0, max=2), HTML(value='')))

[{'patient_id': '04384b80-0b9e-4744-85e4-f5184fc073e9',
  'boxes': [[171.83824157714844,
    572.594482421875,
    366.2283630371094,
    763.4937133789062],
   [578.368408203125, 382.0631103515625, 822.3856201171875, 811.136962890625],
   [125.1860122680664, 312.7394714355469, 494.4402160644531, 924.36083984375]],
  'scores': [0.7624905705451965, 0.23498092591762543, 0.07064270973205566]},
 {'patient_id': '1189f742-0450-455c-8311-192da30f23b8',
  'boxes': [[250.8720245361328,
    495.53759765625,
    458.2784423828125,
    769.0703125],
   [642.4140625, 600.7378540039062, 842.4464111328125, 821.7581176757812],
   [616.894775390625,
    251.45193481445312,
    866.589599609375,
    754.3848876953125]],
  'scores': [0.9274036288261414, 0.8253685235977173, 0.5534019470214844]}]