In [None]:
#!git clone https://github.com/space-lab-sk/scss-net.git
#%cd scss-net
#!git pull origin main                # uncomment and start here if the repo is already cloned
!pip install -U pip
!pip install -U setuptools
!pip install -r requirements.txt

In [None]:
import sys, os

sys.path.append('../scss-net/src')

In [None]:
from mega import Mega
import zipfile
import glob
from datetime import datetime
from PIL import Image
import matplotlib.pylab as plt
from matplotlib.gridspec import GridSpec
from sklearn.model_selection import train_test_split
import albumentations
from ImageDataAugmentor.image_data_augmentor import *
from tensorflow.keras.callbacks import ModelCheckpoint

In [None]:
from model_scss_net import scss_net
from metrics import dice_np, iou_np, dice, iou
from utils import plot_imgs, plot_metrics

In [None]:
from PIL import Image, ImageEnhance

In [None]:
IMG_SIZE = 64  # resize imgs to 256x256
BATCH_SIZE = 20 # set batch size
SEED = 20       # set seed for reproducibility
EPOCHS = 1000    # Set number of epochs

MODEL_NAME = "model_galaxie_vsetky_1000_ep_drop_6"                       # Specify model name
model_filename = f"{MODEL_NAME}.h5"                # Specify path where to save model

## DATA PREP

In [None]:
imgs = glob.glob("data/all_galaxies/*.jpg")
masks = glob.glob("data/all_masks/*.jpg")

print(f"Imgs number = {len(imgs)}\nMasks number = {len(masks)}")

imgs_list = []
masks_list = []
for image, mask in zip(imgs, masks):
    #ig = Image.open(image)
    #enhancer = ImageEnhance.Color(ig)
    #enhancer.enhance(25) toto namiesto Image.open(image)
    imgs_list.append(np.array(Image.open(image).convert("L").resize((IMG_SIZE, IMG_SIZE))))
    masks_list.append(np.array(Image.open(mask).convert("L").resize((IMG_SIZE, IMG_SIZE))))



In [None]:
# Normalization from (0; 255) to (0; 1)
x = np.asarray(imgs_list, dtype=np.float32)/255
y = np.asarray(masks_list, dtype=np.float32)/255

# Reshape to (n_imgs, height, width, channels)
x = x.reshape(x.shape[0], x.shape[1], x.shape[2], 1)
y = y.reshape(y.shape[0], y.shape[1], y.shape[2], 1)

In [None]:
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=123, shuffle=True)

In [None]:
print(imgs[0])

In [None]:
plot_imgs(imgs=x, masks=y, n_imgs=8).show()

In [None]:
# input shape should be (256, 256, 1)
input_shape = x_train[0].shape
print(f"Input shape: {input_shape}\nTrain shape: {x_train.shape}  Val shape: {x_val.shape}")

## TRAINING SCSS MODEL

In [None]:
# Load model architecture with optimal parameteres
model = scss_net( 
    input_shape,
    filters=32,       
    layers=4,
    batch_norm=True,
    drop_prob=0.6)

In [None]:
# Compile model
model.compile(
    optimizer="adam",
    loss="binary_crossentropy",  
    metrics=[iou, dice])

# Set steps parameters acording to size of training set and size of batch
STEPS = x_train.shape[0] // BATCH_SIZE        

# Set Callback that saves only best weights
callback_checkpoint = ModelCheckpoint(
    model_filename,
    verbose=1,
    monitor="val_loss",
    save_best_only=True)

In [None]:
model.load_weights("model_galaxie_vsetky_1000_ep.h5") # TRENUJEME ODZNOVA

In [None]:
# Train model
#history = model.fit(
#    x_train,
#    y_train,
#    steps_per_epoch=STEPS,
#    epochs=EPOCHS,
#    validation_data=(x_val, y_val),
#    callbacks=[callback_checkpoint],
#    verbose=2)

# # Plot training history (Metrics and Loss)
#plot_metrics(history).show()

cca 22 hodin na macbook pro m1

In [None]:
model.save_weights(model_filename)

In [None]:
 y_pred = model.predict(x_val)

In [None]:
plot_imgs(imgs=x_val, masks=y_val, predictions=y_pred, n_imgs=5).show()

In [None]:
imgs_test = glob.glob("data/testovacia_po_edge_detection/cropped_improved_galaxies/*.jpg")
masks_test = glob.glob("data/testovacia_po_edge_detection/galaxy_improved_masks/*.jpg")

print(f"Imgs number = {len(imgs_test)}\nMasks number = {len(masks_test)}")

# Load data and convert imgs to np.array
imgs_test_list = []
masks_test_list = []
for image, mask in zip(imgs_test, masks_test):
    imgs_test_list.append(np.array(Image.open(image).convert("L").resize((IMG_SIZE, IMG_SIZE))))
    masks_test_list.append(np.array(Image.open(mask).convert("L").resize((IMG_SIZE, IMG_SIZE))))

