# Crop and Segment

Using a pre-trained model, crop into our image and produce segmented images for classification!

In [9]:
from pathlib import Path

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow import keras
from tensorflow.keras.preprocessing.image import load_img

In [7]:
IMG_WIDTH_ORIGINAL=512
IMG_HEIGHT_ORIGINAL=1024
IMG_WIDTH = 256
IMG_HEIGHT = 512
IMG_CHANNELS = 1 # grayscale images
IMG_SIZE = (IMG_HEIGHT, IMG_WIDTH)
NUM_CLASSES = 8 #8 in case just fluids is 4, in case fluids and layers is 8
BATCH_SIZE = 4 # try 4, 8, 12, 16, 32

SRF_CLASS = 6
IRF_CLASS = 7

CLASS_LABELS = (
    "Above ILM",
    "ILM-IPL/INL",
    "IPL/INL-RPE",
    "RPE-BM",
    "Under BM",
    "PED",
    "SRF",
    "IRF",
)

In [5]:
model = keras.models.load_model("./oct_model_20230125-092315")

In [6]:
OCTID_DATA = Path("./data/OCTID/")

normal = OCTID_DATA.glob("**/NORMAL*")
armd = OCTID_DATA.glob("**/AMRD*.jpeg")
dr = OCTID_DATA.glob("**/DR*")
mh = OCTID_DATA.glob("**/MH*")
csr = OCTID_DATA.glob("**/CSR1*")

In [10]:
normal_img_paths = [*normal]
armd_img_paths = [*armd]
dr_img_paths = [*dr]
mh_img_paths = [*mh]
csr_img_paths = [*csr]

In [8]:
class OCTIDSequence(keras.utils.Sequence):
    """Helper to iterate over the data (as Numpy arrays)."""

    def __init__(self, batch_size, img_size, input_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_train_paths = input_img_paths

    def __len__(self):
        return math.floor(len(self.input_train_paths) / self.batch_size)

    def __getitem__(self, idx):
        """Returns tuple (input, target) correspond to batch #idx."""
        i = idx * self.batch_size
        batch_input_train_paths = self.input_train_paths[i : i + self.batch_size]
        x = np.zeros((BATCH_SIZE,) + self.img_size + (1,), dtype="float32")
        for j, path in enumerate(batch_input_train_paths):
            img = load_img(path, target_size=self.img_size, color_mode="grayscale")
            # Crop into image!
            img = img[:, 142:-142]
            img=np.reshape(img, (IMG_HEIGHT, IMG_WIDTH, 1))
            x[j] = img/.255
        return x