In [1]:
import asyncio
import concurrent.futures
import dataclasses
import re
import threading
from collections import Counter
from typing import Dict, List, Tuple, Union, cast

import cv2
import numpy as np
import yt_dlp
from fastapi import WebSocket

from poke_battle_logger.batch.data_builder import DataBuilder
from poke_battle_logger.batch.extractor import Extractor
from poke_battle_logger.batch.frame_compressor import (
    frame_compress,
    message_frame_compress,
)
from poke_battle_logger.batch.frame_detector import FrameDetector
from poke_battle_logger.batch.pokemon_extractor import PokemonExtractor
from poke_battle_logger.database.database_handler import DatabaseHandler
from poke_battle_logger.types import StatusByWebsocket
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import pytesseract
import editdistance
from poke_battle_logger.batch.pokemon_name_window_extractor import (
    EDIT_DISTANCE_THRESHOLD,
    PokemonNameWindowExtractor,
)

In [2]:
import os
os.environ["TESSDATA_PREFIX"] = "/opt/brew/Cellar/tesseract/5.3.0_1/share/tessdata_best/"

In [4]:
video_id = "q6llWtGqj8c"
language = "en"
trainer_id = 1

In [8]:
video = cv2.VideoCapture(f"video/{video_id}.mp4")

frame_detector = FrameDetector(language)
extractor = Extractor(language)
pokemon_extractor = PokemonExtractor()

first_ranking_frames = []
select_done_frames = []
standing_by_frames = []
level_50_frames = []
ranking_frames = []
win_or_lost_frames = []
message_window_frames = []

print("read_video")
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
for i in range(total_frames):
    ret, frame = video.read()
    if ret:
        # first ranking
        if frame_detector.is_first_ranking_frame(frame):
            first_ranking_frames.append(i)

        # select done
        if frame_detector.is_select_done_frame(frame):
            select_done_frames.append(i)

        # standing_by
        if frame_detector.is_standing_by_frame(frame):
            standing_by_frames.append(i)

        # level_50
        if frame_detector.is_level_50_frame(frame):
            level_50_frames.append(i)

        # ranking
        if frame_detector.is_ranking_frame(frame):
            ranking_frames.append(i)

        # win_or_lost
        if frame_detector.is_win_or_lost_frame(frame):
            win_or_lost_frames.append(i)

        # message window
        if frame_detector.is_message_window_frame(frame):
            message_window_frames.append(i)
    else:
        continue

# compress
print("compress")
compressed_first_ranking_frames = frame_compress(first_ranking_frames)
compressed_select_done_frames = frame_compress(select_done_frames)
compressed_standing_by_frames = frame_compress(
    standing_by_frames, ignore_short_frames=True
)
compressed_level_50_frames = frame_compress(level_50_frames)
compressed_ranking_frames = frame_compress(ranking_frames)
compressed_win_or_lost_frames = frame_compress(win_or_lost_frames)
compressed_message_window_frames = message_frame_compress(
    message_window_frames, frame_threshold=3
)

# 開始時のランクを検出(OCR)
print("rank1")
rank_numbers = {}
first_ranking_frame_number = compressed_first_ranking_frames[0][-5]
video.set(cv2.CAP_PROP_POS_FRAMES, first_ranking_frame_number - 1)
_, _first_ranking_frame = video.read()
rank_numbers[first_ranking_frame_number] = extractor.extract_first_rank_number(
    _first_ranking_frame
)

# ランクを検出(OCR)
print("rank2")
for ranking_frame_numbers in compressed_ranking_frames:
    ranking_frame_number = ranking_frame_numbers[-5]
    video.set(cv2.CAP_PROP_POS_FRAMES, ranking_frame_number - 1)
    _, _ranking_frame = video.read()
    rank_numbers[ranking_frame_number] = extractor.extract_rank_number(
        _ranking_frame
    )

