In [None]:
#@title Flat-bug demo

#@markdown First make sure that you are using a GPU instance by click on `Runtime` > `Change runtime type` > `T4 GPU` or any other instance with `GPU` must be selected.

#@markdown Then press the grey arrow on the left of this cell to start the installation.

#@markdown Once the installation is complete (it should take less than a minute), the app should start. You can also click on the link to open it in a new tab.

# flat-bug install

# Install dependencies
!pip install gradio rawpy

# clone the repo
!git clone https://github.com/darsa-group/flat-bug.git --branch main --single-branch flat-bug/

# Bug fix: flat-bug requires python >= 3.11 and colab is currently running with Python = 3.10
# This was easier to change flat-bug change colab.

import re

def find_and_replace_in_file(file_path, search_pattern, replacement_text):
    """
    Find and replace text in a Python (.py) file using regex.

    :param file_path: Path to the .py file
    :param search_pattern: Regex pattern to search for
    :param replacement_text: Replacement text
    """
    try:
        # Read the file content
        with open(file_path, 'r', encoding='utf-8') as file:
            content = file.read()

        # Replace using regex
        updated_content = re.sub(search_pattern, replacement_text, content)

        # Write the updated content back to the file
        with open(file_path, 'w', encoding='utf-8') as file:
            file.write(updated_content)

        print(f"Replaced text in '{file_path}' successfully.")
    except FileNotFoundError:
        print(f"File '{file_path}' not found.")
    except Exception as e:
        print(f"An error occurred: {e}")

find_and_replace_in_file('flat-bug/pyproject.toml', r'requires-python = ">=3.11"', 'requires-python = ">=3.10"')
self_replace = "from typing_extensions import Self"
find_and_replace_in_file('flat-bug/src/flat_bug/predictor.py',
                         r'from typing import Any, List, Optional, Self, Tuple, Union',
                         f'from typing import Any, List, Optional, Tuple, Union\n{self_replace}')
find_and_replace_in_file('flat-bug/src/flat_bug/augmentations.py',
                         r'from typing import Dict, List, Optional, Self, Tuple, Union',
                         f'from typing import Dict, List, Optional, Tuple, Union\n{self_replace}')
find_and_replace_in_file('flat-bug/src/flat_bug/trainers.py',
                         r'from typing import Any, Dict, List, Optional, Self, Tuple, Union',
                         f'from typing import Any, Dict, List, Optional, Tuple, Union\n{self_replace}')
find_and_replace_in_file('flat-bug/src/flat_bug/datasets.py',
                         r'from typing import Dict, List, Optional, Self, Tuple, Union',
                         f'from typing import Dict, List, Optional, Tuple, Union\n{self_replace}')

# install the package
!pip install -e flat-bug

# fix the package path to sys.path
import sys
sys.path.append("/content/flat-bug/src")

# Localization implementation

import os, glob, json, io, zipfile, base64, uuid, re, tempfile

from urllib.request import urlretrieve
from copy import deepcopy
from typing import List, Tuple, Union, Optional
from tqdm import tqdm

import numpy as np
import torch

import rawpy
from PIL import Image

from flat_bug.predictor import Predictor, TensorPredictions

# Input parsing
IMG_REGEX = re.compile(r'\.(jp[e]{0,1}g|png|dng)$', re.IGNORECASE)

def is_image(file_path):
    return bool(re.search(IMG_REGEX, file_path)) and os.path.isfile(file_path)

def is_txt(file_path):
    return file_path.endswith('.txt') and os.path.isfile(file_path)

def is_dir(file_path):
    return os.path.isdir(file_path)

def is_glob(file_path):
    return not (is_image(file_path) or is_dir(file_path))

def type_of_path(file_path):
    if is_image(file_path):
        return 'image'
    elif is_txt(file_path):
        return 'txt'
    elif is_dir(file_path):
        return 'dir'
    elif is_glob(file_path):
        return 'glob'
    else:
        return 'unknown'

def get_images(input_path_dir_globs : Union[str, List[str]]) -> List[str]:
    if isinstance(input_path_dir_globs, str):
        input_path_dir_globs = [input_path_dir_globs]
    images = []
    for path in input_path_dir_globs:
        match type_of_path(path):
            case 'image':
                images.append(path)
            case 'txt':
                with open(path, 'r') as f:
                    paths = [path.strip() for path in f.readlines() if len(path.strip()) > 0]
                images.extend(get_images(paths))
            case 'dir':
                images.extend(glob.glob(os.path.join(path, '*')))
            case 'glob':
                images.extend(glob.glob(path))
            case _:
                raise ValueError(f"Unknown path type: {path}")
    if len(images) == 0:
        raise ValueError("No images found")
    return images

# Image processing
class ImageEncoder(json.JSONEncoder):
    def default(self, o):
        if isinstance(o, Base64Image):
            return str(o)
        else:
            return super().default(o)

