In [9]:
import sqlite3
import pandas as pd
import numpy as np 
import torchvision
import torch
from PIL import Image
import asyncio
import io
import os
import json
from torchvision import datasets, transforms

from src.constants import PROJECT_ROOT

OUTPUT_DIR = PROJECT_ROOT / 'data/datasets/dataset_ver2_all'        # Директория для сохранения изображений и
db_p = 'C:\\work\\WestTrade\\SecondStep\\data\\raw_data\\so_deep.db'
OUTPUT_DIR

WindowsPath('C:/work/WestTrade/SecondStep/data/datasets/dataset_ver2_all')

In [6]:
def mark_id_to_class(wtm_mark_id):
    if 11 <= wtm_mark_id <= 14:
        return 1
    elif 21 <= wtm_mark_id <= 26:
        return 2
    elif 41 <= wtm_mark_id <= 46:
        return 3
    elif wtm_mark_id >= 50:
        return 4
    else:
        return None  # Игнорируем метки вне интересующих классов


In [7]:
conn = sqlite3.connect(db_p)
cursor = conn.cursor()

if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

marker_periods_query = """
WITH marker_groups AS (
    SELECT
        wtm_ed_uuid,
        wtm_mark_id,
        wtm_focus,
        wtm_wtl_id,
        wtm_wtl_id - ROW_NUMBER() OVER (
            PARTITION BY wtm_ed_uuid, wtm_mark_id, wtm_focus
            ORDER BY wtm_wtl_id
        ) AS grp
    FROM
        well_timeline_marker
)
SELECT
    wtm_ed_uuid,
    wtm_mark_id,
    wtm_focus,
    MIN(wtm_wtl_id) AS start_frame,
    MAX(wtm_wtl_id) AS end_frame
FROM
    marker_groups
GROUP BY
    wtm_ed_uuid,
    wtm_mark_id,
    wtm_focus,
    grp
ORDER BY
    wtm_ed_uuid,
    start_frame;
"""
cursor.execute(marker_periods_query)
marker_periods = cursor.fetchall()

In [8]:
for period in marker_periods:
    wtm_ed_uuid, wtm_mark_id, wtm_focus, start_frame, end_frame = period
    class_label = mark_id_to_class(wtm_mark_id)

    if class_label is None:
        continue  # Пропускаем метки вне интересующих классов

    frames_query = """
    SELECT wtf_wtl_id, wtf_frame
    FROM well_timeline_frames
    WHERE wtf_ed_uuid = ?
      AND wtf_wtl_id BETWEEN ? AND ?
      AND wtf_rel_focus = ?
    ORDER BY wtf_wtl_id;
    """
    cursor.execute(frames_query, (wtm_ed_uuid, start_frame, end_frame, wtm_focus))
    frames = cursor.fetchall()

    for wtf_wtl_id, wtf_frame in frames:
        # Сохраняем изображение
        image = Image.open(io.BytesIO(wtf_frame))

        embryo_dir = os.path.join(OUTPUT_DIR, str(class_label), wtm_ed_uuid)
        if not os.path.exists(embryo_dir):
            os.makedirs(embryo_dir)

        image_path = os.path.join(embryo_dir, f'{wtf_wtl_id}_{wtm_focus}.png')
        image.save(image_path)