# 順位が変動しなかった場合、その値を rank_numbers から削除する
rank_frames = list(rank_numbers.keys())
for i in range(len(rank_numbers) - 1):
    _ranking_frame_number = rank_frames[i]
    _next_ranking_frame_number = rank_frames[i + 1]

    if (
        rank_numbers[_ranking_frame_number]
        == rank_numbers[_next_ranking_frame_number]
    ):
        del rank_numbers[_ranking_frame_number]

# 対戦の始点と終点を定義する
print("start,end")
battle_start_end_frame_numbers: List[Tuple[int, int]] = []
rank_frames = list(rank_numbers.keys())
for i in range(len(compressed_standing_by_frames)):
    _standing_by_frames = compressed_standing_by_frames[i]
    _standing_by_frame_number = _standing_by_frames[-1]

    # チーム選択からの場合(最初の順位表示なし)
    if len(compressed_standing_by_frames) == len(rank_numbers):
        _ranking_frame = rank_frames[i]
    else:
        # バトルスタジアム入場(最初に表示された順位がある)からの場合
        _ranking_frame = rank_frames[i + 1]

    if _standing_by_frame_number < _ranking_frame:
        battle_start_end_frame_numbers.append(
            (_standing_by_frame_number, _ranking_frame)
        )

# ポケモンの選出順を抽出する
pokemon_select_order = {}
print("order")
for i in range(len(compressed_select_done_frames)):
    _select_done_frames = compressed_select_done_frames[i]
    _select_done_frame_number = _select_done_frames[-5]

    video.set(cv2.CAP_PROP_POS_FRAMES, _select_done_frame_number - 1)
    _, _select_done_frame = video.read()

    _pokemon_select_order = extractor.extract_pokemon_select_numbers(
        _select_done_frame
    )
    pokemon_select_order[_select_done_frame_number] = _pokemon_select_order

# 6vs6のポケモンを抽出する
print("6vs6")
pre_battle_pokemons: Dict[int, Dict[str, List[str]]] = {}
is_exist_unknown_pokemon_list1 = []
for i in range(len(compressed_standing_by_frames)):
    _standing_by_frames = compressed_standing_by_frames[i]
    if len(_standing_by_frames) == 1:
        continue
    _standing_by_frame_number = _standing_by_frames[-1]

    video.set(cv2.CAP_PROP_POS_FRAMES, _standing_by_frame_number - 1)
    _, _standing_by_frame = video.read()
    _standing_by_frame = cast(np.ndarray, _standing_by_frame)

    (
        your_pokemon_names,
        opponent_pokemon_names,
        _is_exist_unknown_pokemon,
    ) = pokemon_extractor.extract_pre_battle_pokemons(_standing_by_frame)
    pre_battle_pokemons[_standing_by_frame_number] = {
        "your_pokemon_names": your_pokemon_names,
        "opponent_pokemon_names": opponent_pokemon_names,
    }
    is_exist_unknown_pokemon_list1.append(_is_exist_unknown_pokemon)

# 対戦中のポケモンを抽出する
print("3vs3")
battle_pokemons: List[Dict[str, Union[str, int]]] = []
is_exist_unknown_pokemon_list2 = []
for level_50_frame_numbers in compressed_level_50_frames:
    _level_50_frame_number = level_50_frame_numbers[-1]
    video.set(cv2.CAP_PROP_POS_FRAMES, _level_50_frame_number - 1)
    _, _level_50_frame = video.read()

    (
        your_pokemon_name,
        opponent_pokemon_name,
        _is_exist_unknown_pokemon,
    ) = extractor.extract_pokemon_name_in_battle(_level_50_frame)

    battle_pokemons.append(
        {
            "frame_number": _level_50_frame_number,
            "your_pokemon_name": your_pokemon_name,
            "opponent_pokemon_name": opponent_pokemon_name,
        },
    )
    is_exist_unknown_pokemon_list2.append(_is_exist_unknown_pokemon)

if any(is_exist_unknown_pokemon_list1) or any(is_exist_unknown_pokemon_list2):
    raise ValueError