class Base64Image:
    def __init__(self, path : str):
        self.path = path
        self.bytes = base64.b64encode(open(path, "rb").read())
        self.str = self.bytes.decode("ascii")

    def __str__(self):
        return self.str

    def __bytes__(self):
        return self.bytes

    def __repr__(self) -> str:
        return f"Base64Image({self.path})"

def parse_image(images : Optional[Union[np.ndarray, bytes, str, Union[List[Union[np.ndarray, bytes, str]], Tuple[Union[np.ndarray, bytes, str]]]]], device : Union[torch.device, str]="cpu"):
    # Cases:
    # List: Recursively parse each image
    if isinstance(images, (list, tuple)):
        return [parse_image(image, device) for image in images]
    # String: Open the image with PIL
    elif isinstance(images, str):
        if re.search(re.compile("\.dng$", re.IGNORECASE), images):
            with rawpy.imread(images) as raw:
                images = raw.postprocess()
                images = Image.fromarray(images)
        else:
            images = Image.open(images)
    # Bytes: Open the image with PIL using a BytesIO
    elif isinstance(images, bytes):
        images = Image.open(io.BytesIO(images))
    # Numpy array: Do nothing
    elif isinstance(images, np.ndarray):
        pass
    # Other: Raise an error
    else:
        raise ValueError(f"Expected image(s) to be a np.ndarray, string or bytes, or list of these, but got {type(images)}")

    # Convert the image to a numpy array
    image = np.array(images)

    # Convert the image to a torch tensor and change from HWC to CHW
    return torch.from_numpy(image).permute(2, 0, 1).to(device)

# General file handling
def generate_uuid() -> str:
    return str(uuid.uuid4())[::3]

def save_file(content : str, name : str, dir : str, ext : str, identifier : Optional[str]=None, dtype : str="text") -> str:
    # Is the data raw bytes or text?
    if "text" in dtype:
        dtype = ""
    elif "byte" in dtype:
        dtype = "b"
    # If the UUID is not specified, generate a new one
    if identifier is None:
        identifier = generate_uuid()
    # Construct the path
    path = f"{dir}/{identifier}_{name}.{ext}"
    # Dump the content to the file
    with open(path, f"w{dtype}") as f:
        f.write(content)
    # Return the path
    return path

def zip_files(files : List[str], name : str, dir : str, identifier : Optional[str]=None) -> str:
    # If the UUID is not specified, generate a new one
    if identifier is None:
        identifier = generate_uuid()
    # Construct the path
    path = f"{dir}/{identifier}_{name}.zip"
    # Open the zip file
    with zipfile.ZipFile(path, "w") as z:
        # Add all the files
        for file in files:
            if isinstance(file, Base64Image):
                z.write(file.path)
            else:
                z.write(file)
    # Return the path
    return path

# Model definition
class Localizer(Predictor):
    def predict(self, images : Optional[Union[np.ndarray, bytes, str, Union[List[Union[np.ndarray, bytes, str]], Tuple[Union[np.ndarray, bytes, str]]]]], do_plot : bool | List[bool]=False, include_crops : bool=False, outdir : str="output") -> dict:
        # Initialize the data
        data = {
            "uuids": [],
            "predictions": [],
            "crops" : [],
            "visualizations": []
        }

        if not isinstance(images, (list, tuple)):
            images = [images]
        if not isinstance(do_plot, list):
            if isinstance(do_plot, tuple):
                do_plot = list(do_plot)
            else:
                do_plot = [do_plot]
            if len(do_plot) == 1 and len(images) > 1:
                do_plot = do_plot * len(images)
        if not all([isinstance(plot, bool) for plot in do_plot]):
            raise ValueError(f"Expected do_plot to be a boolean or list of booleans, but got {do_plot}")

        for i, image in enumerate(tqdm(images, desc="Localizing insects", unit="image", leave=True)):
            if isinstance(image, str):
                image_identifier = os.path.splitext(os.path.basename(image))[0]
            else:
                image_identifier = "DUMMY"
            # Fetch the image
            image = parse_image(image, self._device)

            # Generate uuid for the image
            identifier = generate_uuid()

            # Create the output directory
            this_outdir = os.path.join(outdir, identifier)
            if not os.path.exists(this_outdir):
                os.makedirs(this_outdir)

            # Run the model
            predictions : TensorPredictions = self.pyramid_predictions(image, "DUMMY_PATH_STR", scale_before=1)

            # Save crops
            predictions.save_crops(outdir=this_outdir, basename=image_identifier, mask=True, identifier=identifier)
            crops = [Base64Image(crop) if include_crops else crop for crop in glob.glob(os.path.join(this_outdir, f"crop*{identifier}.png"))]

            # Plot the image if requested
            if do_plot[i]:
                visualization_dir = os.path.join(os.path.dirname(this_outdir), "visualization")
                if not os.path.exists(visualization_dir):
                    os.makedirs(visualization_dir)
                predict_image = os.path.join(visualization_dir, f'{identifier}_visualization.jpg')
                predictions.plot(outpath=predict_image, scale=1/2)
                # base64_image = Base64Image(predict_image)
            else:
                # base64_image = None
                predict_image = None

            # Append the data
            data["uuids"].append(identifier)
            data["visualizations"].append(predict_image)
            data["crops"].append(crops)
            data["predictions"].append(predictions.json_data)

        # Return the predictions as JSON
        return data

