In [None]:
import os
import sys
sys.path.append('..')  # Add parent directory to path

from dlsia.core.helpers import get_device
import matplotlib.pyplot as plt

from mlex_dlsia.dataset import initialize_tiled_datasets
from mlex_dlsia.inference import run_inference
from mlex_dlsia.network import build_network, load_network
from mlex_dlsia.parameters import IOParameters, TrainingParameters
from mlex_dlsia.train import run_train
from mlex_dlsia.utils.dataloaders import construct_train_dataloaders
from mlex_dlsia.utils.params_validation import validate_parameters
from mlex_dlsia.utils.tiled import prepare_tiled_containers

In [None]:
RECON_TILED_URI = os.getenv('RECON_TILED_URI')
RECON_TILED_API_KEY = os.getenv('RECON_TILED_API_KEY')
MASK_TILED_URI = os.getenv('MASK_TILED_URI')
MASK_TILED_API_KEY = os.getenv('MASK_TILED_API_KEY')
SEG_TILED_URI = os.getenv('SEG_TILED_URI')
SEG_TILED_API_KEY = os.getenv('SEG_TILED_API_KEY')
UID = 'uid0016'

In [None]:
io_parameters = IOParameters(
    data_tiled_uri = RECON_TILED_URI,
    data_tiled_api_key = RECON_TILED_API_KEY,
    mask_tiled_uri = MASK_TILED_URI,
    mask_tiled_api_key = MASK_TILED_API_KEY,
    seg_tiled_uri = SEG_TILED_URI,
    seg_tiled_api_key = SEG_TILED_API_KEY,
    uid_save = UID,
    uid_retrieve=None,
)

training_parameters = TrainingParameters(
    network = "DLSIA TUNet",
    num_classes = 2,
    qlty_window=16,
    qlty_step=6,
    qlty_border=1,
    num_epochs=10,
    batch_size_train=16,
    batch_size_inference=16,
    batch_size_val=16,
)

## Get Annotated Data

In [None]:
dataset = initialize_tiled_datasets(io_parameters, training_parameters, is_training=True)
print(len(dataset))

In [None]:
iter_data = iter(dataset)

In [None]:
data, mask = next(iter_data)
print("Data shape:", data.shape)
print("Mask shape:", mask.shape)

plt.subplots(1, 2, figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Data Slice")
plt.imshow(data[0, :, :])
plt.subplot(1, 2, 2)
plt.title("Mask Slice")
plt.imshow(mask)
plt.show()

## Get Network

In [None]:
import torchsummary

qlty_window = training_parameters.qlty_window
last_channel = dataset.data_client.shape[-1]
networks = build_network(
    network_name=training_parameters.network,
    in_channels=last_channel if last_channel <= 4 else 1,
    image_shape=(qlty_window, qlty_window),
    num_classes=training_parameters.num_classes,
    parameters=training_parameters.dict(),  # Pass the raw parameters dictionary for network construction
)

torchsummary.summary(networks[0], input_size=(1, qlty_window, qlty_window))

## Train Segmentation Model

In [None]:
device = get_device()

io_parameters.models_dir = "./models"
training_parameters.weights = "[0.8,1]"

train_loader, val_loader = construct_train_dataloaders(
    dataset, training_parameters
)

net = run_train(
    train_loader, val_loader, io_parameters, networks, training_parameters, device, use_dvclive=False
)

## Prepare dataset for inference

In [None]:
dataset = initialize_tiled_datasets(io_parameters, training_parameters, is_training=False)

iter_data = iter(dataset)
data = next(iter_data)
print("Data shape for inference:", data.shape)

In [None]:
indx = 50
plt.title("Data Slice for Inference")
plt.imshow(data[indx, 0, :, :])
plt.show()

In [None]:
seg_client = prepare_tiled_containers(
    io_parameters, dataset, training_parameters.network
)

run_inference(
    dataset,
    net,
    seg_client,
    training_parameters,
    device,
)

In [None]:
from tiled.client import from_uri

tiled_client = from_uri(SEG_TILED_URI, api_key=SEG_TILED_API_KEY)
seg_data = tiled_client[UID]["seg_result"][:]
print(f"Segmented data shape: {seg_data.shape}")

plt.imshow(seg_data[0, :, :])
plt.colorbar()
plt.title("Segmented Data Slice")
plt.show()