# Normalization from (0; 255) to (0; 1)
x_test = np.asarray(imgs_test_list, dtype=np.float32)/255
y_test = np.asarray(masks_test_list, dtype=np.float32)/255

# Reshape to (n_imgs, height, width, channels)
x_test = x_test.reshape(x_test.shape[0], x_test.shape[1], x_test.shape[2], 1)
y_test = y_test.reshape(y_test.shape[0], y_test.shape[1], y_test.shape[2], 1)

In [None]:
y_pred = model.predict(x_test)  
plot_imgs(imgs=x_test, masks=y_test, predictions=y_pred, n_imgs=20).show()

In [None]:
step = 1
for tresh in range(10, 100 + step, step):
    test_tresh = tresh / 100
    y_pred_bin = np.where(y_pred > test_tresh, 1, 0)  # Set threshold for predicted values

    dice = np.round(dice_np(y_val, y_pred), 4)
    iou_val = np.round(iou_np(y_val, y_pred), 4)

    dice_tresh = np.round(dice_np(y_val, y_pred_bin), 4)
    iou_val_tresh = np.round(iou_np(y_val, y_pred_bin), 4)

    print(f"Validation (> {test_tresh}):\nDice: {dice} Dice_tresh: {dice_tresh}\n IoU: {iou_val} IoU_tresh: {iou_val_tresh}\n")

In [None]:
y_pred = model.predict(x_test)  
plot_imgs(imgs=x_test, masks=y_test, predictions=y_pred, n_imgs=20).show()

In [None]:
y_pred_bin = np.where(y_pred > 0.5, 1, 0)  # Set threshold for predicted values

dice = np.round(dice_np(y_val, y_pred), 4)
iou_val = np.round(iou_np(y_val, y_pred), 4)

dice_tresh = np.round(dice_np(y_val, y_pred_bin), 4)
iou_val_tresh = np.round(iou_np(y_val, y_pred_bin), 4)

print(f"Validation:\nDice: {dice} Dice_tresh: {dice_tresh}\n IoU: {iou_val} IoU_tresh: {iou_val_tresh}\n")

In [None]:
step = 1
for tresh in range(10, 100 + step, step):
    test_tresh = tresh / 100
    y_pred_bin = np.where(y_pred > test_tresh, 1, 0)  # Binarize predicted values

    dice = np.round(dice_np(y_test, y_pred), 4)
    iou_test = np.round(iou_np(y_test, y_pred), 4)

    dice_tresh = np.round(dice_np(y_test, y_pred_bin), 4)
    iou_test_tresh = np.round(iou_np(y_test, y_pred_bin), 4)

    print(f"Validation (> {test_tresh}):\nDice: {dice} Dice_tresh: {dice_tresh}\n IoU: {iou_test} IoU_tresh: {iou_test_tresh}\n")

In [None]:
y_pred_resized = []
for i, pred in enumerate(y_pred):
    # Remove channel dimension if necessary
    # mask = np.squeeze(pred)
    
    # Convert from [0,1] to [0,255] and ensure type is uint8
    # mask = (mask).astype(np.uint8)* 255
    
    # Retrieve the original size; note that PIL's img.size returns (width, height)
    orig_size = original_shapes[i]
    
    # Resize the mask back to the original size using nearest-neighbor interpolation
    resized_mask = cv2.resize(pred, orig_size, interpolation=cv2.INTER_NEAREST)
    y_pred_resized.append(resized_mask)

In [None]:
def save_predicted_masks(y_pred, input_folder, output_folder):
    """
    Reads image filenames from a folder and saves predicted masks with the same names.

    :param numpy.array y_pred: Array of predicted masks.
    :param str input_folder: Folder containing original images.
    :param str output_folder: Folder to save predicted masks.
    """
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Get all image filenames from the input folder
    image_filenames = sorted([
        f for f in os.listdir(input_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))
    ])

    if len(image_filenames) != len(y_pred):
        print(f"Warning: {len(image_filenames)} images but {len(y_pred)} masks!")
    
    for pred_mask, img_filename in zip(y_pred, image_filenames):
        filename = os.path.splitext(img_filename)[0]  # Remove extension
        mask_filename = os.path.join(output_folder, f"{filename}.png")

        # Convert mask to 8-bit grayscale (0-255)
        pred_mask = (pred_mask * 255).astype(np.uint8)

        # Save the mask
        cv2.imwrite(mask_filename, pred_mask)
        print(f"Saved: {mask_filename}")

In [None]:
output_folder = "../../data/new_data_subset/masks_from_cut_fits/masks_scss"
input_folder = "../../data/new_data_subset/masks_from_cut_fits/images_dump/"

In [None]:
save_predicted_masks(y_pred_resized, input_folder, output_folder)