In [1]:
import logging
import os
from glob import glob
from time import time

import cv2
import hydra
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf
from PIL import Image
from tqdm import tqdm
from src.dataset import get_image_transforms
from src.model import TIMMModel

log = logging.getLogger(__name__)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
infer_config = {
    "work_dir": "/code",
    "name": "baseline-swin_large_patch4_window12_384",
    "checkpoint_s": "./weights/s/ckpts/last.ckpt",
    "checkpoint_h": "./weights/h/ckpts/checkpoint-epoch27-step8764-val_acc0.988-val_loss0.062.ckpt",
    "crop_size": 384,
    "videos_dir": "/data/*",
    "result_dir": "/result/"
}

In [3]:
def get_model(infer_config, checkpoint):
    checkpoint_path = checkpoint

    output_folder = "/".join(checkpoint_path.split("/")[:-2])
    cfg = OmegaConf.load(
        os.path.join(infer_config["work_dir"], output_folder, ".hydra", "config.yaml")
    )
    model = TIMMModel(cfg.model)

    ckpt = torch.load(os.path.join(infer_config["work_dir"], checkpoint_path))
    model.load_state_dict(ckpt["state_dict"])

    return model

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model_s = get_model(infer_config, infer_config["checkpoint_s"]).eval().to(device)
model_h = get_model(infer_config, infer_config["checkpoint_h"]).eval().to(device)

transforms = get_image_transforms(infer_config["crop_size"], False, None)
dummy = torch.zeros((1, 3, 384, 384)).to(device)
model_s(dummy)
model_h(dummy)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


tensor([[0.2933, 0.0405]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [5]:
list_video = sorted(
    glob(infer_config["videos_dir"])
)

In [6]:
result_dir = infer_config["result_dir"]
time_submission_file = open(os.path.join(result_dir, "time_submission.csv"), "w")
time_submission_file.write("fname,time\n")

submission_file = open(os.path.join(result_dir, "jupyter_submission.csv"), "w")
submission_file.write("fname,liveness_score\n")

for video_path in list_video:
    t1 = time()
    name = video_path.split('/')[-1]
    cap = cv2.VideoCapture(video_path)
    frames = []
    fps = cap.get(cv2.CAP_PROP_FPS)
    if not fps:
        fps = 25
    count = 0

    images = []
    while cap.isOpened():
        ret, frame = cap.read()
        if isinstance(frame, np.ndarray):
            if int(count % round(fps)) == 0:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                image = Image.fromarray(frame)
                image = transforms(image)
                images.append(image)


            count += 1
        else:
            break
    images = torch.stack(images)
    images = images.to(device)
    with torch.no_grad():
        logits_s = model_s(images)
        logits_s = F.softmax(logits_s, dim=-1)
        outputs_s = logits_s[:, 1].mean().item()

        logits_h = model_h(images)
        logits_h = F.softmax(logits_h, dim=-1)
        outputs_h = logits_h[:, 1].mean().item()

        outputs = (outputs_s + outputs_h)/2
    t2 = time()
    predicted_time = int(t2*1000 - t1*1000)
    print(f"{video_path}: pred {outputs}, time {predicted_time}")
    submission_file.write(f"{name},{outputs}\n")
    time_submission_file.write(f"{name},{predicted_time}\n")
submission_file.close()
time_submission_file.close()

/data/0.mp4: pred 8.728980276373477e-07, time 264
/data/1.mp4: pred 0.9999999403953552, time 228
/data/10.mp4: pred 0.9952429831027985, time 264
/data/100.mp4: pred 1.381251252041693e-06, time 163
/data/101.mp4: pred 7.696585271332879e-07, time 180
/data/102.mp4: pred 0.9999879598617554, time 244
/data/103.mp4: pred 0.9994657933712006, time 181
/data/104.mp4: pred 0.9999972283840179, time 319
/data/105.mp4: pred 0.9999998211860657, time 243
/data/106.mp4: pred 0.999691367149353, time 288
/data/107.mp4: pred 2.8925641515797906e-05, time 272
/data/108.mp4: pred 0.007888829193234415, time 262
/data/109.mp4: pred 5.432146821249262e-08, time 224
/data/11.mp4: pred 2.1718092284572776e-05, time 345
/data/110.mp4: pred 0.9999939501285553, time 313
/data/111.mp4: pred 0.9999999105930328, time 182
/data/112.mp4: pred 0.9999993145465851, time 295
/data/113.mp4: pred 2.0024368723170483e-06, time 313
/data/114.mp4: pred 0.9999715983867645, time 340
/data/115.mp4: pred 0.10336876288056374, time 494


/data/247.mp4: pred 0.999999463558197, time 254
/data/248.mp4: pred 0.9999995529651642, time 375
/data/249.mp4: pred 6.0113797189842444e-05, time 241
/data/25.mp4: pred 0.0006040294483682374, time 199
/data/250.mp4: pred 1.1527880090511644e-06, time 248
/data/251.mp4: pred 0.7516434788703918, time 284
/data/252.mp4: pred 1.8277914023201447e-05, time 551
/data/253.mp4: pred 0.03309551056008786, time 318
/data/254.mp4: pred 4.220825303491438e-06, time 311
/data/255.mp4: pred 9.096828108567934e-07, time 170
/data/256.mp4: pred 6.385232325101242e-07, time 324
/data/257.mp4: pred 1.592933259075835e-06, time 344
/data/258.mp4: pred 0.999976396560669, time 235
/data/259.mp4: pred 1.0076807157588519e-06, time 203
/data/26.mp4: pred 0.9999995529651642, time 134
/data/260.mp4: pred 0.9999999701976776, time 162
/data/261.mp4: pred 0.17621389031410217, time 217
/data/262.mp4: pred 0.9999994039535522, time 240
/data/263.mp4: pred 0.9999969601631165, time 142
/data/264.mp4: pred 0.999829888343811, t

/data/395.mp4: pred 0.7652354836463928, time 528
/data/396.mp4: pred 0.20274228655034676, time 350
/data/397.mp4: pred 4.4223388329101e-07, time 350
/data/398.mp4: pred 0.07136414083652198, time 226
/data/399.mp4: pred 0.18887927196919918, time 256
/data/4.mp4: pred 0.03964916756376624, time 352
/data/40.mp4: pred 0.007065375801175833, time 214
/data/400.mp4: pred 0.705162525177002, time 226
/data/401.mp4: pred 0.9999983608722687, time 168
/data/402.mp4: pred 0.9999988973140717, time 516
/data/403.mp4: pred 0.9999998807907104, time 224
/data/404.mp4: pred 0.33386696939123794, time 514
/data/405.mp4: pred 0.9999790489673615, time 175
/data/406.mp4: pred 0.999966949224472, time 305
/data/407.mp4: pred 0.9997240900993347, time 169
/data/408.mp4: pred 0.04976258799433708, time 421
/data/409.mp4: pred 0.9999996423721313, time 246
/data/41.mp4: pred 0.09448429383337498, time 342
/data/410.mp4: pred 0.9999997317790985, time 204
/data/411.mp4: pred 0.9999582469463348, time 177
/data/412.mp4: p