In [5]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tifffile as tif
from tqdm import tqdm

from virvs.architectures.pix2pix import Generator
from virvs.data.npy_dataloader import NpyDataloader, center_crop

tf.keras.utils.set_random_seed(42)


BASE_PATH = "/home/wyrzyk93/VIRVS/outputs/weights/"
DATASETS = {
    "hadv_2ch": "/bigdata/casus/MLID/maria/VIRVS_data/HADV/processed/test",
    "hadv_1ch": "/bigdata/casus/MLID/maria/VIRVS_data/HADV/processed/test",
    "vacv": "/bigdata/casus/MLID/maria/VIRVS_data/VACV/processed/test",
    "iav": "/bigdata/casus/MLID/maria/VIRVS_data/IAV/processed/test",
    "hsv": "/bigdata/casus/MLID/maria/VIRVS_data/HSV/processed/test",
    "rv": "/bigdata/casus/MLID/maria/VIRVS_data/RV/processed/test",
}

WEIGHTS = {
    "pix2pix": {
        "hadv_2ch": "model_100000_316674a2-4299-4a06-b601-5e20f7dd02a6.h5",
        "hadv_1ch": "model_100000_d6792b38-8091-448d-a26a-ef08375b8dbe.h5",
        "vacv": "model_100000_138fd26a-88b9-4d2a-9f96-6a2ffe991364.h5",
        "iav": "model_100000_ef994584-7730-41b9-9a56-75d478bacf02.h5",
        "hsv": "model_100000_52d604f6-25fe-4d42-a212-cef281e3a8b5.h5",
        "rv": "model_100000_a6bd97a2-d889-40a4-a8ec-5c8816797f9d.h5",
    },
    "unet": {
        "hadv_2ch": "model_100000_c0175f01-1e6e-4bec-9f2e-2e5542c16584.h5",
        "hadv_1ch": "model_100000_47b9346c-e143-4bf7-9dce-40aa6f2329e5.h5",
        "vacv": "model_100000_4f91526f-a1b4-4057-926e-073f4ffbef67.h5",
        "iav": "model_100000_26c5b8b9-0c35-4fee-9eac-44a857cebe76.h5",
        "hsv": "model_100000_1b8e6e99-10fa-4221-b53a-b680a65826be.h5",
        "rv": "model_100000_0ca537f4-cbd8-4722-89cb-bdd43107a66b.h5",
    },
}

In [None]:
# IAV norm plot
size = 2048
ch_in = [0]

virus = "iav"

dataloader = NpyDataloader(
    path=DATASETS[virus],
    im_size=size,
    random_jitter=False,
    crop_type="center",
    ch_in=ch_in,
)

x, y = dataloader[1]

plt.imshow(np.squeeze(y), vmin=-1, vmax=1)
plt.figure()

plt.imshow(np.squeeze(x[..., 0]), vmin=-1, vmax=1)
plt.figure()
model = "unet"
dropout = False

generator = Generator(size, ch_in=ch_in, ch_out=1, apply_dropout=dropout)
generator.load_weights(f"{BASE_PATH}/{WEIGHTS[model][virus]}")
output_1 = np.squeeze(generator(np.expand_dims(x, 0), training=True), 0)

plt.figure()
plt.imshow(np.squeeze(output_1), vmin=-1, vmax=1)
plt.title(model)

output_2 = np.squeeze(generator(np.expand_dims(x, 0), training=False), 0)

plt.figure()
plt.imshow(np.squeeze(output_2), vmin=-1, vmax=1)
plt.title(model)

In [1]:
# HADV norm plots
def read_tiff(path: str) -> np.ndarray:
    im_stack = tif.imread(path)
    if len(im_stack.shape) == 4:
        im_stack = im_stack[:, 0]

    return im_stack


def get_percentiles(x):
    mi = np.percentile(x, 3, axis=None)
    ma = np.percentile(x, 99.8, axis=None)
    return mi, ma


def read_data(channel):
    paths_w1 = list(Path(BASE_PATH).glob("**/TimePoint_49" + f"/*_w{str(channel)}.tif"))
    n_sequences = len(paths_w1)
    data = []
    for idx in tqdm(range(n_sequences)):
        w1_path = str(paths_w1[idx])
        w1_ch = np.expand_dims(read_tiff(w1_path), -1) / 65535.0
        data.append(w1_ch)
    data = np.array(data, dtype=np.float32)
    return data


def hist(imgs):
    imgs = (imgs * 255).astype(np.uint8)
    unique, counts = np.unique(imgs, return_counts=True)
    all_counts = np.zeros(np.max(unique) + 1)
    all_counts[unique] = counts

    plt.figure(figsize=(12, 12))
    plt.stairs(all_counts)
    plt.yscale("log")
    plt.xlim(0, np.max(unique) + 1)