def get_defaults():
    # Define the model parameters
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    dtype = torch.float16

    return device, dtype

def main(args : dict):
    # Get the defaults
    device, dtype = get_defaults()

    # Update the parameters
    image_paths = get_images(args["input"])

    if "device" in args and args["device"] is not None:
        device = torch.device(args["device"])

    if "dtype" in args and args["dtype"] is not None:
        dtype = getattr(torch, args["dtype"])

    if not "output" in args or args["output"] is None:
        outdir = tempfile.mkdtemp()
    else:
        outdir = args["output"]
        if not os.path.exists(outdir):
            os.makedirs(outdir)

    # Create the model
    model = Localizer(model="flat_bug_L.pt", device=device, dtype=dtype)
    model.set_hyperparameters(SCORE_THRESHOLD=0.25, EDGE_CASE_MARGIN=32, MIN_MAX_OBJ_SIZE=(16, 768), TIME=False)

    # Run the model
    output = model.predict(image_paths, do_plot=args.get("plot", False), include_crops=False, outdir=outdir)
    output_json = save_file(json.dumps(output, cls=ImageEncoder), name="instances", dir=outdir, ext="json", dtype="text")

    # Print the result path
    print(f"Results have saved to {os.path.abspath(outdir)}")

    return output


#-------------------------------------------------------------------------------
# Gradio app.


import os, json

from typing import List, Tuple, Optional, Union

import numpy as np
import gradio as gr

# from localize import Localizer, ImageEncoder, save_file, zip_files, get_defaults, Base64Image

# def html_base64_img(b64: Union[str, bytes]) -> str:
#     if isinstance(b64, bytes):
#         b64 = b64.decode("ascii")
#     return f'<img src="data:image/jpeg;base64,{b64}"/>'

def combine(*args):
    out = []
    [out.extend(a) if isinstance(a, list) else out.append(a) for a in args if a is not None]
    return out

# Define the postprocessing function
def postprocess(out : dict) -> List[Tuple[str, str]]:
    """
    Postprocess the output of the model.

    Arguments:
        out: The output of the model.

    Returns: A list of tuples with the JSON string and paths to saved images.
    """
    n = len(out["predictions"])
    json_files = [save_file(json.dumps([out[k][i] for k in out], cls=ImageEncoder), name="instances", dir="output", ext="json", identifier=out["uuids"][i], dtype="text") for i in range(n)]
    return (
        json_files,
        [zip_files(combine(out["crops"][i], out["visualizations"][i], json_files[i]), name="combined", dir="output", identifier=out["uuids"][i]) for i in range(n)],
        [out["visualizations"][i].path if isinstance(out["visualizations"][i], Base64Image) else out["visualizations"][i] for i in range(n)],
        [[crop.path if isinstance(crop, Base64Image) else crop for crop in out["crops"][i]] for i in range(n)]
    )

with gr.Blocks() as demo:
    device, dtype = get_defaults()

    # Create a model loader
    def get_model():
        return Localizer(model="flat_bug_L.pt", device=device, dtype=dtype)

    # Load the model in the app state
    model = gr.State(get_model)

    # Define the localization function
    def localize(images : Optional[Union[np.ndarray, bytes, str, Union[List[Union[np.ndarray, bytes, str]], Tuple[Union[np.ndarray, bytes, str]]]]],
                #  do_plot : bool=False) -> Tuple[List[str], List[str]]:
                 ) -> Tuple[List[str], List[str]]:
        if not os.path.exists("output"):
            os.makedirs("output")
        # predictions = model.value().predict(images, do_plot=do_plot, include_crops=True, outdir="output")
        predictions = model.value().predict(images, do_plot=True, include_crops=True, outdir="output")
        return postprocess(predictions)


    # Define the input-output format
    file_input = gr.Image(value="example_image1.jpg", label="Input image")
    # checkbox = gr.Checkbox(label="Return annotated image")
    output_json = gr.File(label="Output JSON")
    output_zip = gr.File(label="Output ZIP")
    output_image = gr.Image(label="Annotated Image", format="jpg")
    output_crops = gr.Gallery(label="Cropped insects", format="png")
    # test_output = gr.HTML(label="test output")

    gr.Interface(
        fn=localize,
        inputs=[file_input],
        outputs=[output_json, output_zip, output_image, output_crops],
        title="Perform localization on a single image",
        # description="This model localizes bugs in images. Upload an image and it will return the localization as a JSON file and optionally an annotated image.",
        description="This model localizes bugs in images. Upload an image and it will return the localization as a JSON file and an annotated image.",
        batch=True
    )

from IPython.display import clear_output
clear_output(wait=True)
demo.launch(inbrowser=True)