# 勝ち負けを検出
# 間違いやすいので、周辺最大10フレームを見て判断する。全て unknown の時は弾く
print("win lose")
win_or_lost = {}
for win_or_lost_frame_numbers in compressed_win_or_lost_frames:
    if len(win_or_lost_frame_numbers) > 3:
        _win_or_lost_results = []
        for idx in range(max(10, len(win_or_lost_frame_numbers))):
            _win_or_lost_frame_number = win_or_lost_frame_numbers[-1] - idx
            video.set(cv2.CAP_PROP_POS_FRAMES, _win_or_lost_frame_number - 1)
            _, _win_or_lost_frame = video.read()
            _win_or_lost_result = extractor.extract_win_or_lost(
                _win_or_lost_frame
            )
            _win_or_lost_results.append(_win_or_lost_result)

        _win_or_lost_results = [
            v for v in _win_or_lost_results if v != "unknown"
        ]
        win_or_lost_frame_number = win_or_lost_frame_numbers[-1]
        if len(_win_or_lost_results) == 0:
            win_or_lost[win_or_lost_frame_number] = "unknown"
            continue
        win_or_lost_result = Counter(_win_or_lost_results).most_common()[0][0]
        win_or_lost_frame_number = win_or_lost_frame_numbers[-1]
        win_or_lost[win_or_lost_frame_number] = win_or_lost_result

# メッセージの文字認識(OCR)
print("message")
messages = {}
for message_frame_numbers in compressed_message_window_frames:
    message_frame_number = message_frame_numbers[-1]
    video.set(cv2.CAP_PROP_POS_FRAMES, message_frame_number - 1)
    _, _message_frame = video.read()
    _message = extractor.extract_message(_message_frame)
    if _message is not None:
        messages[message_frame_number] = _message

# build formatted data
data_builder = DataBuilder(
    trainer_id=trainer_id,
    video_id=video_id,
    battle_start_end_frame_numbers=battle_start_end_frame_numbers,
    battle_pokemons=battle_pokemons,
    pre_battle_pokemons=pre_battle_pokemons,
    pokemon_select_order=pokemon_select_order,
    rank_numbers=rank_numbers,
    messages=messages,
    win_or_lost=win_or_lost,
)

(
    battles,
    battle_logs,
    modified_pre_battle_pokemons,
    modified_in_battle_pokemons,
    modified_messages,
) = data_builder.build()



read_video
compress
rank1
rank2
start,end
order
6vs6
3vs3
win lose
message


In [17]:
battles

[Battle(battle_id='b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2', trainer_id=1),
 Battle(battle_id='aa389327-657d-58dc-ba4d-0cd3dda7609c', trainer_id=1)]

In [57]:
import pandas as pd

# Convert list of InBattlePokemon and Message objects to dataframes
df_in_battle_pokemon = pd.DataFrame([obj.__dict__ for obj in modified_in_battle_pokemons])
df_messages = pd.DataFrame([obj.__dict__ for obj in modified_messages])

In [56]:
df_messages.head(50)

