In [None]:
import torch

torch.cuda.get_device_name(0)

In [5]:
import math
import re
import shutil
import statistics
import subprocess
import sys
from collections import defaultdict
from pathlib import Path
from pprint import pprint

import cv2
import numpy as np
from PIL import Image
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.build_sam import build_sam2
from tqdm import tqdm

if True:
    sys.path.append("..")
    from src.functions import extract_cards

In [6]:
PROJECT_ROOT = Path("..")

SRC = Path("__file__").parent.resolve()
ROOT = SRC.parent
DATA = ROOT / "data"

In [None]:
checkpoint = PROJECT_ROOT / "checkpoints" / "sam2_hiera_tiny.pt"
model_cfg = "sam2_hiera_t.yaml"

# video_path = PROJECT_ROOT / "data" / "20240815_172323.mp4"
# video_path = PROJECT_ROOT / "data" / "20240820_203040.mp4"
# video_path = PROJECT_ROOT / "data" / "20240820_203206.mp4"
# video_path = PROJECT_ROOT / "data" / "20240930_093614.mp4"
# video_path = PROJECT_ROOT / "data" / "20240930_094536.mp4"
# video_path = PROJECT_ROOT / "data" / "20240930_115001.mp4"
# video_path = PROJECT_ROOT / "data" / "20240930_115325.mp4"
# video_path = PROJECT_ROOT / "data" / "20240930_120701.mp4"
video_path = PROJECT_ROOT / "data" / "20240930_120950.mp4"

video_dir = video_path.parent / video_path.stem

if True:
    if video_dir.is_dir():
        shutil.rmtree(video_dir)

    video_dir.mkdir(parents=True, exist_ok=True)

    cmd = f"ffmpeg -i {video_path} -filter:v scale=1080:-1 -q:v 2 {video_dir}/%05d.jpg"
    print(cmd)

    # extract JPEG frames from the video
    subprocess.run(cmd.split(), stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)

In [9]:
# scan all the JPEG frame names in this directory
frame_names = [
    p for p in video_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".JPG", ".JPEG"]
]

# if frame_names:
#     num_frames = int(max(frame_names).stem)
#     num_seqs = int(round(num_frames / 300, 0))

#     frames_per_seq = math.ceil(num_frames / num_seqs)

#     for i in range(num_seqs):
#         seq_dir = video_dir / str(i).zfill(2)
#         seq_dir.mkdir(parents=True, exist_ok=True)

#     for i, frame in enumerate(frame_names):
#         frame_dir = video_dir / str(i // frames_per_seq).zfill(2)
#         frame.rename(frame_dir / frame.name)
# else:
#     frame_names = [
#         p
#         for p in (video_dir / "00").iterdir()
#         if p.suffix in [".jpg", ".jpeg", ".JPG", ".JPEG"]
#     ]

# sort frames by name
frame_names.sort(key=lambda p: int(p.stem))

In [10]:
# first_seq_dir = video_dir / "00"

# # take a look at a frame of the video
# frame_idx = 100
# plt.figure(figsize=(8, 6))
# plt.title(f"frame {frame_idx}")
# plt.imshow(Image.open(first_seq_dir / frame_names[frame_idx].name))
# plt.show()

# # scan all the JPEG frame names in this directory
# frame_names = [
#     p for p in first_seq_dir.iterdir() if p.suffix in [".jpg", ".jpeg", ".JPG", ".JPEG"]
# ]
# frame_names.sort(key=lambda p: int(p.stem))

In [11]:
# predictor = build_sam2_video_predictor(model_cfg, checkpoint)

# inference_state = predictor.init_state(video_path=first_seq_dir.as_posix())

In [12]:
mask_generator = SAM2AutomaticMaskGenerator(
    build_sam2(model_cfg, checkpoint),
    points_per_side=16,
    stability_score_thresh=0.9,
    stability_score_offset=0.95,
    min_mask_region_area=5_000,
    # min_mask_region_area=5_000,
    # use_m2m=True,
)

In [None]:
frames_sharpness = []
for frame_path in tqdm(frame_names):
    # Load frame from given path
    frame = np.array(Image.open(frame_path).convert("RGB"))

    # Convert frame to grayscale for better performance
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # Calculate the Laplacian variance
    sharp = cv2.Laplacian(gray_frame, cv2.CV_64F).var()

    frames_sharpness.append(sharp)

frames_sharpness = np.asarray(frames_sharpness)

In [None]:
THRESHOLD = 70

