# Extract frame

In [1]:
import cv2
import os


def RealFrameCount(video_path):
    cap = cv2.VideoCapture(video_path)
    f_count = 0
    while(cap.isOpened()):
        ret, image = cap.read()
        if ret:
            f_count += 1    
        else:
            break    
    return f_count


def ExtractFrameListByInterval(total_frame_count, fps_criteria, real_fps):
    interval = int(real_fps/fps_criteria)
    return [i for i in range(total_frame_count)][::interval][1:]


def GetImageList(video_path , extract_fps=3):

    image_list = []
    video_info = {'fps':0, 'real_fps':0, 'target_frame_idx':None,}

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)  
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))  
    video_length = total_frames / fps
    real_frame_count = RealFrameCount(video_path)
    real_fps = real_frame_count / video_length
    interval_list = ExtractFrameListByInterval(real_frame_count, fps_criteria=extract_fps, real_fps=real_fps)

    video_info['fps'] = fps
    video_info['real_fps'] = real_fps
    video_info['target_frame_idx'] = interval_list

    f_count = 0
    while(cap.isOpened()):
        ret, image = cap.read()
        if ret:
            f_count += 1
            if f_count in interval_list:
                image_list.append(image)
                if f_count == interval_list[-1]:
                    break
        else:
            break
    
    return image_list, video_info

In [2]:
VIDEO_PATH = r"SampleData\001_s3_c1.mp4"

image_list, video_info = GetImageList(VIDEO_PATH, extract_fps=3)
print(len(image_list))
print(video_info)

58
{'fps': 59.94005994005994, 'real_fps': 29.773413091126308, 'target_frame_idx': [9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99, 108, 117, 126, 135, 144, 153, 162, 171, 180, 189, 198, 207, 216, 225, 234, 243, 252, 261, 270, 279, 288, 297, 306, 315, 324, 333, 342, 351, 360, 369, 378, 387, 396, 405, 414, 423, 432, 441, 450, 459, 468, 477, 486, 495, 504, 513, 522]}


# Inference

In [3]:
import glob
import monai
import torch
import albumentations as A
import numpy as np 

def LoadModel(CFG, check_point_path=False):
    model = monai.networks.nets.DenseNet121(
        spatial_dims=2, 
        in_channels=CFG.ImgInfo['Channel'],
        out_channels=CFG.NumOutputClass,
        pretrained=True
    )
    if check_point_path:
        model.load_state_dict(torch.load(check_point_path, map_location=CFG.Device))
    return model.eval()


def LoadData(CFG, img=None, input_path=None):
    
    Resize = [A.Resize(CFG.ImgInfo['Size'], CFG.ImgInfo['Size'])]
    Normalize = [A.Normalize(mean=CFG.ImgInfo['Mean'], std=CFG.ImgInfo['Std'], max_pixel_value=255.0, p=1.0)]
    ValidAugmentationList = Resize + Normalize
    valid_aug = A.Compose(ValidAugmentationList, is_check_shapes=False)

    if input_path:
        img = cv2.cvtColor(cv2.imread(input_path), cv2.COLOR_BGR2RGB)

    augmented = valid_aug(image=img)
    img = augmented["image"]
    img = torch.tensor(img).permute(2,0,1) #[W,H,C] -> [C,W,H]
    img = torch.unsqueeze(img, dim=0)
    return img


def EnsemblePredict(models, input_list, CFG):
    preds_by_images = []
    preds_by_models = []
    with torch.no_grad():
        for cnt, model in enumerate(models):
            for input in input_list:
                outputs = model(input.to(CFG.Device))
                outputs = torch.nn.functional.softmax(outputs, dim=-1)
                preds_by_images.append(outputs)
            ensemble_output_by_images = torch.mean(torch.stack(preds_by_images), dim=0)
            ensemble_output_by_images = ensemble_output_by_images.numpy()[0][1]
            
            pred_binary = np.where(ensemble_output_by_images>=CFG.ModelCutoff[f'{cnt}Fold'], 1, 0)
            preds_by_models.append(pred_binary)
            
        ensemble_output_by_models = np.mean(preds_by_models)
    return ensemble_output_by_models

In [4]:
class CFG:
    Device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    ImgInfo = {'Size':256, 'Channel':3, 'Mean':0, 'Std':1}
    NumOutputClass = 2
    ModelCutoff = {'0Fold': 0.4946, '1Fold': 0.6218, '2Fold': 0.4114, '3Fold': 0.6081, '4Fold': 0.3817}
    AvgCutoff = np.mean(list(ModelCutoff.values()))

In [5]:
MODEL_ROOT_PATH = r"Model"
CFG.Device = 'cpu'


MODELS = []
Model_Path_List = sorted(glob.glob(os.path.join(MODEL_ROOT_PATH, '*.pth')))
for Model_Path in Model_Path_List:
    MODELS.append(LoadModel(CFG, check_point_path=Model_Path))
print('Loaded model number :', len(MODELS))


Loaded model number : 5


In [6]:
input_list = [LoadData(CFG, img=i) for i in image_list]
prob = EnsemblePredict(MODELS, input_list, CFG)
print(f"The probability of it being cirrhosis is {prob*100:.2f}%")


The probability of it being cirrhosis is 0.00%


# Test in batch

In [7]:
VIDEO_PATH_LIST = [
    r"SampleData\001_s3_c1.mp4",
    r"SampleData\012_s5_c3.mp4",
]


class CFG:
    Device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    ImgInfo = {'Size':256, 'Channel':3, 'Mean':0, 'Std':1}
    NumOutputClass = 2
    ModelCutoff = {'0Fold': 0.4946, '1Fold': 0.6218, '2Fold': 0.4114, '3Fold': 0.6081, '4Fold': 0.3817}
    AvgCutoff = np.mean(list(ModelCutoff.values()))

MODEL_ROOT_PATH = r"Model"
CFG.Device = 'cpu'

MODELS = []
Model_Path_List = sorted(glob.glob(os.path.join(MODEL_ROOT_PATH, '*.pth')))
for Model_Path in Model_Path_List:
    MODELS.append(LoadModel(CFG, check_point_path=Model_Path))
print('Loaded model number :', len(MODELS))


for VIDEO_PATH in VIDEO_PATH_LIST:
    image_list, video_info = GetImageList(VIDEO_PATH, extract_fps=3)
    input_list = [LoadData(CFG, img=i) for i in image_list]
    prob = EnsemblePredict(MODELS, input_list, CFG)
    print(f"The probability of it being cirrhosis is {prob*100:.2f}%")    

Loaded model number : 5
The probability of it being cirrhosis is 0.00%
The probability of it being cirrhosis is 100.00%
