In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import time
from dataclasses import dataclass
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
from numpy.typing import NDArray
import torch
from PIL import Image
from rich.pretty import pprint as pp
from transformers.modeling_utils import load_state_dict

from libs.lizi.my_magi import MyMagiModel
from libs.lizi.my_magi.config import MagiConfig
from libs.lizi.my_magi.utils import read_image_as_np_array as read_image

In [3]:
images = [read_image(str(image)) for image in Path("data/manga2/").glob("*.jpg")]

### Модель

In [4]:
state_dict = load_state_dict(str(Path("models/magi/pytorch_model.bin").resolve()))
state_dict.keys()

dict_keys(['ocr_model.encoder.embeddings.cls_token', 'ocr_model.encoder.embeddings.position_embeddings', 'ocr_model.encoder.embeddings.patch_embeddings.projection.weight', 'ocr_model.encoder.embeddings.patch_embeddings.projection.bias', 'ocr_model.encoder.encoder.layer.0.attention.attention.query.weight', 'ocr_model.encoder.encoder.layer.0.attention.attention.key.weight', 'ocr_model.encoder.encoder.layer.0.attention.attention.value.weight', 'ocr_model.encoder.encoder.layer.0.attention.output.dense.weight', 'ocr_model.encoder.encoder.layer.0.attention.output.dense.bias', 'ocr_model.encoder.encoder.layer.0.intermediate.dense.weight', 'ocr_model.encoder.encoder.layer.0.intermediate.dense.bias', 'ocr_model.encoder.encoder.layer.0.output.dense.weight', 'ocr_model.encoder.encoder.layer.0.output.dense.bias', 'ocr_model.encoder.encoder.layer.0.layernorm_before.weight', 'ocr_model.encoder.encoder.layer.0.layernorm_before.bias', 'ocr_model.encoder.encoder.layer.0.layernorm_after.weight', 'ocr_mo

In [5]:
config: MagiConfig = MagiConfig.from_json_file(Path("libs/lizi/my_magi/config.json").resolve())  # type: ignore
model = MyMagiModel(config)

In [6]:
model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=[], unexpected_keys=['ocr_model.encoder.embeddings.cls_token', 'ocr_model.encoder.embeddings.position_embeddings', 'ocr_model.encoder.embeddings.patch_embeddings.projection.weight', 'ocr_model.encoder.embeddings.patch_embeddings.projection.bias', 'ocr_model.encoder.encoder.layer.0.attention.attention.query.weight', 'ocr_model.encoder.encoder.layer.0.attention.attention.key.weight', 'ocr_model.encoder.encoder.layer.0.attention.attention.value.weight', 'ocr_model.encoder.encoder.layer.0.attention.output.dense.weight', 'ocr_model.encoder.encoder.layer.0.attention.output.dense.bias', 'ocr_model.encoder.encoder.layer.0.intermediate.dense.weight', 'ocr_model.encoder.encoder.layer.0.intermediate.dense.bias', 'ocr_model.encoder.encoder.layer.0.output.dense.weight', 'ocr_model.encoder.encoder.layer.0.output.dense.bias', 'ocr_model.encoder.encoder.layer.0.layernorm_before.weight', 'ocr_model.encoder.encoder.layer.0.layernorm_before.bias', 'ocr_model.encoder.encoder

In [7]:
model.cuda() # type: ignore

MyMagiModel(
  (crop_embedding_model): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTMAELayer(
          (attention): ViTMAEAttention(
            (attention): ViTMAESelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=768, out_features=

### Проверка

In [8]:
with torch.no_grad():
    results = model.predict_detections_and_associations(images)





In [None]:
pp(results[0])

In [None]:
for i in range(len(images)):
    model.visualise_single_image_prediction(images[i], results[i], filename=f"image_{i}.png")


# проверка