data = read_data(1)

i = 0
min_p = 0.0035
max_p = 0.0178
d = np.squeeze(data[i])
plt.axis("off")
plt.imsave(
    "test.svg",
    np.clip((d - 0.0019836728) / (0.06881819 - 0.0019836728), 0, 1),
    format="svg",
    vmin=0,
    vmax=1,
)
plt.imsave(
    "test_1.svg",
    np.clip((d - min_p) / (max_p - min_p), 0, 1),
    format="svg",
    vmin=0,
    vmax=1,
)
plt.imsave("test_0.svg", d, format="svg", vmin=0, vmax=1)
print(np.max(data[i]), np.min(data[i]))

NameError: name 'np' is not defined

In [8]:
from virvs.utils.evaluation_utils import (
    get_masks_to_show,
    get_mean_per_mask,
)

virus = "hsv"
model = "unet"
size = 512

if "hadv" in virus:
    threshold = -0.9
else:
    threshold = -0.8

if "2ch" in virus:
    ch_in = [0, 1]
else:
    ch_in = [0]

dataloader = NpyDataloader(
    path=DATASETS[virus],
    im_size=size,
    random_jitter=False,
    ch_in=ch_in,
    crop_type="center",
)
masks = np.load(
    f"/bigdata/casus/MLID/maria/VIRVS_data/masks/masks_nuc_{virus[:4]}_test.npy"
)
generator = Generator(size, ch_in=ch_in, ch_out=1, apply_dropout=False)
generator.load_weights(f"{BASE_PATH}/{WEIGHTS[model][virus]}")

i = 0
x, y = dataloader[i]
threshold = -0.9
masks_pred = center_crop(masks[i], x.shape[0])
pred = np.squeeze(generator(np.expand_dims(x, 0), training=True), 0)

label_flat = y.flatten()
mask_flat = masks_pred.flatten()

weights = (label_flat > threshold).astype(np.float32)
mean_per_mask = get_mean_per_mask(mask_flat, weights)
masks_to_show = get_masks_to_show(mean_per_mask, 0.5)
new_mask = np.isin(masks_pred, masks_to_show)

pred_weights = (np.squeeze(pred).flatten() > threshold).astype(np.float32)
pred_mean_per_mask = get_mean_per_mask(mask_flat, pred_weights)
pred_masks_to_show = get_masks_to_show(pred_mean_per_mask, 0.5)
pred_new_mask = np.isin(masks_pred, pred_masks_to_show)

Data shape: (96, 2160, 2160, 1)


  mean_per_mask = sum_per_mask / count_per_mask


In [19]:
plt.imsave("mask_gt.svg", new_mask, format="svg")
plt.imsave("all_masks.svg", masks_pred, format="svg")
plt.imsave("x.svg", np.squeeze(x), format="svg")
plt.imsave("y.svg", np.squeeze(y), format="svg")
plt.imsave("pred.svg", np.squeeze(pred), format="svg")
plt.imsave("mask_pred.svg", pred_new_mask, format="svg")

In [3]:
import pandas as pd
import matplotlib.pyplot as plt
import os

base_path = "/home/wyrzyk93/DeepStain/notebooks/final/metrics/"

# Define the filenames and corresponding titles
metrics_files = {
    "val_acc.csv": "Validation Accuracy",
    "val_acc_only.csv": "Validation Accuracy (Only Nuclei)",
    "val_iou.csv": "Validation IoU",
    "val_mse.csv": "Validation MSE",
    "val_prec.csv": "Validation Precision",
    "val_psnr.csv": "Validation PSNR",
    "val_rec.csv": "Validation Recall",
    "val_ssim.csv": "Validation SSIM",
    "val_f1.csv": "Validation F1",
}


# Function to plot and save the metric
def plot_metric(file_name, title):
    # Read the CSV file
    file_path = os.path.join(base_path, file_name)

    data = pd.read_csv(file_path, header=None)

    # Extract step and metric value
    steps = data.iloc[:, 0]
    metric_values = data.iloc[:, -1]

    # Plot the metric
    plt.figure()
    plt.plot(steps, metric_values, label=title, marker="o", linestyle="-", color="b")
    plt.title(title)
    plt.xlabel("Step")
    plt.ylabel("Value")
    plt.grid(True)
    plt.legend()

    # Save the plot as SVG
    output_file = f"{title.replace(' ', '_')}.svg"
    plt.savefig(output_file, format="svg")

    # Close the plot
    plt.close()


# Iterate over the files and create plots
for file_name, title in metrics_files.items():
    plot_metric(file_name, title)

print("Plots generated and saved as SVG files.")

Plots generated and saved as SVG files.
