In [2]:
from torchvision.models import resnet

In [3]:
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_VISIBLE_DEVICES=0


In [4]:
import torch
from torch import nn

In [7]:
from torch.nn.utils import weight_norm
from torch import nn

class Identity(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x

class Chomp1d(nn.Module):
    def __init__(self, chomp_size):
        super().__init__()
        self.chomp_size = chomp_size

    def forward(self, x):
        return x[:, :, : -self.chomp_size].contiguous()


class GAP1d(nn.Module):
    'Global Adaptive Pooling + Flatten'

    def __init__(self, output_size=1):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool1d(output_size)
        self.flatten = nn.Flatten()

    def forward(self, x):
        return self.flatten(self.gap(x))


class TemporalBlock(nn.Module):
    def __init__(self, ni, nf, ks, stride, dilation, padding, dropout=0.0):
        super().__init__()
        self.conv1 = weight_norm(
            nn.Conv1d(
                ni, nf, ks, stride=stride, padding=padding, dilation=dilation,
            ),
        )
        self.chomp1 = Chomp1d(padding)
        self.relu1 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout)
        self.conv2 = weight_norm(
            nn.Conv1d(
                nf, nf, ks, stride=stride, padding=padding, dilation=dilation,
            ),
        )
        self.chomp2 = Chomp1d(padding)
        self.relu2 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout)
        self.net = nn.Sequential(
            self.conv1,
            self.chomp1,
            self.relu1,
            self.dropout1,
            self.conv2,
            self.chomp2,
            self.relu2,
            self.dropout2,
        )
        self.downsample = nn.Conv1d(ni, nf, 1) if ni != nf else None
        self.relu = nn.ReLU()
        self.init_weights()

    def init_weights(self):
        self.conv1.weight.data.normal_(0, 0.01)
        self.conv2.weight.data.normal_(0, 0.01)
        if self.downsample is not None:
            self.downsample.weight.data.normal_(0, 0.01)

    def forward(self, x):
        out = self.net(x)
        res = x if self.downsample is None else self.downsample(x)
        return self.relu(out + res)


def temporal_conv_net(c_in, layers, ks=2, dropout=0.0):
    temp_layers = []
    for i, layer in enumerate(layers):
        dilation_size = 2 ** i
        ni = c_in if i == 0 else layers[i - 1]
        nf = layer
        temp_layers += [
            TemporalBlock(
                ni,
                nf,
                ks,
                stride=1,
                dilation=dilation_size,
                padding=(ks - 1) * dilation_size,
                dropout=dropout,
            ),
        ]
    return nn.Sequential(*temp_layers)


class TCN(nn.Module):
    def __init__(
            self,
            c_in,
            c_out,
            layers=[25] * 8,
            ks=7,
            conv_dropout=0.0,
            fc_dropout=0.0,
    ):
        super().__init__()
        self.norm = nn.BatchNorm1d(c_in)
        self.tcn = temporal_conv_net(c_in, layers, ks=ks, dropout=conv_dropout)
        self.gap = GAP1d()
        self.dropout = nn.Dropout(fc_dropout) if fc_dropout else None
        self.linear = nn.Linear(layers[-1], c_out)
        self.init_weights()

    def init_weights(self):
        self.linear.weight.data.normal_(0, 0.01)

    def forward(self, x):
        x = self.norm(x)
        x = self.tcn(x)
        x = self.gap(x)
        if self.dropout is not None:
            x = self.dropout(x)
        return self.linear(x)

In [8]:
class ResNetTCN(nn.Module):
    def __init__(self):
        super().__init__()
        model = resnet.resnet18(pretrained=True)
        in_feats = model.fc.in_features
        model.fc = Identity()
        self.encoder = model
        self.tcn = TCN(in_feats, 1, layers=[68]*8, ks=3)
        
    def forward(self, X):
        '[B, frames, C, H, W]'#;print(X.shape)
        B, F, C, H, W = X.shape
        X = X.reshape(-1, C, H, W)
        X = self.encoder(X)
        '[B*frames, in_feats]'#;print(X.shape)
        X = X.reshape(B, F, -1)
        '[B, frames, in_feats]'#;print(X.shape)
        X = X.transpose(1, 2)
        '[B, in_feats, frames]'#;print(X.shape)
        X = self.tcn(X)
        return X

