In [None]:

from multissl.data.mask_loader import SegmentationDataset
import torch
import pytorch_lightning as pl
import os
from multissl.models.randomforest import RandomForestSegmentation
from torchvision import transforms
from torch.utils.data import DataLoader

In [None]:
CONFIG = {
    "img_dir": "",
    "mask_dir": "",

    "output_dir": "",
    "model_filename": "",
    
    "random_seed": 42,
    "output_dir": "",
    "n_estimators" : 100,
    "max_depth": None,
    "pixel_features": True,
    "spatial_features": False,
    "texture_features": False,
    "img_size": 224,
    "in_channels": 4

}

In [None]:
# Set random seed for reproducibility
pl.seed_everything(CONFIG["random_seed"])

# Create output directory
os.makedirs(CONFIG["output_dir"], exist_ok=True)

# Create datasets
from multissl.data.seg_transforms import SafeUIntToFloat, ToTensorSafe


In [None]:
train_dataset = SegmentationDataset(
    img_dir=os.path.join(CONFIG["img_dir"], "train1"),
    mask_dir=os.path.join(CONFIG["mask_dir"], "train1"),
    img_size=CONFIG["img_size"]
)

val_dataset = SegmentationDataset(
    img_dir=os.path.join(CONFIG["img_dir"], "val_esac2"),
    mask_dir=os.path.join(CONFIG["mask_dir"], "val_esac2"),
    img_size=CONFIG["img_size"]
)

test_dataset_v = SegmentationDataset(
    img_dir=os.path.join(CONFIG["img_dir"], "test_valdoeiro"),
    mask_dir=os.path.join(CONFIG["mask_dir"], "test_valdoeiro"),
    img_size=CONFIG["img_size"]
)

test_dataset_q = SegmentationDataset(
    img_dir=os.path.join(CONFIG["img_dir"], "test_qbaixo"),
    mask_dir=os.path.join(CONFIG["mask_dir"], "test_qbaixo"),
    img_size=CONFIG["img_size"]
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=CONFIG["num_workers"],
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=0,
    pin_memory=True,

)

test_loader_v = DataLoader(
    test_dataset_v,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=0,
    pin_memory=True,

)
test_loader_q = DataLoader(
    test_dataset_q,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=0,
    pin_memory=True,

)
# Print dataset sizes
print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")
print(f"Test dataset size valdoeiro: {len(test_dataset_v)}")
print(f"Test dataset size qbaixo: {len(test_dataset_q)}")

In [None]:


# Create model
model = RandomForestSegmentation(
    n_estimators= CONFIG["n_estimators"],
    max_depth=CONFIG["max_depth"],
    pixel_features=CONFIG["pixel_features"] ,
    texture_features=CONFIG["texture_features"] ,
    spatial_features=CONFIG["spatial_features"] ,
    img_size=CONFIG["img_size"],
    in_channels=CONFIG["in_channels"],

    class_weight='balanced',
)

# Train model
metrics = model.fit(train_loader, val_loader)




In [None]:


# Test model
test_metrics_v = model.validate(test_loader_v)
# Test model
test_metrics_q = model.validate(test_loader_q)

In [None]:
# Save model
model_path = os.path.join(CONFIG["output_dir"], CONFIG["model_filename"])
model.save(model_path)