In [1]:
%matplotlib inline

import os
import pandas as pd
from sklearn.model_selection import ShuffleSplit
import numpy as np
import torch
import random

from myUtils import visualize_random_image, make_validation, get_model, get_test_predictions, get_sub_list, write_submission, get_aug
from dataset import parse_dataset, RSNADataset, get_dicom_fps, RSNAAlbumentationsDataset, convertAnnotations, visualize

from albumentations import (
    BboxParams,
    HorizontalFlip,
    VerticalFlip,
    Resize,
    CenterCrop,
    RandomCrop,
    Crop,
    Compose,
    Flip,
    Rotate,
    RandomSizedBBoxSafeCrop
)

In [2]:
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

set_seed(42)

## Загрузим датасет

Данные скачивались с https://www.kaggle.com/c/rsna-pneumonia-detection-challenge/data . Для дальнейшей работы укажите в переменной data_path локальный путь к данным.

In [3]:
data_path = "/mnt/ubuntu_hdd/rsna"
train_img_path = os.path.join(data_path, "stage_2_train_images")
annotations = pd.read_csv(os.path.join(data_path, "stage_2_train_labels.csv"))

image_fps, image_annotations = parse_dataset(train_img_path, anns=annotations)

## Обучение

In [8]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
num_classes = 2
img_size = 1024
print("your device: {}".format(device))

# Аугментация
# aug = get_aug([Flip(p=0.5), Rotate(limit=15, p=0.5), RandomSizedBBoxSafeCrop(img_size, img_size, p=0.5)])
aug = get_aug([Flip(p=0.5), Rotate(limit=15, p=0.5)])

# Для аккуратной валидации используйте больше разбиений
n_splits = 1
cv = list(ShuffleSplit(n_splits=n_splits, random_state=15, test_size=0.2).split(np.zeros((len(image_fps), 1)),
                                                                   np.zeros(len(image_fps))))

params = {"img_size": img_size, "num_classes": num_classes, "num_epochs": 5, "device": device}

images_files = np.array(image_fps)

make_validation(images_files=images_files, image_annotations=image_annotations, cv=cv, params=params, dataset="RSNADataset", transformations=None)
# make_validation(images_files=images_files, image_annotations=image_annotations, cv=cv, params=params)

your device: cuda
epoch 0
Epoch: [0]  [ 0/58]  eta: 0:01:33  lr: 0.005000  loss: 0.7534 (0.7534)  loss_classifier: 0.6094 (0.6094)  loss_box_reg: 0.0571 (0.0571)  loss_objectness: 0.0642 (0.0642)  loss_rpn_box_reg: 0.0227 (0.0227)  time: 1.6107  data: 0.5740  max mem: 4498
Epoch: [0]  [10/58]  eta: 0:00:48  lr: 0.005000  loss: 0.3647 (0.3961)  loss_classifier: 0.1668 (0.2105)  loss_box_reg: 0.0979 (0.0957)  loss_objectness: 0.0690 (0.0719)  loss_rpn_box_reg: 0.0160 (0.0179)  time: 1.0114  data: 0.0659  max mem: 4759
Epoch: [0]  [20/58]  eta: 0:00:37  lr: 0.005000  loss: 0.3228 (0.3438)  loss_classifier: 0.1374 (0.1711)  loss_box_reg: 0.1030 (0.1048)  loss_objectness: 0.0406 (0.0518)  loss_rpn_box_reg: 0.0150 (0.0161)  time: 0.9517  data: 0.0153  max mem: 4760
Epoch: [0]  [30/58]  eta: 0:00:27  lr: 0.005000  loss: 0.2760 (0.3223)  loss_classifier: 0.1247 (0.1540)  loss_box_reg: 0.1132 (0.1084)  loss_objectness: 0.0237 (0.0440)  loss_rpn_box_reg: 0.0150 (0.0159)  time: 0.9512  data: 0.01

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

HBox(children=(IntProgress(value=0, max=34), HTML(value='')))

## Test

In [9]:
model = get_model(num_classes=num_classes)
model.load_state_dict(torch.load("fold_num_0_model"))
model.to(device)

test_images = get_dicom_fps(os.path.join(data_path, "stage_2_test_images"))

imgs_info = get_test_predictions(model, test_images, device, img_size)

# min_conf -- минимальный порог уверенности для того, чтобы считать объект пневмонией
sub_list = get_sub_list(imgs_info, img_size, min_conf=0.7)

write_submission(sub_list)

HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))