In [6]:
import albumentations as albu
from albumentations.pytorch import ToTensorV2

trans = albu.Compose([
    albu.Resize(224, 224), 
    albu.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
    ToTensorV2()
])

In [7]:
from torch.utils.data import DataLoader
import torch
from os.path import join
import random
import cv2
import numpy as np
import os

NFRAMES = 300

def is_matched(block, frame, thr=0.95):
    return cv2.matchTemplate(block, frame, cv2.TM_CCOEFF_NORMED)[0][0] > thr

BLOCK1 = cv2.imread('block.png')
BLOCK2 = cv2.imread('block1.png')

class CrashDataset:
    def __init__(self, vid_paths, labels):
        assert len(vid_paths) == len(labels)
        self.vid_paths = vid_paths
        self.labels = labels
        
    def __getitem__(self, idx):
        # obj = self.annot[idx]
        # rand_id = random.choice(obj['participants']) # TODO if None
        vid_path = self.vid_paths[idx]
        label = self.labels[idx]
        cap = cv2.VideoCapture(str(vid_path))
        cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if cnt > NFRAMES:
            start = random.randrange(0, cnt-NFRAMES+1)
            cap.set(cv2.CAP_PROP_POS_FRAMES, start)
        
        no_banner = label == 0
        frames = []
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            if not no_banner and (is_matched(frame, BLOCK1) or is_matched(frame, BLOCK2)):
                print('found banner')
                continue
            no_banner = True
            
            frame = frame[...,::-1] # BGR -> RGB
            frames.append(frame)
            if len(frames) == NFRAMES:
                break
        if not frames:
            print(vid_path, 'is invalid')
            return self[random.randrange(0, len(self))]
                
                
        # frames = np.array(frames)
        # len_ = len(frames)
        # if len_ < NFRAMES: # remove
        #     frames = np.pad(frames, ((0,NFRAMES-len_),(0,0),(0,0), (0,0)))
            
        frames = [trans(image=frame)['image'] for frame in frames]
        
        len_ = len(frames)
        if len_ < NFRAMES:
            frames = frames + [frames[-1]] * (NFRAMES - len_)
            
        frames = torch.stack(frames)
        # tensor = torch.tensor(frames / 255).float()
        return frames, torch.tensor([label]).float()
         

    def __len__(self):
        return len(self.vid_paths)

In [8]:
POS_DIR = 'data_parsed/'
NEG_DIR = 'data_site_small/'

In [9]:
import pandas as pd
df = pd.read_csv('urls_marked.csv')
vid_paths_pos = np.array([POS_DIR + str(i) + '.mp4' for i in df[df.status==1].index])

In [None]:
!ls data_site_small | wc -l

In [10]:
from pathlib import Path

In [11]:
split_val = int(0.2*len(vid_paths_pos))
vid_paths_pos_val, vid_paths_pos_train = vid_paths_pos[:split_val], vid_paths_pos[split_val:]

In [12]:
vid_paths_neg = np.array([str(i) for i in Path(NEG_DIR).glob('*.mp4')])

In [13]:
from sklearn.model_selection import GroupKFold

def group_kfold(*arrays, groups=None):
    gkf = GroupKFold(n_splits=5)
    train_ids, test_ids = next(gkf.split(*arrays, groups=groups))
    return (arr[ids] for arr in arrays for ids in (train_ids, test_ids))

In [14]:
groups = [Path(i).stem.split('_')[0] for i in vid_paths_neg]
vid_paths_neg_train, vid_paths_neg_val = group_kfold(vid_paths_neg, groups=groups)

In [15]:
vid_paths_train = np.hstack((vid_paths_pos_train, vid_paths_neg_train))
labels_train = [1]*len(vid_paths_pos_train) + [0]*len(vid_paths_neg_train)

