In [10]:
metrics = ['mean', '50%']

data_paths = {
        'T1': 'data/dataset/t1',
        'T2': 'data/dataset/t2',
        'mask': 'data/dataset/mask',
        'T1_avaliacao': 'data/avaliacao/t1',
        'T2_avaliacao': 'data/avaliacao/t2',
    }

In [11]:
import torch as th
import torch.nn as nn

class Unet_module(nn.Module):
    def __init__(self, kernel_size, de_kernel_size, channel_list, down_up='down'):
        super(Unet_module, self).__init__()
        self.conv1 = nn.Conv2d(channel_list[0], channel_list[1], kernel_size, 1, (kernel_size - 1) // 2)
        self.conv2 = nn.Conv2d(channel_list[1], channel_list[2], kernel_size, 1, (kernel_size - 1) // 2)
        self.relu1 = nn.PReLU()
        self.relu2 = nn.PReLU()
        self.bn1 = nn.BatchNorm2d(channel_list[1])
        self.bn2 = nn.BatchNorm2d(channel_list[2])
        self.bridge_conv = nn.Conv2d(channel_list[0], channel_list[-1], kernel_size, 1, (kernel_size - 1) // 2)

        if down_up == 'down':
            self.sample = nn.Sequential(
                nn.Conv2d(channel_list[2], channel_list[2], de_kernel_size, 2, (de_kernel_size - 1) // 2, 1),
                nn.BatchNorm2d(channel_list[2]), nn.PReLU())
        else:
            self.sample = nn.Sequential(
                nn.ConvTranspose2d(channel_list[2], channel_list[2], de_kernel_size, 2, (de_kernel_size - 1) // 2),
                nn.BatchNorm2d(channel_list[2]), nn.ReLU())

    def forward(self, x):
        res = self.bridge_conv(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = x + res
        next_layer = self.sample(x)

        return next_layer, x
    
class de_conv_module(nn.Module):
    def __init__(self, kernel_size, de_kernel_size, channel_list, down_up='down'):
        super().__init__()
        self.conv1 = nn.Conv2d(channel_list[0], channel_list[1], kernel_size, 1, (kernel_size - 1) // 2)
        self.conv2 = nn.Conv2d(channel_list[1], channel_list[2], kernel_size, 1, (kernel_size - 1) // 2)
        self.relu1 = nn.PReLU()
        self.relu2 = nn.PReLU()
        self.bn1 = nn.BatchNorm2d(channel_list[1])
        self.bn2 = nn.BatchNorm2d(channel_list[2])
        self.bridge_conv = nn.Conv2d(channel_list[0], channel_list[-1], kernel_size, 1, (kernel_size - 1) // 2)

        if down_up == 'down':
            self.sample = nn.Sequential(
                nn.Conv2d(channel_list[2], channel_list[2], de_kernel_size, 2, (de_kernel_size - 1) // 2, 1),
                nn.BatchNorm2d(channel_list[2]), nn.PReLU())
        else:
            self.sample = nn.Sequential(
                nn.ConvTranspose2d(channel_list[2], channel_list[2], de_kernel_size, 2, (de_kernel_size - 1) // 2),
                nn.BatchNorm2d(channel_list[2]), nn.ReLU())

    def forward(self, x, x1):
        x = th.cat([x, x1], dim=1)
        res = self.bridge_conv(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = x + res
        next_layer = self.sample(x)

        return next_layer

class FCN_2D(nn.Module):
    def __init__(self, in_channel, layers):
        super().__init__()
        # channel=2
        self.conv1 = nn.Sequential(nn.Conv2d(in_channel, layers, 5, 1, padding=2), nn.BatchNorm2d(layers), nn.PReLU())
        self.conv2 = nn.Sequential(nn.Conv2d(layers, layers * 2, 2, 2, padding=0), nn.BatchNorm2d(layers * 2),
                                   nn.PReLU())
        self.conv3 = Unet_module(5, 2, [layers * 2, layers * 2, layers * 4], 'down')
        self.conv4 = Unet_module(5, 2, [layers * 4, layers * 4, layers * 8], 'down')
        self.conv5 = Unet_module(5, 2, [layers * 8, layers * 8, layers * 16], 'down')

        self.de_conv1 = Unet_module(5, 2, [layers * 16, layers * 32, layers * 16], down_up='up')
        self.de_conv2 = de_conv_module(5, 2, [layers * 32, layers * 8, layers * 8], down_up='up')
        self.de_conv3 = de_conv_module(5, 2, [layers * 16, layers * 4, layers * 4], down_up='up')
        self.de_conv4 = de_conv_module(5, 2, [layers * 8, layers * 2, layers], down_up='up')

        self.last_conv = nn.Conv2d(layers * 2, 1, 1, 1, padding=0)

    def forward(self, x):
        x = self.conv1(x)
        x_1 = x
        x = self.conv2(x)
        x, x_2 = self.conv3(x)
        x, x_3 = self.conv4(x)
        x, x_4 = self.conv5(x)

        x, _ = self.de_conv1(x)
        x = self.de_conv2(x, x_4)
        x = self.de_conv3(x, x_3)
        x = self.de_conv4(x, x_2)

        x = th.cat([x, x_1], dim=1)
        output = self.last_conv(x)
        return output


In [12]:
from torch.utils.data import Dataset
import numpy as np
import os
import torch
import rasterio

class WorCapDataset(Dataset):
    def __init__(self, T10_dir, T20_dir, mask_dir, ids, transform=None):
        self.T10_dir = T10_dir
        self.T20_dir = T20_dir
        self.mask_dir = mask_dir
        self.ids = ids
        self.transform = transform
        
    def __len__(self):
        return len(self.ids)

    def read_image(self, path):
        with rasterio.open(path) as src:
            img = src.read().astype(np.float32)
            img = np.nan_to_num(img, nan=0.0)
            img_min = img.min()
            img_max = img.max()
            if img_max - img_min > 0:
                img = (img - img_min) / (img_max - img_min)
            else:
                img = np.zeros_like(img)
        return torch.tensor(img, dtype=torch.float32)

    def read_mask(self, path):
        with rasterio.open(path) as src:
            mask = src.read(1).astype(np.float32)
            mask = np.nan_to_num(mask, nan=0.0)
            mask = np.where(mask > 0, 1.0, 0.0)
        return torch.tensor(mask, dtype=torch.float32).unsqueeze(0)

    def __getitem__(self, idx):
        id_ = self.ids[idx]
        fname = f"recorte_{id_}.tif"
        T10_path = os.path.join(self.T10_dir, fname)
        T20_path = os.path.join(self.T20_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        t1 = self.read_image(T10_path)
        t2 = self.read_image(T20_path)
        mask = self.read_mask(mask_path)

        if self.transform:
            t1 = self.transform(t1)
            t2 = self.transform(t2)
            mask = self.transform(mask)

        T = torch.cat([t1, t2], dim=0)
        return T, mask, id_

In [13]:
import spyndex
class WorCapDiffDataset(Dataset):
    def __init__(self, T10_dir, T20_dir, mask_dir, ids, transform=None):
        self.T10_dir = T10_dir
        self.T20_dir = T20_dir
        self.mask_dir = mask_dir
        self.ids = ids
        self.transform = transform

    def __len__(self):
        return len(self.ids)

    def read_image(self, path):
        with rasterio.open(path) as src:
            S1 = src.read(3).astype(np.float32)
            S2 = src.read(4).astype(np.float32)
            
        idx = spyndex.computeIndex(
            index=["NBRSWIR"],
            params={
                "S1": S1,
                "S2": S2
            }
        )
        
        return idx

    def read_mask(self, path):
        with rasterio.open(path) as src:
            mask = src.read(1).astype(np.float32)
            mask = np.nan_to_num(mask, nan=0.0)
            mask = np.where(mask > 0, 1.0, 0.0)
        return torch.tensor(mask, dtype=torch.float32).unsqueeze(0)
    

    def __getitem__(self, idx):
        id_ = self.ids[idx]
        fname = f"recorte_{id_}.tif"
        T10_path = os.path.join(self.T10_dir, fname)
        T20_path = os.path.join(self.T20_dir, fname)
        mask_path = os.path.join(self.mask_dir, fname)

        t1 = self.read_image(T10_path)
        t2 = self.read_image(T20_path)
        
        idx = t1 - t2 
        idx = np.expand_dims(idx, axis=0)

        mask = self.read_mask(mask_path)

        if self.transform:
            idx = self.transform(torch.tensor(idx, dtype=torch.float32))
            mask = self.transform(mask)

        T = torch.tensor(idx, dtype=torch.float32)
        return T, mask, id_

In [14]:
def valid(model, criterion, valid_loader, device, e):
    model.eval()
    valid_sum = 0
    for j, batch in enumerate(valid_loader):
        img, label = batch[0].float(), batch[1].float()
        img, label = img.to(device), label.to(device)

        with torch.no_grad():
            outputs = model(img)
            loss = criterion(outputs, label)
        valid_sum += loss.item()
        print('Epoch {:<3d}  |Step {:>3d}/{:<3d}  | valid loss {:.4f}'.format(e, j, len(valid_loader), loss.item()))

    return valid_sum / len(valid_loader)

In [15]:
limiar_list = [0.05 * i for i in range(1, 10)]
layers_list = [8, 32, 64]
channels_list = [8, 1]
load_num = 100
channels = 8

os.environ['CUDA_VISIBLE_DEVICES'] = str(0)
torch.cuda.set_device(0)
torch.backends.cudnn.enabled = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
import pandas as pd

split_df = pd.read_csv('split_ids.csv')
val_ids = split_df[split_df['split'] == 'val']['ID'].to_list()

In [17]:
def calc_dice(labels, preds):
    labels_flat = labels.flatten()
    preds_flat = preds.flatten()
    
    labels_flat = labels_flat.astype(bool).astype(int)
    preds_flat = preds_flat.astype(bool).astype(int)
    
    intersection = np.sum(labels_flat * preds_flat)
    union = np.sum(labels_flat) + np.sum(preds_flat)
    
    if union == 0:
        return 1.0 # sem labels em ambos, ou seja, predição perfeita
    else:
        return (2. * intersection) / union

In [18]:
from torch.utils.data import DataLoader
dice_results = {}

for layers in layers_list:
    for l in limiar_list:
        for channels in channels_list:
            if channels == 1 and layers == 64:
                continue
            print(f'Evaluating for layers={layers}, threshold={l:.2f}, channels={channels}')
            col_name = f'layers{layers}_limiar{l:.2f}_{channels}ch'
            dice_results[col_name] = []
            result_path = os.path.abspath('.') + '/results'
            model_save_path = os.path.join(result_path, f'FCN_2D_{channels}ch_{layers}lyr')

            net = FCN_2D(channels, layers).to(device)
            net.load_state_dict(torch.load(model_save_path + '/net_%d.pkl' % load_num))
            net.eval()

            if channels == 8:
                dataset_val = WorCapDataset(data_paths["T1"], data_paths["T2"], data_paths['mask'], val_ids)
            elif channels == 1:
                dataset_val = WorCapDiffDataset(data_paths["T1"], data_paths["T2"], data_paths['mask'], val_ids)
            test_loader = DataLoader(dataset_val, batch_size=1, shuffle=False)

            for j, batch in enumerate(test_loader):
                img, label, id = batch[0].float(), batch[1].float(), str(batch[2].item())
                img, label = img.to(device), label.to(device)
                with torch.no_grad():
                    outputs = net(img)
                outputs = torch.sigmoid(outputs)
                outputs = outputs.squeeze(1).cpu().numpy()
                label_np = label.squeeze(1).cpu().numpy()

                pred_bin = (outputs >= l).astype(np.uint8)
                label_bin = label_np.astype(np.uint8)

                dice = calc_dice(label_bin, pred_bin)
                dice_results[col_name].append(dice)

df_dice = pd.DataFrame(dice_results)
df_dice.index = val_ids
df_dice.to_csv('dice_results.csv', index_label='ID')
            


Evaluating for layers=8, threshold=0.05, channels=8
Evaluating for layers=8, threshold=0.05, channels=1
Evaluating for layers=8, threshold=0.10, channels=8
Evaluating for layers=8, threshold=0.10, channels=1
Evaluating for layers=8, threshold=0.15, channels=8
Evaluating for layers=8, threshold=0.15, channels=1
Evaluating for layers=8, threshold=0.20, channels=8
Evaluating for layers=8, threshold=0.20, channels=1
Evaluating for layers=8, threshold=0.25, channels=8
Evaluating for layers=8, threshold=0.25, channels=1
Evaluating for layers=8, threshold=0.30, channels=8
Evaluating for layers=8, threshold=0.30, channels=1
Evaluating for layers=8, threshold=0.35, channels=8
Evaluating for layers=8, threshold=0.35, channels=1
Evaluating for layers=8, threshold=0.40, channels=8
Evaluating for layers=8, threshold=0.40, channels=1
Evaluating for layers=8, threshold=0.45, channels=8
Evaluating for layers=8, threshold=0.45, channels=1
Evaluating for layers=32, threshold=0.05, channels=8
Evaluating 

In [19]:
describr_df = df_dice.describe()

top5_mean_cols = describr_df.loc['mean'].sort_values(ascending=False).head(5)
top5_sec_qart_cols = describr_df.loc['50%'].sort_values(ascending=False).head(5)



In [20]:
top5_mean_cols

layers32_limiar0.15_8ch    0.821957
layers32_limiar0.20_8ch    0.821328
layers32_limiar0.10_8ch    0.821246
layers32_limiar0.25_8ch    0.820496
layers32_limiar0.30_8ch    0.820321
Name: mean, dtype: float64

In [21]:
top5_sec_qart_cols

layers32_limiar0.25_8ch    0.932297
layers32_limiar0.30_8ch    0.932197
layers32_limiar0.20_8ch    0.929527
layers32_limiar0.15_8ch    0.928064
layers32_limiar0.10_8ch    0.927152
Name: 50%, dtype: float64

In [22]:
limiar_selected = [0.15, 0.2, 0.25, 0.1, 0.3, 0.4, 0.5, 0.6]