In [None]:
import random
import time
from pyflim import flim, arch, data, metrics, util
import numpy as np
from torch.utils.data import DataLoader
import torch

## Create Dataset

In [None]:
orig_folder = "data/orig/"
marker_folder = "data/markers/"
label_folder = "data/label/"
orig_ext = ".png"
label_ext = ".png"
marker_ext = "-seeds.txt"
file_list = "./train.txt"

train_by_batch=True
if(not train_by_batch):
    dataset = data.FLIMData(orig_folder, images_list=file_list, marker_folder=marker_folder, orig_ext=orig_ext, marker_ext=marker_ext,
                                                 transform=data.transforms.Compose([data.ToTensor()]))
else:
    dataset_ = data.FLIMData(orig_folder, images_list=file_list, label_folder=None, label_ext=label_ext, marker_folder=marker_folder, orig_ext=orig_ext, marker_ext=marker_ext,
                                                 transform=data.transforms.Compose([
                                                     data.Rescale(256),
                                                     data.ToTensor()]))
    sampler = torch.utils.data.sampler.BatchSampler(torch.utils.data.sampler.RandomSampler(dataset_),
                                                    batch_size=5,drop_last=False)

    dataset = DataLoader(dataset_, batch_sampler=sampler)

## Train model

In [None]:
architecture = arch.FLIMArchitecture("arch.json")
model = flim.FLIMModel(architecture, adaptation_function="robust_weights", device="cpu", filter_by_size=False, track_gpu_stats=True)
start = time.time()
model.fit(dataset)
stop = time.time()
print('Network trained in:', stop - start, 'seconds')

In [None]:
start = time.time()
model.run(dataset, "out/")
stop = time.time()
print('Forward pass in:', stop - start, 'seconds')

In [None]:
print("Model parameters: ", util.get_model_n_params(model), "(M)")

### Run on validation

In [None]:
file_list = "./val.txt"

train_by_batch=False
if(not train_by_batch):
    dataset = data.FLIMData(orig_folder, images_list=file_list, orig_ext=orig_ext,
                                                 transform=data.transforms.Compose([data.ToTensor()]))
else:
    dataset_ = data.FLIMData(orig_folder, images_list=file_list, orig_ext=orig_ext,
                                                 transform=data.transforms.Compose([
                                                     data.Rescale(256),
                                                     data.ToTensor()]))
    sampler = torch.utils.data.sampler.BatchSampler(torch.utils.data.SequentialSampler(dataset_),
                                                    batch_size=5,drop_last=False)

    dataset = DataLoader(dataset_, batch_sampler=sampler)

start = time.time()
model.run(dataset, "out/")
stop = time.time()
print('Forward pass in:', stop - start, 'seconds')

### Compute Metrics

In [None]:
from pyflim import flim, arch, data, metrics, util
file_list = "./val.txt"
label_folder = "data/label/"
results_folder = "out/"

metricas = metrics.FLIMMetrics()
metricas.evaluate_saliency_results(results_folder, label_folder, file_list=util.readFileList(file_list))

In [None]:
metricas.print_results()