In [16]:
vid_paths_val = np.hstack((vid_paths_pos_val, vid_paths_neg_val))
labels_val = [1]*len(vid_paths_pos_val) + [0]*len(vid_paths_neg_val)

In [17]:
train = CrashDataset(vid_paths_train, labels_train)
val = CrashDataset(vid_paths_val, labels_val)

In [18]:
train_loader = DataLoader(train, batch_size=4, num_workers=20, drop_last=True, shuffle=True)
val_loader = DataLoader(val, batch_size=4, num_workers=20)

In [None]:
len(train), len(val)

In [19]:
model = ResNetTCN().cuda()

In [23]:
import os
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl, utils

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=3e-4)

loaders = {
    "train": train_loader,
    "valid": val_loader
}

runner = dl.SupervisedRunner(
    input_key="features", output_key="logits", target_key="targets", loss_key="loss"
)

# model training
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=30,
    callbacks=[
        dl.auc.AUCCallback(input_key='logits', target_key='targets')
    ],
    logdir="./logs",
    valid_loader="valid",
    valid_metric="auc",
    minimize_valid_metric=False,
    verbose=True,
    load_best_on_end=True,
)

1/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (1/30) auc: 0.9115520119667053 | auc/_macro: 0.9115520119667053 | auc/_micro: 0.9115519852390817 | auc/_weighted: 0.5308008790016174 | loss: 0.37653318254484064 | loss/mean: 0.37653318254484064 | loss/std: 0.364035597655034 | lr: 0.0003 | momentum: 0.9


1/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (1/30) auc: 0.8882462382316589 | auc/_macro: 0.8882462382316589 | auc/_micro: 0.8882462301177443 | auc/_weighted: 0.5132898688316345 | loss: 0.6035901587151107 | loss/mean: 0.6035901587151107 | loss/std: 1.085814934205753 | lr: 0.0003 | momentum: 0.9
* Epoch (1/30) 


2/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (2/30) auc: 0.9398901462554932 | auc/_macro: 0.9398901462554932 | auc/_micro: 0.9398901634493202 | auc/_weighted: 0.5473023056983948 | loss: 0.30667553185542773 | loss/mean: 0.30667553185542773 | loss/std: 0.3025405120807499 | lr: 0.0003 | momentum: 0.9


2/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (2/30) auc: 0.9142739176750183 | auc/_macro: 0.9142739176750183 | auc/_micro: 0.9142739103491015 | auc/_weighted: 0.5283304452896118 | loss: 0.5907588809667553 | loss/mean: 0.5907588809667553 | loss/std: 0.9921837012088393 | lr: 0.0003 | momentum: 0.9
* Epoch (2/30) 


3/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (3/30) auc: 0.9467310309410095 | auc/_macro: 0.9467310309410095 | auc/_micro: 0.9467310135946666 | auc/_weighted: 0.5512858033180237 | loss: 0.29111686122408265 | loss/mean: 0.29111686122408265 | loss/std: 0.3242793970087344 | lr: 0.0003 | momentum: 0.9


3/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (3/30) auc: 0.9366521835327148 | auc/_macro: 0.9366521835327148 | auc/_micro: 0.9366522068443159 | auc/_weighted: 0.541262149810791 | loss: 0.4831096099567467 | loss/mean: 0.4831096099567467 | loss/std: 1.0961582099151466 | lr: 0.0003 | momentum: 0.9
* Epoch (3/30) 


4/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (4/30) auc: 0.9629623293876648 | auc/_macro: 0.9629623293876648 | auc/_micro: 0.9629623187145311 | auc/_weighted: 0.5597466230392456 | loss: 0.24001580206945602 | loss/mean: 0.24001580206945602 | loss/std: 0.31587972671020437 | lr: 0.0003 | momentum: 0.9


4/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (4/30) auc: 0.9464297890663147 | auc/_macro: 0.9464297890663147 | auc/_micro: 0.946429801005302 | auc/_weighted: 0.5469123125076294 | loss: 0.34443403612898643 | loss/mean: 0.34443403612898643 | loss/std: 0.616134324114256 | lr: 0.0003 | momentum: 0.9
* Epoch (4/30) 


