In [1]:
import torch
from data_loader import CustomDataLoader
from torch.utils.data import DataLoader
import ukis_metrics.seg_metrics as segm

import warnings
warnings.filterwarnings('ignore')

In [2]:
# model_ckpt = './networks/2024_04_16-16_37_36_encoder_efficientnet-b0_weights_imagenet_epochs_100_batchsize_16_lr_0.001/model_0.38835_val_avg.pt'
model_ckpt = './networks/2024_05_01-10_54_41_encoder_efficientnet-b0_weights_imagenet_epochs_100_batchsize_16_lr_0.001/model_0.19832_val_avg.pt'

In [3]:
import pickle

# Take the same path and find the associated parameter file
path = '/'.join(model_ckpt.split('/')[:-1]) + '/params.pkl'
with open(path, 'rb') as f:
    loaded_dict = pickle.load(f)
print(loaded_dict) # Show the parameter file's contents

import ast # Test how string-to-dict works
ast.literal_eval(str(loaded_dict))

channels = loaded_dict['channels']

{'epochs': 100, 'batch_size': 16, 'lr': 0.001, 'encoder': 'efficientnet-b0', 'weights': 'imagenet', 'model_ckpt_path': './networks', 'data_path': '../data_prepped', 'channels': 'r.g.b.nir.swir1.swir2'}


In [4]:
%%time
from model import SegmentationModel

model = SegmentationModel('efficientnet-b0', 'imagenet', channels = channels)
model.load_state_dict(torch.load(model_ckpt))
model.eval()

CPU times: user 2.34 s, sys: 455 ms, total: 2.79 s
Wall time: 2.71 s


SegmentationModel(
  (cnn): Unet(
    (encoder): EfficientNetEncoder(
      (_conv_stem): Conv2dStaticSamePadding(
        6, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
        (static_padding): ZeroPad2d((0, 1, 0, 1))
      )
      (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
      (_blocks): ModuleList(
        (0): MBConvBlock(
          (_depthwise_conv): Conv2dStaticSamePadding(
            32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
            (static_padding): ZeroPad2d((1, 1, 1, 1))
          )
          (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
          (_se_reduce): Conv2dStaticSamePadding(
            32, 8, kernel_size=(1, 1), stride=(1, 1)
            (static_padding): Identity()
          )
          (_se_expand): Conv2dStaticSamePadding(
            8, 32, kernel_size=(1, 1), stride=(1, 1)
            (static_padding): I

In [5]:
data_path = '../data_prepped'
image_path = data_path + '/{}/img/*'
mask_path = data_path + '/{}/msk/*'
 
# Use custom data loaders for S1S2 dataset
test_dataset = CustomDataLoader(image_path.format('val'), mask_path.format('val'), channels = channels)
test_loader = DataLoader(dataset = test_dataset, batch_size = 8, shuffle = False, num_workers = 8)

In [6]:
def metric_for_ratio(masks, pred_masks, ratio):
    pred_masks = (pred_masks > ratio).cpu().bool()
    masks = masks.bool()

    tpfptnfn = segm.tpfptnfn(masks == 1, pred_masks, None)
    metrics = segm.segmentation_metrics(tpfptnfn)
    # print(ratio, metrics)
    return metrics

In [7]:
torch.multiprocessing.set_sharing_strategy('file_system')
from tqdm import tqdm 
import numpy as np 

device = 'cuda'
model = model.to(device)

tpfptnfn = {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0, 'n_valid_pixel': 0}
ratio = 0.8

i = 0
for ratio in [0.1, 0.2, 0.4, 0.8, 1.0]:
    with torch.no_grad():
        for images, masks in tqdm(test_loader):
            images = images.to(device, dtype = torch.float32)
    
            logits = model(images, None)
            preds = torch.sigmoid(logits)
            preds = (preds > ratio).bool()
    
            masks = masks.bool().to(device)
    
            # Compute the classification performance
            tpfptnfn['tp'] += torch.sum((preds == 1) & (masks == 1)).cpu().numpy()
            tpfptnfn['fn'] += torch.sum((preds == 0) & (masks == 1)).cpu().numpy()
            tpfptnfn['tn'] += torch.sum((preds == 0) & (masks == 0)).cpu().numpy()
            tpfptnfn['fp'] += torch.sum((preds == 1) & (masks == 0)).cpu().numpy()
            tpfptnfn['n_valid_pixel'] += np.prod(masks.shape) # Get the number of pixels
    
        metrics = segm.segmentation_metrics(tpfptnfn)
        print(ratio, metrics)
            
            # Tested different ratios to find the best one for performance
            # for ratio in [0, 0.05, 0.1, 0.2, 0.4, 0.8, 0.85, 0.9, 0.95]: 
            #     metric_for_ratio(masks, pred_masks, ratio)
            
            # ratio = 0.8
            # metric_for_ratio(masks, pred_masks, ratio)
            # i += 1
            # if i == 10:
            #     break

100%|███████████████████████████████████████████████████████████████████████████████| 3107/3107 [04:54<00:00, 10.55it/s]


0.1 {'iou': 0.8356, 'recall': 0.9997, 'precision': 0.8358, 'acc': 0.9768, 'F1': 0.9104, 'kappa': 0.8972}


100%|███████████████████████████████████████████████████████████████████████████████| 3107/3107 [04:52<00:00, 10.61it/s]


0.2 {'iou': 0.8575, 'recall': 0.9992, 'precision': 0.8581, 'acc': 0.9804, 'F1': 0.9233, 'kappa': 0.9121}


100%|███████████████████████████████████████████████████████████████████████████████| 3107/3107 [04:54<00:00, 10.56it/s]


0.4 {'iou': 0.894, 'recall': 0.9954, 'precision': 0.8977, 'acc': 0.9861, 'F1': 0.944, 'kappa': 0.9361}


100%|███████████████████████████████████████████████████████████████████████████████| 3107/3107 [04:54<00:00, 10.55it/s]


0.8 {'iou': 0.8864, 'recall': 0.9619, 'precision': 0.9186, 'acc': 0.9854, 'F1': 0.9398, 'kappa': 0.9315}


100%|███████████████████████████████████████████████████████████████████████████████| 3107/3107 [04:53<00:00, 10.59it/s]

1.0 {'iou': 0.7204, 'recall': 0.7695, 'precision': 0.9186, 'acc': 0.9647, 'F1': 0.8375, 'kappa': 0.8179}





In [8]:
# metrics = segm.segmentation_metrics(tpfptnfn)
# metrics