best_frames = set()

for window_size in (20, 30):
    for idx in range(0, len(frame_names) - window_size, window_size // 2):
        best_idx = (
            idx + np.argmax(frames_sharpness[idx : (idx + window_size // 2)]).item()
        )
        # print(idx, best_idx)
        if frames_sharpness[best_idx] > THRESHOLD:
            best_frames.add(best_idx)

print(len(best_frames))

In [None]:
extracted_cards = {}

for idx in sorted(best_frames):
    frame_path = frame_names[idx]
    # # plt.imshow(warped_card)
    # plt.imshow(frame)
    # plt.show()

    # Load frame from given path
    frame = np.array(Image.open(frame_path).convert("RGB"))

    # Convert frame to grayscale for better performance
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # Calculate the Laplacian variance
    sharpness = cv2.Laplacian(gray_frame, cv2.CV_64F).var()

    outputs = extract_cards(mask_generator, frame_path)

    print(f"{idx:>4}", f"{sharpness:>6.2f}", outputs)

    if outputs and any(map(any, outputs)):
        extracted_cards[idx] = outputs

In [None]:
processed = defaultdict(list)

for idx, items in extracted_cards.items():
    # print(idx, items)
    for card_name, card_num in items:
        if "\n" in card_name:
            best_candidate = 0
            for name in map(str.strip, card_name.splitlines()):
                curr_candidate = statistics.mean(map(len, name.split()))
                if curr_candidate > best_candidate:
                    best_candidate = curr_candidate
                    cname = name
        else:
            cname = card_name

        if "\n" in card_num:
            card_num = card_num.splitlines()
            if len(card_num) == 2 or re.match(
                r"[o\d]{4}", card_num[0], flags=re.IGNORECASE
            ):
                cnum, cset = card_num[:2]
            else:
                cnum, cset = card_num[-2:]
            cnum = re.sub(r"Oo", "0", cnum)
            cnum = re.findall(r"(?<=[a-z]\s)?\d{3,4}", cnum)
            cnum = cnum[0] if cnum else None
            cset = cset.split()[:1]
            cset = cset[0] if cset else None
            if cset and len(cset) > 3:
                cset = cset.removesuffix("EN")
        else:
            cnum = re.findall(r"(?<=[a-z]\s)?[o\d]{3,4}", card_num, flags=re.IGNORECASE)
            cnum = cnum[0] if cnum else None
            if not cnum:
                cset = card_num or None

        print(f"{idx:>4}", "-", cset, cnum, "-", cname)
        processed[idx].append((cset, cnum, cname))

In [None]:
import editdistance

DIST_THRESH = 7


card_groups = []
last_cidx = []
last_fidx = []

for cidx, (fidx, cards) in enumerate(processed.items()):
    # print(cidx, fidx, cards)
    for c1 in cards:
        if not card_groups:
            # card_groups.append(set([c1]))
            card_groups.append([c1])
            last_cidx.append(cidx)
            last_fidx.append(fidx)
            continue

        min_dist = 99
        best_group = None

        # min_mean = 99
        # best_mean = None

        for gidx, group in enumerate(card_groups):
            if cidx - last_cidx[gidx] > 2:
                continue
            group_mean = []
            for c2 in group:
                dist = 0
                for v1, v2, weight in zip(c1, c2, [2, 3, 4]):
                    dist += weight * (
                        min(1, editdistance.eval(v1, v2) / min(len(v1), len(v2)))
                        if v1 and v2
                        else 1
                    )
                # group_mean.append(dist)
                if dist < min_dist:
                    min_dist = dist
                    best_group = gidx

            # group_mean = statistics.mean(group_mean)
            # if group_mean < min_mean:
            #     min_mean = group_mean
            #     best_mean = gidx

        # min_dist += min((fidx - last_fidx[best_group]) / 30, 1)
        min_dist += (fidx - last_fidx[best_group]) / 30

        if min_dist < 5:
            print(gidx, min_dist)
            # card_groups[best_group].add(c1)
            card_groups[best_group].append(c1)
            last_cidx[best_group] = cidx
            last_fidx[best_group] = fidx
        else:
            print(len(card_groups), min_dist)
            # card_groups.append(set([c1]))
            card_groups.append([c1])
            last_cidx.append(cidx)
            last_fidx.append(fidx)

    # print(card_groups)
    # print()


pprint(card_groups)

In [None]:
one_card_groups = [g for g in card_groups if len(g) == 1 and not all(g[0])]
card_groups = [g for g in card_groups if g not in one_card_groups]

# print(one_card_groups)
assignments = []

for g1 in one_card_groups:
    # c1 = list(g1)[0]
    c1 = g1[0]

    best_dist = 99
    best_group = None

    for gidx, g2 in enumerate(card_groups):
        if g2 in one_card_groups:
            continue

        dist_min = 99
        dist_acc = 0

        for c2 in g2:
            dist = 0
            for v1, v2, weight in zip(c1, c2, [2, 3, 4]):
                dist += weight * (
                    min(1, editdistance.eval(v1, v2) / max(len(v1), len(v2)))
                    if v1 and v2
                    else 1
                )

            dist_acc += dist
            if dist < dist_min:
                dist_min = dist

        group_dist = dist_min + dist_acc / len(g2)
        if group_dist < best_dist:
            best_dist = group_dist
            best_group = gidx

    print(c1)
    print(best_group, best_dist)
    print(card_groups[best_group])
    print()

    # card_groups[best_group].add(c1)
    card_groups[best_group].append(c1)


pprint(card_groups)

In [None]:
from collections import Counter
from itertools import combinations

new_groups = []

for group in card_groups:
    csets, cnums, cnames = [
        [xx for xx in x if xx] for entries in zip(zip(*group)) for x in entries
    ]

    cset_counts = Counter(csets).most_common()
    cnum_counts = Counter(cnums).most_common()
    cname_counts = Counter(cnames).most_common()

    try:
        # TODO: filter using list of MTG sets
        cset_cand = cset_counts[0][0]
        cset = (
            cset_cand
            if len(cset_counts) == 1
            else (
                cset_cand
                if len(cset_cand) == 3 and cset_counts[0][1] > cset_counts[1][1]
                else csets
            )
        )
    except IndexError:
        cset = []

    try:
        cnum_cand = cnum_counts[0][0]
        cnum = (
            cnum_cand
            if len(cnum_counts) == 1
            else (
                cnum_cand
                if len(cnum_cand) == 4 and cnum_counts[0][1] > cnum_counts[1][1]
                else cnums
            )
        )
    except IndexError:
        cnum = []

    cname_cand = cname_counts[0][0]
    cname = (
        cname_cand
        if len(cname_counts) == 1
        else (cname_cand if cname_counts[0][1] > cname_counts[1][1] else cnames)
    )
    if isinstance(cname, list):
        dists = defaultdict(list)
        for n1, n2 in combinations(cname, 2):
            dist = editdistance.eval(n1, n2)
            dists[n1].append(dist)
            dists[n2].append(dist)

        dists = {n: sum(ds) for n, ds in dists.items()}
        cname = sorted(dists.items(), key=lambda x: x[1])[0][0]

    new_groups.append((cname, cnum, cset))

    # # print(cset)
    # # print(cnum)
    # # print(cname)

    # # break
    # print(cset)
    # print(cnum)
    # print(cname)
    # print()

# new_groups = [set(group) for group in new_groups]
pprint(new_groups)

In [None]:
import scrython
import nest_asyncio
import time
from itertools import product

nest_asyncio.apply()


for cname, cnum, cset in new_groups:
    # if cnum != "0266":
    #     continue
    # print(cname, cnum, cset)
    cnums = [cnum] if isinstance(cnum, str) else cnum
    csets = [cset] if isinstance(cset, str) else cset

    cnums = [str(int(cn)) for cn in cnums]
    csets = [code.lower() for code in csets if len(code) == 3]

    card = None
    best_dist = 99
    for code, cn in product(csets, cnums):
        try:
            time.sleep(0.1)
            candidate = scrython.cards.Collector(code=code, collector_number=cn)
        except scrython.ScryfallError:
            continue

        if candidate.name() == cname:
            card = candidate
            break

        dist = editdistance.eval(cname, candidate.name())
        if dist < best_dist:
            best_dist = dist
            card = candidate

    if card is None:
        set_query = " OR ".join([f"set:{code}" for code in csets])
        num_query = " OR ".join([f"cn:{num}" for num in cnums])

        try:
            query = f"{cname} ({set_query}) ({num_query})"
            results = scrython.cards.Search(q=query).data()
        except scrython.ScryfallError:
            query = f"{cname} ({set_query})"
            results = scrython.cards.Search(q=query).data()

        card = scrython.cards.Id(id=results[0]["id"])

    print(card.set_code().upper(), card.collector_number().zfill(4), card.name())