5/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (5/30) auc: 0.9766693711280823 | auc/_macro: 0.9766693711280823 | auc/_micro: 0.9766693483507644 | auc/_weighted: 0.5677141547203064 | loss: 0.18899975717642611 | loss/mean: 0.18899975717642611 | loss/std: 0.30718498143871736 | lr: 0.0003 | momentum: 0.9


5/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

found banner
found banner
found banner
found banner
found banner
found banner
found banner
found banner
found banner
found banner
found banner
found banner
found banner
valid (5/30) auc: 0.9506300091743469 | auc/_macro: 0.9506300091743469 | auc/_micro: 0.9506300351167114 | auc/_weighted: 0.5493394732475281 | loss: 0.3633160525474881 | loss/mean: 0.3633160525474881 | loss/std: 0.46297624245315255 | lr: 0.0003 | momentum: 0.9
* Epoch (5/30) 


6/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (6/30) auc: 0.9726902842521667 | auc/_macro: 0.9726902842521667 | auc/_micro: 0.9726903089734947 | auc/_weighted: 0.5654012560844421 | loss: 0.2054316030854528 | loss/mean: 0.2054316030854528 | loss/std: 0.28892090328115894 | lr: 0.0003 | momentum: 0.9


6/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (6/30) auc: 0.9162018895149231 | auc/_macro: 0.9162018895149231 | auc/_micro: 0.9162018866625352 | auc/_weighted: 0.5294445753097534 | loss: 0.7154011157661919 | loss/mean: 0.7154011157661919 | loss/std: 1.7151683298986427 | lr: 0.0003 | momentum: 0.9
* Epoch (6/30) 


7/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (7/30) auc: 0.9604214429855347 | auc/_macro: 0.9604214429855347 | auc/_micro: 0.9604214172570453 | auc/_weighted: 0.5592577457427979 | loss: 0.24844847724259422 | loss/mean: 0.24844847724259422 | loss/std: 0.31317841083848064 | lr: 0.0003 | momentum: 0.9


7/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (7/30) auc: 0.9175789952278137 | auc/_macro: 0.9175789952278137 | auc/_micro: 0.9175790126007024 | auc/_weighted: 0.5302403569221497 | loss: 0.46438969636706207 | loss/mean: 0.46438969636706207 | loss/std: 0.5539423279535323 | lr: 0.0003 | momentum: 0.9
* Epoch (7/30) 


8/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (8/30) auc: 0.9789698123931885 | auc/_macro: 0.9789698123931885 | auc/_micro: 0.9789697984388249 | auc/_weighted: 0.5690513849258423 | loss: 0.17879965437246562 | loss/mean: 0.17879965437246562 | loss/std: 0.27353559758231716 | lr: 0.0003 | momentum: 0.9


8/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (8/30) auc: 0.8986435532569885 | auc/_macro: 0.8986435532569885 | auc/_micro: 0.8986435309509054 | auc/_weighted: 0.5192981362342834 | loss: 0.5732252805543389 | loss/mean: 0.5732252805543389 | loss/std: 1.0072912849737115 | lr: 0.0003 | momentum: 0.9
* Epoch (8/30) 


9/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (9/30) auc: 0.9836681485176086 | auc/_macro: 0.9836681485176086 | auc/_micro: 0.983668123030862 | auc/_weighted: 0.5727944374084473 | loss: 0.14890384106254143 | loss/mean: 0.14890384106254143 | loss/std: 0.3058893742939104 | lr: 0.0003 | momentum: 0.9


9/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (9/30) auc: 0.9222612380981445 | auc/_macro: 0.9222612380981445 | auc/_micro: 0.9222612407904703 | auc/_weighted: 0.5329460501670837 | loss: 0.6928078948680435 | loss/mean: 0.6928078948680435 | loss/std: 1.75494029163485 | lr: 0.0003 | momentum: 0.9
* Epoch (9/30) 


