In [1]:
import glob
import os
from tqdm.auto import tqdm
from multiprocessing import Pool, cpu_count
import cv2
import time
import argparse
import logging
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
torch.backends.cudnn.benchmark = True
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm


from timm.models import create_model, apply_test_time_pool
from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging

KeyboardInterrupt: 

In [None]:
VALID = True # for validation
TEST = True # for submission

In [None]:
def extract_images(video_path, out_dir):
    video_name = os.path.basename(video_path).split('.')[0]
    cam = cv2.VideoCapture(video_path)
    print(video_path)
    frame_count = 1
    while True:
        successed, img = cam.read()
        if not successed:
            break
        outfile = f'{out_dir}/{video_name}-{frame_count:06}.jpg'
        img = cv2.resize(img, dsize=IMG_SIZE)
        cv2.imwrite(outfile, img)
        #print(outfile)
        frame_count += 1

IMG_SIZE = (456, 456)

if TEST:
    OUT_DIR = '../work/extracted_images_test'
    IN_DIR = '../input/dfl-bundesliga-data-shootout/test'
    IN_VIDEOS = sorted(glob.glob('../input/dfl-bundesliga-data-shootout/test/*'))
    !mkdir -p $OUT_DIR
    for video_path in IN_VIDEOS:
        extract_images(video_path, OUT_DIR)


if VALID:
    OUT_DIR = '../work/extracted_images_train'
    IN_DIR = '../input/dfl-bundesliga-data-shootout/train'
    IN_VIDEOS = ['../input/dfl-bundesliga-data-shootout/train/3c993bd2_0.mp4','../input/dfl-bundesliga-data-shootout/train/3c993bd2_1.mp4']
    !mkdir -p $OUT_DIR
    for video_path in IN_VIDEOS:
        extract_images(video_path, OUT_DIR)

In [None]:
class DFLDataset(Dataset):
    def __init__(self, img_path, img_label, transform=None):
        self.img_path = img_path
        self.img_label = img_label
        if transform is not None:
            self.transform = transform
        else:
            self.transform = None

    def __getitem__(self, index):

        # img = cv2.imread(self.img_path[index])
        img_path  = self.img_path[index]
        img = self.load_3d_slice(img_path) # [h, w, c]

        img = img.astype(np.float32)
        # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)

        if self.transform is not None:
            img = self.transform(image=img)['image']
        return img, torch.from_numpy(np.array(self.img_label[index]))

    def __len__(self):
        return len(self.img_path)

    def load_3d_slice(self, middle_img_path):
        #### 步骤1: 获取中间图片的基本信息
        #### eg: middle_img_path: './work_25d/split_images/train\play\1606b0e6_1_012923_mid.jpg'
        middle_slice = os.path.basename(middle_img_path).split('_mid')[0].split('\\')[-1].split('.jpg')[0] # eg: 1606b0e6_1_012923
        middle_slice_num = middle_slice.split('_')[-1]

        new_25d_imgs = []

        ##### 步骤2：按照左右n_25d_shift数量进行填充，如果没有相应图片填充为Nan.
        ##### 注：经过EDA发现同一天的所有患者图片的shape是一致的
        for i in range(-3, 4):  # eg: i = {-2, -1, 0, 1, 2}
            if i != 0:
                shift_slice_num = int(middle_slice_num) + i
                shift_slice_str = str(shift_slice_num).zfill(6)
                shift_img_path = middle_img_path.replace(middle_slice_num + '_mid', shift_slice_str)
            else:
                shift_img_path = middle_img_path
            if os.path.exists(shift_img_path):
                shift_img = cv2.imread(shift_img_path, cv2.IMREAD_UNCHANGED)  # [w, h]
                shift_img = cv2.cvtColor(shift_img, cv2.COLOR_RGB2GRAY)
                shift_img = cv2.resize(shift_img,CFG.img_size)

                new_25d_imgs.append(shift_img)
            else:
                new_25d_imgs.append(None)
                # print(shift_img_path)

        ##### 步骤3：从中心开始往外循环，依次填补None的值
        ##### eg: n_25d_shift = 2, 那么形成5个channel, idx为[0, 1, 2, 3, 4], 所以依次处理的idx为[1, 3, 0, 4]
        shift_left_idxs = []
        shift_right_idxs = []
        for related_idx in range(3):
            shift_left_idxs.append(3 - related_idx)
            shift_right_idxs.append(3 + related_idx + 1)

        for left_idx, right_idx in zip(shift_left_idxs, shift_right_idxs):
            if new_25d_imgs[left_idx] is None:
                new_25d_imgs[left_idx] = new_25d_imgs[left_idx + 1]
            if new_25d_imgs[right_idx] is None:
                new_25d_imgs[right_idx] = new_25d_imgs[right_idx - 1]

        new_25d_imgs = np.stack(new_25d_imgs, axis=2).astype('float32')  # [w, h, c]
        mx_pixel = new_25d_imgs.max()
        if mx_pixel != 0:
            new_25d_imgs /= mx_pixel
        return new_25d_imgs