Unnamed: 0,battle_id,frame_number,message
0,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,3890,ま よ sent out ヘイ ラ ツ シ ヤ ゼ !
1,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,3985,Gol Gallade!
2,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,4289,Gallade used Leaf Blade!
3,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,4347,It's super effective!
4,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,4416,The opposing ヘイ ラ ツ シ ャ ヤ used Yawn!
5,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,4526,Gallade grew drowsy!
6,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,4908,"Gallade, come back!"
7,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,5067,"You're in charge, Rotom!"
8,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,5184,The opposing \-{ 5WZ/¥ used Yawn!
9,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,5295,Rotom grew drowsy!


In [58]:
# Add a 'next_frame_number' column to the battle log
df_in_battle_pokemon = df_in_battle_pokemon.sort_values(by=['battle_id', 'turn'])
df_in_battle_pokemon['next_frame_number'] = df_in_battle_pokemon.groupby('battle_id')['frame_number'].shift(-1)

# Join the message log to the battle log
df_messages2 = pd.merge_asof(
    df_messages.sort_values('frame_number'), 
    df_in_battle_pokemon.sort_values('frame_number'), 
    left_on='frame_number', 
    right_on='frame_number',
    by='battle_id',
    direction='backward'
)

In [53]:
df_messages2.head(30)

Unnamed: 0,battle_id,frame_number,message,turn,your_pokemon_name,opponent_pokemon_name,next_frame_number,fainted_pokemon_type
34,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,10732,Gallade fainted!,9.0,ロトム,ヘイラッシャ,11591.0,Your Pokemon Fainted
38,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,11788,The opposing ヘイ ラッ シャ fainted!,10.0,ロトム,ヘイラッシャ,12837.0,Opponent Pokemon Fainted
49,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,15575,The opposing テツ ノド ク ガ fainted!,13.0,ハバタクカミ,テツノドクガ,16150.0,Opponent Pokemon Fainted
99,aa389327-657d-58dc-ba4d-0cd3dda7609c,32993,The opposing コノ ヨ ザ ル fainted!,5.0,ハバタクカミ,コノヨザル,34451.0,Opponent Pokemon Fainted
118,aa389327-657d-58dc-ba4d-0cd3dda7609c,37780,Flutter Mane fainted!,9.0,ハバタクカミ,ウインディ,39315.0,Your Pokemon Fainted
123,aa389327-657d-58dc-ba4d-0cd3dda7609c,39666,The opposing ウイ ン デ ィ fainted!,10.0,ウォッシュロトム,ウインディ,40740.0,Opponent Pokemon Fainted
137,aa389327-657d-58dc-ba4d-0cd3dda7609c,44109,Gallade fainted!,13.0,エルレイド,ウルガモス,44556.0,Your Pokemon Fainted
140,aa389327-657d-58dc-ba4d-0cd3dda7609c,44893,Rotom fainted!,14.0,ウォッシュロトム,ウルガモス,,Your Pokemon Fainted


In [59]:
# Find fainted pokemon messages
df_messages2['fainted_pokemon_type'] = None
df_messages2.loc[df_messages2.message.str.contains('.* fainted!'), 'fainted_pokemon_type'] = 'Your Pokemon Fainted'
df_messages2.loc[df_messages2.message.str.contains('The opposing .* fainted!'), 'fainted_pokemon_type'] = 'Opponent Pokemon Fainted'

In [60]:
# Keep only rows with fainted pokemon
df_messages2 = df_messages2.dropna(subset=['fainted_pokemon_type'])

# Join fainted pokemon messages to the battle log
df_in_battle_pokemon = df_in_battle_pokemon.merge(
    df_messages2[['battle_id', 'turn', 'fainted_pokemon_type']], 
    on=['battle_id', 'turn'],
    how='left'
)

In [52]:
df_messages2

Unnamed: 0,battle_id,frame_number,message,turn,your_pokemon_name,opponent_pokemon_name,next_frame_number,fainted_pokemon_type
34,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,10732,Gallade fainted!,9.0,ロトム,ヘイラッシャ,11591.0,Your Pokemon Fainted
38,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,11788,The opposing ヘイ ラッ シャ fainted!,10.0,ロトム,ヘイラッシャ,12837.0,Opponent Pokemon Fainted
49,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,15575,The opposing テツ ノド ク ガ fainted!,13.0,ハバタクカミ,テツノドクガ,16150.0,Opponent Pokemon Fainted
99,aa389327-657d-58dc-ba4d-0cd3dda7609c,32993,The opposing コノ ヨ ザ ル fainted!,5.0,ハバタクカミ,コノヨザル,34451.0,Opponent Pokemon Fainted
118,aa389327-657d-58dc-ba4d-0cd3dda7609c,37780,Flutter Mane fainted!,9.0,ハバタクカミ,ウインディ,39315.0,Your Pokemon Fainted
123,aa389327-657d-58dc-ba4d-0cd3dda7609c,39666,The opposing ウイ ン デ ィ fainted!,10.0,ウォッシュロトム,ウインディ,40740.0,Opponent Pokemon Fainted
137,aa389327-657d-58dc-ba4d-0cd3dda7609c,44109,Gallade fainted!,13.0,エルレイド,ウルガモス,44556.0,Your Pokemon Fainted
140,aa389327-657d-58dc-ba4d-0cd3dda7609c,44893,Rotom fainted!,14.0,ウォッシュロトム,ウルガモス,,Your Pokemon Fainted


In [61]:
# Add a 'fainted_pokemon_side' column
df_in_battle_pokemon['fainted_pokemon_side'] = 'Unknown'
df_in_battle_pokemon.loc[df_in_battle_pokemon.fainted_pokemon_type == 'Your Pokemon Fainted', 'fainted_pokemon_side'] = 'Opponent Pokemon Win'
df_in_battle_pokemon.loc[df_in_battle_pokemon.fainted_pokemon_type == 'Opponent Pokemon Fainted', 'fainted_pokemon_side'] = 'Your Pokemon Win'

In [40]:
df_in_battle_pokemon

Unnamed: 0,battle_id,turn,frame_number,your_pokemon_name,opponent_pokemon_name,next_frame_number,fainted_pokemon_type,fainted_pokemon_side
0,aa389327-657d-58dc-ba4d-0cd3dda7609c,1,26493,エルレイド,ウインディ,28206.0,,Unknown
1,aa389327-657d-58dc-ba4d-0cd3dda7609c,2,28206,ウォッシュロトム,ウインディ,29350.0,,Unknown
2,aa389327-657d-58dc-ba4d-0cd3dda7609c,3,29350,ウォッシュロトム,コノヨザル,31356.0,,Unknown
3,aa389327-657d-58dc-ba4d-0cd3dda7609c,4,31356,ウォッシュロトム,コノヨザル,32811.0,,Unknown
4,aa389327-657d-58dc-ba4d-0cd3dda7609c,5,32811,ハバタクカミ,コノヨザル,34451.0,,Unknown
5,aa389327-657d-58dc-ba4d-0cd3dda7609c,6,34451,ハバタクカミ,ウインディ,35064.0,Opponent Pokemon Fainted,Your Pokemon Win
6,aa389327-657d-58dc-ba4d-0cd3dda7609c,7,35064,ハバタクカミ,ウインディ,36694.0,,Unknown
7,aa389327-657d-58dc-ba4d-0cd3dda7609c,8,36694,ハバタクカミ,ウインディ,37607.0,,Unknown
8,aa389327-657d-58dc-ba4d-0cd3dda7609c,9,37607,ハバタクカミ,ウインディ,39315.0,,Unknown
9,aa389327-657d-58dc-ba4d-0cd3dda7609c,10,39315,ウォッシュロトム,ウインディ,40740.0,Your Pokemon Fainted,Opponent Pokemon Win


In [63]:
# Filter for a specific battle
battle_id = 'b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2'
df_battle = df_in_battle_pokemon[df_in_battle_pokemon.battle_id == battle_id]
df_battle.query("fainted_pokemon_side != 'Unknown'")

Unnamed: 0,battle_id,turn,frame_number,your_pokemon_name,opponent_pokemon_name,next_frame_number,fainted_pokemon_type,fainted_pokemon_side
22,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,9,9983,ロトム,ヘイラッシャ,11591.0,Your Pokemon Fainted,Opponent Pokemon Win
23,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,10,11591,ロトム,ヘイラッシャ,12837.0,Opponent Pokemon Fainted,Your Pokemon Win
26,b6a566d3-2707-5e3b-9c9b-97a8e4d54ec2,13,15357,ハバタクカミ,テツノドクガ,16150.0,Opponent Pokemon Fainted,Your Pokemon Win