10/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (10/30) auc: 0.9854189157485962 | auc/_macro: 0.9854189157485962 | auc/_micro: 0.985418886303842 | auc/_weighted: 0.5728000998497009 | loss: 0.13656093181041926 | loss/mean: 0.13656093181041926 | loss/std: 0.2790218146969162 | lr: 0.0003 | momentum: 0.9


10/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (10/30) auc: 0.9303174018859863 | auc/_macro: 0.9303174018859863 | auc/_micro: 0.9303174275287476 | auc/_weighted: 0.5376014709472656 | loss: 0.5525134150897818 | loss/mean: 0.5525134150897818 | loss/std: 1.080625942900168 | lr: 0.0003 | momentum: 0.9
* Epoch (10/30) 


11/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (11/30) auc: 0.9917778968811035 | auc/_macro: 0.9917778968811035 | auc/_micro: 0.9917779168404951 | auc/_weighted: 0.5754760503768921 | loss: 0.10197757810201892 | loss/mean: 0.10197757810201892 | loss/std: 0.25048916955014083 | lr: 0.0003 | momentum: 0.9


11/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (11/30) auc: 0.9438821077346802 | auc/_macro: 0.9438821077346802 | auc/_micro: 0.9438821180196929 | auc/_weighted: 0.5454400777816772 | loss: 0.48685687871955924 | loss/mean: 0.48685687871955924 | loss/std: 1.3313201614848877 | lr: 0.0003 | momentum: 0.9
* Epoch (11/30) 


12/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (12/30) auc: 0.9929769039154053 | auc/_macro: 0.9929769039154053 | auc/_micro: 0.9929768867821964 | auc/_weighted: 0.5771933197975159 | loss: 0.10307722918652361 | loss/mean: 0.10307722918652361 | loss/std: 0.215255446983577 | lr: 0.0003 | momentum: 0.9


12/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (12/30) auc: 0.9307305812835693 | auc/_macro: 0.9307305812835693 | auc/_micro: 0.9307305653101976 | auc/_weighted: 0.5378402471542358 | loss: 0.5221901067746881 | loss/mean: 0.5221901067746881 | loss/std: 1.175436836766971 | lr: 0.0003 | momentum: 0.9
* Epoch (12/30) 


13/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (13/30) auc: 0.9962992668151855 | auc/_macro: 0.9962992668151855 | auc/_micro: 0.9962992759452934 | auc/_weighted: 0.5791245698928833 | loss: 0.0537406204338781 | loss/mean: 0.0537406204338781 | loss/std: 0.22461797268453693 | lr: 0.0003 | momentum: 0.9


13/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (13/30) auc: 0.8854231238365173 | auc/_macro: 0.8854231238365173 | auc/_micro: 0.8854231219445018 | auc/_weighted: 0.5116584897041321 | loss: 1.112703480524219 | loss/mean: 1.112703480524219 | loss/std: 1.6425140239188776 | lr: 0.0003 | momentum: 0.9
* Epoch (13/30) 


14/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

train (14/30) auc: 0.9814357161521912 | auc/_macro: 0.9814357161521912 | auc/_micro: 0.981435708193354 | auc/_weighted: 0.571494460105896 | loss: 0.16695293935704525 | loss/mean: 0.16695293935704525 | loss/std: 0.3501161864978833 | lr: 0.0003 | momentum: 0.9


14/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (14/30) auc: 0.9085588455200195 | auc/_macro: 0.9085588455200195 | auc/_micro: 0.9085588377057082 | auc/_weighted: 0.525027871131897 | loss: 0.7592238022260689 | loss/mean: 0.7592238022260689 | loss/std: 1.8872501755220357 | lr: 0.0003 | momentum: 0.9
* Epoch (14/30) 


15/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (15/30) auc: 0.9935986995697021 | auc/_macro: 0.9935986995697021 | auc/_micro: 0.9935986701248062 | auc/_weighted: 0.5785770416259766 | loss: 0.08482493780679816 | loss/mean: 0.08482493780679816 | loss/std: 0.2694467166673854 | lr: 0.0003 | momentum: 0.9


15/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

