## MangaLineExtraction_Pytorch

_This is an interactive demo of the paper ["Deep Extraction of Manga Structural Lines"](https://www.cse.cuhk.edu.hk/~ttwong/papers/linelearn/linelearn.html)_

Firstly run the follwing cell to get the enviornment set up. Please ensure you have the GPU runtime setting set to "on".

In [None]:
#@title Environment setup

%cd ~
! git clone https://github.com/kiligon/MangaLineExtraction_PyTorch.git
%cd MangaLineExtraction_PyTorch
! wget -O erika.pth https://github.com/ljsabc/MangaLineExtraction_PyTorch/releases/download/v1/erika.pth


import torch
import cv2

from google.colab import files
import os
import numpy as np
from google.colab.patches import cv2_imshow

from src.model import MangaLineExtractor
from src.helper import remap_state_dict_keys

model = MangaLineExtractor()
model.load_state_dict(remap_state_dict_keys(torch.load('erika.pth')))

model.cuda();
model.eval();

print("Setup Complete")

### Test with your own image

After the environment setup, run this cell to test with your own image. When the file upload button emerge in the output, select any picture from your local device and wait for the code to run. The output will be shown on the bottom. 

Right click on the result to save the output. Re-run this cell to upload and process again for a new round.

In [None]:
#@title File upload and processing

uploaded = files.upload()
outputLoc = None
with torch.no_grad():
    for imname in uploaded.keys():
        srcc = cv2.imread(imname)
        print("Original Image:")
        cv2_imshow(srcc)

        src = cv2.imread(imname, cv2.IMREAD_GRAYSCALE)
                
        rows = int(np.ceil(src.shape[0]/16))*16
        cols = int(np.ceil(src.shape[1]/16))*16
        
        # manually construct a batch. You can change it based on your usecases. 
        patch = np.ones((1, 1, rows, cols), dtype="float32")
        patch[0, 0, 0 : src.shape[0], 0 : src.shape[1]] = src

        tensor = torch.from_numpy(patch).cuda()
        y = model(tensor)
        print(imname, torch.max(y), torch.min(y))

        yc = y.cpu().numpy()[0,0,:,:]
        yc[yc>255] = 255
        yc[yc<0] = 0

        head, tail = os.path.split(imname)
        if not os.path.exists("output"):
            os.mkdir("output")

        print("Output Image:")
        output = yc[0:src.shape[0],0:src.shape[1]]
        cv2_imshow(output)

        outputLoc = "output/"+tail.replace(".jpg",".png")
        cv2.imwrite(outputLoc,output)