### Libraries

In [1]:
from pathlib import Path

import huggingface_hub
import numpy as np
import torch
from anime_segmentation import get_model as get_anime_segmentation_model
from diffusers.schedulers import UniPCMultistepScheduler
from libs.colouranga.from_magi_model import MyMagiModel
from libs.colouranga.from_magi_model.config import MagiConfig
from libs.colouranga.utils import (
    character_segment,
    color_inversion,
    creating_pairs,
    finding_samples,
    get_embeddings,
    get_line_art,
    original_bboxes_compare,
    prepreparing_embeddings,
    sample_img,
    upload_pages,
)
from PIL import Image
from stable_diffusion_reference_only.pipelines.pipeline_stable_diffusion_reference_only import (
    StableDiffusionReferenceOnlyPipeline,
)
from transformers.modeling_utils import load_state_dict

### Model

In [2]:
print(f"Is CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

Is CUDA available: True


In [3]:
automatic_coloring_pipeline = StableDiffusionReferenceOnlyPipeline.from_pretrained(
    "AisingioroHao0/stable-diffusion-reference-only-automatic-coloring-0.1.2"
).to(device)



Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

In [4]:
automatic_coloring_pipeline.scheduler = UniPCMultistepScheduler.from_config(
    automatic_coloring_pipeline.scheduler.config
)

In [5]:
segment_model = get_anime_segmentation_model(
    model_path=huggingface_hub.hf_hub_download("skytnt/anime-seg", "isnetis.ckpt")
).to(device)

In [6]:
# Model initialization
state_dict = load_state_dict(str(Path("models/magi/pytorch_model.bin").resolve()))
state_dict.keys()
config: MagiConfig = MagiConfig.from_json_file(Path("libs/colouranga/from_magi_model/config.json").resolve())  # type: ignore
model = MyMagiModel(config)
model.load_state_dict(state_dict, strict=False)
model.cuda() # type: ignore

MyMagiModel(
  (crop_embedding_model): ViTMAEModel(
    (embeddings): ViTMAEEmbeddings(
      (patch_embeddings): ViTMAEPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
    )
    (encoder): ViTMAEEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTMAELayer(
          (attention): ViTMAESdpaAttention(
            (attention): ViTMAESdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTMAESelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTMAEIntermediate(
            (dense): Linear(in_features=768, out_f

### code

In [7]:
directory_path = "data/x_manga"

In [8]:
directory_path_samples = "data/ex_samples"

In [9]:
destination_folder = "data/ex_colour"

In [10]:
def colorization_pipeline(input_path: str, samples_path: str, output_path: str) -> None:
    """Colorizes all characters from manga pages.

    Args:
        input_path: directory that contains monochrome manga pages
        samples_path: directory that contains colorized samples of manga pages
        output_path: where to output colorized images of characters
    """
    my_pages = upload_pages(input_path)
    list_of_bboxes, original_emb = get_embeddings(model, my_pages)
    list_original_embeddings = prepreparing_embeddings(original_emb)
    my_comp_list = original_bboxes_compare(list_original_embeddings)
    my_crop_embeddings_sample, my_images_color_for_analysis = sample_img(model, samples_path)
    my_result_dict, my_sample_dict = finding_samples(my_crop_embeddings_sample, my_comp_list)
    final_character_list = creating_pairs(
        my_images_color_for_analysis, my_sample_dict, list_of_bboxes, my_result_dict
    )

    for elem in final_character_list:
        np_blue = elem.crop_image_bbox
        segmented_blue = character_segment(segment_model, np_blue)
        line_blue = get_line_art(segmented_blue)
        ready_blue = color_inversion(line_blue)

        colour_img = Image.open(elem.full_sample_file_name).convert("RGB")
        np_prompt = np.array(colour_img)
        ready_prompt = character_segment(segment_model, np_prompt)

        torch.cuda.empty_cache()
        ready_image = automatic_coloring_pipeline(
            prompt=Image.fromarray(ready_prompt),
            blueprint=Image.fromarray(ready_blue),  # type: ignore
            num_inference_steps=20,
        )  # type: ignore
        new_filename = str(elem.crop_bboxes_coordinates)
        destination_path = Path(output_path) / new_filename

        ready_image.images[0].save(str(destination_path) + ".png")


In [None]:
colorization_pipeline(directory_path, directory_path_samples, destination_folder)