valid (15/30) auc: 0.9430558681488037 | auc/_macro: 0.9430558681488037 | auc/_micro: 0.9430558424567926 | auc/_weighted: 0.5449626445770264 | loss: 0.533470603554037 | loss/mean: 0.533470603554037 | loss/std: 0.791086236155938 | lr: 0.0003 | momentum: 0.9
* Epoch (15/30) 


16/30 * Epoch (train):   0%|          | 0/243 [00:00<?, ?it/s]

data_site_small/80_3.mp4 is invalid
train (16/30) auc: 0.994073212146759 | auc/_macro: 0.994073212146759 | auc/_micro: 0.994073203126701 | auc/_weighted: 0.5798760056495667 | loss: 0.08919117624086709 | loss/mean: 0.08919117624086709 | loss/std: 0.2080116132768069 | lr: 0.0003 | momentum: 0.9


16/30 * Epoch (valid):   0%|          | 0/61 [00:00<?, ?it/s]

Keyboard Interrupt


KeyboardInterrupt: 

In [7]:
model = ResNetTCN().cuda()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /home/gorodion/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

In [8]:
model.load_state_dict(torch.load('logs/checkpoints/best.pth')['model_state_dict'])

<All keys matched successfully>

In [None]:
runner.model = model

In [27]:
from sklearn.metrics import precision_recall_curve
import torch

@torch.no_grad()
def predict(model, loader, steps=None):
    model.eval()
    y_pred_all = None
    y_all = None
    total_steps = len(loader) if steps is None else steps

    for step, (X, y) in enumerate(loader, 1):
        y_pred = model(X.cuda()).cpu()

        y_pred_all = torch.cat([y_pred_all, y_pred]) if y_pred_all is not None else y_pred

        y_all = torch.cat([y_all, y]) if y_all is not None else y

        print(f'{step}/{total_steps} valid steps', end='\r')
        if step == total_steps: break
    print(' '*50, end='\r')

    return y_pred_all, y_all

@torch.no_grad()
def precision_recall(y_pred, y_true):
    y_pred = y_pred.sigmoid().squeeze().cpu()
    y_true = y_true.squeeze().cpu()
    precision, recall, thr = precision_recall_curve(y_true, y_pred)
    fscores = 2 * (precision * recall) / (precision + recall)
    argmax = min(np.argmax(fscores), len(fscores)-2)
    p, r, f, t = (values[argmax] for values in (precision, recall, fscores, thr))

    print('precision:', p)
    print('recall:', r)
    print('f1:', f)
    print('threshold:', t)
    return p, r, f

def evaluate(clf, dataloader):
    y_pred, y_true = predict(clf, dataloader)
    return precision_recall(y_pred, y_true)

In [56]:
y_pred, y_true = predict(model, loaders['valid'])

found banner
                                                  

In [57]:
precision_recall(y_pred, y_true)

precision: 0.8819444444444444
recall: 0.900709219858156
f1: 0.8912280701754387
threshold: 0.7558913


(0.8819444444444444, 0.900709219858156, 0.8912280701754387)

In [51]:
@torch.no_grad()
def predict_sample(model, X):
    return model(X.cuda()).cpu()

In [46]:
import io 
from IPython.display import HTML
from base64 import b64encode
import cv2
import matplotlib.pyplot as plt

def show_video(file_name, width=640):
  # show resulting deepsort video
    mp4 = open(file_name,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return HTML("""
    <video width="{0}" controls>
        <source src="{1}" type="video/mp4">
    </video>
    """.format(width, data_url))

In [48]:
sample = val.vid_paths[0]
print(sample_vid_path)
show_video(sample_vid_path)

data_parsed/0.mp4


In [65]:
X, y = val[15] # семпл из валидации

In [76]:
predict_sample(model, X[None]).sigmoid().item()

0.9972987771034241

In [66]:
THR = 0.5
pred_prob = predict_sample(model, X[None]).sigmoid()[0,0]
print('Реальная метка:', 'ДТП' if int(y[0].item()) else 'не ДТП')
print('Предсказание:', 'ДТП' if (pred_prob > THR).item() else 'не ДТП')

Реальная метка: ДТП
Предсказание: ДТП
