In [1]:
import os, sys
notebook_dir = os.getcwd()
project_dir = os.path.abspath(os.path.join(notebook_dir, '..'))
if project_dir not in sys.path:
    sys.path.append(project_dir)
utils_dir = os.path.abspath(os.path.join(notebook_dir, '..', 'utils'))
if utils_dir not in sys.path:
    sys.path.append(utils_dir)


In [2]:

import warnings
warnings.simplefilter("ignore", (UserWarning, FutureWarning))
from hparams import HParam
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import dataloader_2
import metrics
from core.res_unet import ResUnet
from core.res_unet_plus import ResUnetPlusPlus
from logger import MyWriter
import torch
from torch import nn
import argparse
import glob
import skimage
from torchvision.utils import save_image

def main(hp, name):
    output_dir = hp.pred
    os.makedirs("{}/{}".format(hp.log, name), exist_ok=True)
    writer = MyWriter("{}/{}".format(hp.log, name))
    
    if hp.RESNET_PLUS_PLUS:
        model = ResUnetPlusPlus(3).cuda()
        model = torch.nn.DataParallel(model)
    else:
        model = ResUnet(3, 64).cuda()
    # Wrap model with DataParallel
    # model = torch.nn.DataParallel(model)
    checkpoint = torch.load(hp.checkpoints)
    model.load_state_dict(checkpoint["state_dict"])
    print("=> loaded checkpoint '{}' (epoch {})".format(hp.checkpoints, checkpoint["epoch"]))
    
    output_dir = hp.pred
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    mass_dataset_val = dataloader_2.ImageDataset(
        hp, False, transform=transforms.Compose([dataloader_2.ToTensorTarget()])
    )

    val_dataloader = DataLoader(
        mass_dataset_val, batch_size=1, num_workers=2, shuffle=False
    )

    model.eval()
    print(type(val_dataloader))

    for idx, data in enumerate(tqdm(val_dataloader, desc="validation")):
        print(data)
        inputs = data["sat_img"].cuda()
        prob_map = model(inputs) # last activation was a sigmoid
        outputs = (prob_map > 0.3).float()
        print(outputs.shape)
        image_name = mass_dataset_val.getPath(idx).split("/")[-1]
        print(image_name)
        img = outputs[0] #torch.Size([3,28,28]
        save_image(img, f'{output_dir}/pred_{image_name}')
        
if __name__ == "__main__":
    args = {
        "name": "ResUnetTrain3",
        "config": "/home/manojkumargalla/PostProcess/config/default3_eval.yaml",
        "resume": ""
    }
    
    class Struct:
        def __init__(self, entries):
            self.__dict__.update(entries)
            
    args = Struct(args)
    
    hp = HParam(args.config)
    with open(args.config, "r") as f:
        hp_str = "".join(f.readlines())

    main(hp, name=args.name)


=> loaded checkpoint '/home/manojkumargalla/ResUNet/newcheckpoints/ResUnetTrain/ResUnetTrain_checkpoint_36000.pt' (epoch 75)
Found 57 images in /home/manojkumargalla/ResUNet/data/test2/test/input_crop
<class 'torch.utils.data.dataloader.DataLoader'>


validation:   0%|          | 0/57 [00:00<?, ?it/s]

{'sat_img': tensor([[[[0.7059, 0.7216, 0.7059,  ..., 0.7961, 0.7569, 0.7020],
          [0.7373, 0.7412, 0.7216,  ..., 0.7804, 0.7412, 0.6941],
          [0.7451, 0.7451, 0.7255,  ..., 0.7451, 0.7255, 0.6980],
          ...,
          [0.9451, 0.9412, 0.9333,  ..., 0.8902, 0.8039, 0.7216],
          [0.9294, 0.9333, 0.9294,  ..., 0.8784, 0.8157, 0.7294],
          [0.9333, 0.9373, 0.9333,  ..., 0.8000, 0.8000, 0.7843]],

         [[0.5569, 0.5725, 0.5451,  ..., 0.6392, 0.6078, 0.5529],
          [0.5882, 0.5882, 0.5608,  ..., 0.6235, 0.5922, 0.5529],
          [0.5961, 0.5922, 0.5647,  ..., 0.5882, 0.5765, 0.5569],
          ...,
          [0.9490, 0.9451, 0.9373,  ..., 0.8078, 0.7294, 0.6510],
          [0.9451, 0.9373, 0.9333,  ..., 0.7961, 0.7333, 0.6588],
          [0.9490, 0.9412, 0.9373,  ..., 0.7176, 0.7176, 0.7020]],

         [[0.7176, 0.7333, 0.7176,  ..., 0.7765, 0.7529, 0.6980],
          [0.7490, 0.7608, 0.7333,  ..., 0.7608, 0.7373, 0.6941],
          [0.7569, 0.7647, 0.7

validation:   5%|▌         | 3/57 [00:01<00:24,  2.17it/s]

torch.Size([1, 1, 512, 512])
S-1908-009781_PAS_2of2 [x=7680,y=15872,w=512,h=512].jpg
{'sat_img': tensor([[[[0.8275, 0.8157, 0.7255,  ..., 0.9451, 0.9451, 0.9451],
          [0.9176, 0.8196, 0.7412,  ..., 0.9451, 0.9451, 0.9451],
          [0.8392, 0.7569, 0.7529,  ..., 0.9451, 0.9451, 0.9451],
          ...,
          [0.8431, 0.9020, 0.9647,  ..., 0.9529, 0.9569, 0.9569],
          [0.8275, 0.8863, 0.9569,  ..., 0.9412, 0.9412, 0.9412],
          [0.8157, 0.8784, 0.9490,  ..., 0.9451, 0.9451, 0.9451]],

         [[0.7373, 0.7255, 0.6353,  ..., 0.9451, 0.9451, 0.9451],
          [0.8118, 0.7137, 0.6392,  ..., 0.9451, 0.9451, 0.9451],
          [0.7020, 0.6196, 0.6235,  ..., 0.9451, 0.9451, 0.9451],
          ...,
          [0.7922, 0.8667, 0.9333,  ..., 0.9529, 0.9569, 0.9569],
          [0.7765, 0.8471, 0.9216,  ..., 0.9412, 0.9412, 0.9412],
          [0.7647, 0.8392, 0.9137,  ..., 0.9451, 0.9451, 0.9451]],

         [[0.8510, 0.8392, 0.7451,  ..., 0.9373, 0.9373, 0.9373],
          [

validation:  12%|█▏        | 7/57 [00:01<00:08,  5.84it/s]

{'sat_img': tensor([[[[0.9373, 0.9451, 0.9569,  ..., 0.5922, 0.6549, 0.7294],
          [0.9412, 0.9412, 0.9451,  ..., 0.6078, 0.6275, 0.6706],
          [0.9490, 0.9451, 0.9412,  ..., 0.6118, 0.5961, 0.6627],
          ...,
          [0.9373, 0.9373, 0.9373,  ..., 0.7922, 0.7647, 0.7020],
          [0.9373, 0.9373, 0.9373,  ..., 0.8157, 0.8000, 0.7412],
          [0.9373, 0.9373, 0.9373,  ..., 0.8471, 0.8471, 0.8000]],

         [[0.9412, 0.9490, 0.9608,  ..., 0.4784, 0.5686, 0.6588],
          [0.9451, 0.9451, 0.9490,  ..., 0.5059, 0.5412, 0.6000],
          [0.9529, 0.9490, 0.9451,  ..., 0.5098, 0.5216, 0.5922],
          ...,
          [0.9373, 0.9373, 0.9373,  ..., 0.6824, 0.6549, 0.5922],
          [0.9373, 0.9373, 0.9373,  ..., 0.7059, 0.6902, 0.6314],
          [0.9373, 0.9373, 0.9373,  ..., 0.7373, 0.7373, 0.6902]],

         [[0.9216, 0.9294, 0.9412,  ..., 0.5882, 0.6588, 0.7451],
          [0.9255, 0.9255, 0.9294,  ..., 0.6118, 0.6314, 0.6863],
          [0.9333, 0.9294, 0.9

validation:  19%|█▉        | 11/57 [00:02<00:04,  9.61it/s]

{'sat_img': tensor([[[[0.9725, 0.9725, 0.9725,  ..., 0.7882, 0.7961, 0.8196],
          [0.9647, 0.9647, 0.9647,  ..., 0.8863, 0.9059, 0.9373],
          [0.9647, 0.9647, 0.9647,  ..., 0.9490, 0.9686, 0.9882],
          ...,
          [0.9490, 0.9490, 0.9490,  ..., 0.7333, 0.7569, 0.7412],
          [0.9490, 0.9490, 0.9490,  ..., 0.7176, 0.7686, 0.7686],
          [0.9490, 0.9490, 0.9490,  ..., 0.8745, 0.8980, 0.8667]],

         [[0.9412, 0.9412, 0.9412,  ..., 0.7882, 0.8078, 0.8353],
          [0.9451, 0.9451, 0.9451,  ..., 0.8784, 0.9020, 0.9412],
          [0.9412, 0.9412, 0.9412,  ..., 0.9255, 0.9451, 0.9765],
          ...,
          [0.9569, 0.9569, 0.9569,  ..., 0.6471, 0.6627, 0.6471],
          [0.9569, 0.9569, 0.9569,  ..., 0.6314, 0.6784, 0.6667],
          [0.9569, 0.9569, 0.9569,  ..., 0.7922, 0.8078, 0.7686]],

         [[0.9333, 0.9333, 0.9333,  ..., 0.8275, 0.8275, 0.8471],
          [0.9333, 0.9333, 0.9333,  ..., 0.9216, 0.9333, 0.9608],
          [0.9412, 0.9412, 0.9

validation:  26%|██▋       | 15/57 [00:02<00:03, 12.65it/s]

{'sat_img': tensor([[[[0.9333, 0.9333, 0.9333,  ..., 0.7725, 0.6510, 0.6000],
          [0.9373, 0.9373, 0.9373,  ..., 0.7647, 0.6824, 0.6745],
          [0.9333, 0.9333, 0.9333,  ..., 0.6902, 0.6431, 0.6941],
          ...,
          [0.9451, 0.9451, 0.9451,  ..., 0.8784, 0.8431, 0.8157],
          [0.9451, 0.9451, 0.9451,  ..., 0.9373, 0.9176, 0.9137],
          [0.9451, 0.9451, 0.9451,  ..., 0.9490, 0.9529, 0.9608]],

         [[0.9569, 0.9569, 0.9569,  ..., 0.6667, 0.5451, 0.4824],
          [0.9608, 0.9608, 0.9608,  ..., 0.6588, 0.5647, 0.5569],
          [0.9569, 0.9569, 0.9569,  ..., 0.5804, 0.5255, 0.5686],
          ...,
          [0.9451, 0.9451, 0.9451,  ..., 0.8549, 0.8118, 0.7843],
          [0.9451, 0.9451, 0.9451,  ..., 0.9020, 0.8863, 0.8745],
          [0.9451, 0.9451, 0.9451,  ..., 0.9137, 0.9216, 0.9216]],

         [[0.9412, 0.9412, 0.9412,  ..., 0.7961, 0.6745, 0.6157],
          [0.9451, 0.9451, 0.9451,  ..., 0.7961, 0.7059, 0.6980],
          [0.9412, 0.9412, 0.9

validation:  33%|███▎      | 19/57 [00:02<00:02, 14.66it/s]

{'sat_img': tensor([[[[0.9529, 0.9490, 0.9451,  ..., 0.5961, 0.6118, 0.6510],
          [0.9529, 0.9490, 0.9412,  ..., 0.6000, 0.5922, 0.6000],
          [0.9451, 0.9451, 0.9412,  ..., 0.6353, 0.6235, 0.6157],
          ...,
          [0.7490, 0.8667, 0.9333,  ..., 0.4980, 0.4667, 0.4549],
          [0.7216, 0.7725, 0.7765,  ..., 0.4824, 0.4588, 0.4549],
          [0.7020, 0.7059, 0.6510,  ..., 0.4745, 0.4549, 0.4549]],

         [[0.9529, 0.9490, 0.9451,  ..., 0.5294, 0.5333, 0.5647],
          [0.9529, 0.9490, 0.9412,  ..., 0.5176, 0.4980, 0.5059],
          [0.9451, 0.9451, 0.9412,  ..., 0.5333, 0.5098, 0.4941],
          ...,
          [0.6667, 0.7843, 0.8510,  ..., 0.2549, 0.2235, 0.2118],
          [0.6392, 0.6980, 0.7020,  ..., 0.2510, 0.2275, 0.2235],
          [0.6275, 0.6314, 0.5765,  ..., 0.2431, 0.2275, 0.2314]],

         [[0.9216, 0.9176, 0.9137,  ..., 0.6627, 0.6706, 0.7059],
          [0.9216, 0.9176, 0.9098,  ..., 0.6392, 0.6235, 0.6314],
          [0.9137, 0.9137, 0.9

validation:  40%|████      | 23/57 [00:02<00:02, 15.83it/s]

{'sat_img': tensor([[[[0.8510, 0.8510, 0.8431,  ..., 0.6549, 0.6471, 0.6706],
          [0.8549, 0.8784, 0.8706,  ..., 0.7647, 0.6863, 0.6275],
          [0.8118, 0.8549, 0.8549,  ..., 0.9098, 0.8039, 0.6902],
          ...,
          [0.7333, 0.7490, 0.7373,  ..., 0.7059, 0.6471, 0.6275],
          [0.7137, 0.7451, 0.7412,  ..., 0.6863, 0.6667, 0.6627],
          [0.6941, 0.7333, 0.7294,  ..., 0.7020, 0.7137, 0.7216]],

         [[0.8392, 0.8392, 0.8235,  ..., 0.5098, 0.4863, 0.5059],
          [0.8353, 0.8588, 0.8549,  ..., 0.6078, 0.5333, 0.4863],
          [0.7765, 0.8196, 0.8196,  ..., 0.7451, 0.6627, 0.5804],
          ...,
          [0.5451, 0.5647, 0.5608,  ..., 0.5765, 0.5098, 0.4824],
          [0.5255, 0.5647, 0.5647,  ..., 0.5569, 0.5216, 0.5098],
          [0.5059, 0.5529, 0.5529,  ..., 0.5647, 0.5686, 0.5686]],

         [[0.9216, 0.9216, 0.9020,  ..., 0.6314, 0.6196, 0.6471],
          [0.9137, 0.9373, 0.9216,  ..., 0.7451, 0.6549, 0.5882],
          [0.8510, 0.8941, 0.8

validation:  47%|████▋     | 27/57 [00:03<00:01, 16.49it/s]

{'sat_img': tensor([[[[0.6078, 0.6039, 0.6196,  ..., 0.8353, 0.8196, 0.8000],
          [0.6588, 0.6549, 0.6196,  ..., 0.7608, 0.8392, 0.8784],
          [0.7176, 0.7059, 0.6431,  ..., 0.8078, 0.8863, 0.9412],
          ...,
          [0.8549, 0.8157, 0.7451,  ..., 0.9451, 0.9373, 0.9333],
          [0.8980, 0.8667, 0.7922,  ..., 0.9373, 0.9412, 0.9412],
          [0.7922, 0.7804, 0.7412,  ..., 0.9373, 0.9490, 0.9569]],

         [[0.5373, 0.5333, 0.5294,  ..., 0.7882, 0.8078, 0.8000],
          [0.6000, 0.5843, 0.5412,  ..., 0.7137, 0.7961, 0.8353],
          [0.6627, 0.6471, 0.5725,  ..., 0.7529, 0.8039, 0.8392],
          ...,
          [0.7529, 0.7137, 0.6353,  ..., 0.9451, 0.9373, 0.9333],
          [0.8157, 0.7804, 0.7020,  ..., 0.9373, 0.9412, 0.9412],
          [0.7176, 0.7020, 0.6588,  ..., 0.9373, 0.9490, 0.9569]],

         [[0.6392, 0.6353, 0.6431,  ..., 0.8745, 0.8353, 0.7922],
          [0.6902, 0.6863, 0.6471,  ..., 0.8078, 0.8627, 0.8902],
          [0.7529, 0.7373, 0.6

validation:  54%|█████▍    | 31/57 [00:03<00:01, 16.79it/s]

{'sat_img': tensor([[[[0.9216, 0.9176, 0.8941,  ..., 0.9451, 0.9490, 0.9490],
          [0.8549, 0.8980, 0.9137,  ..., 0.9412, 0.9490, 0.9569],
          [0.8510, 0.8824, 0.8588,  ..., 0.9176, 0.9294, 0.9412],
          ...,
          [0.2667, 0.2706, 0.2314,  ..., 0.8314, 0.7255, 0.6627],
          [0.3216, 0.2588, 0.1804,  ..., 0.8392, 0.7608, 0.7647],
          [0.2745, 0.2941, 0.3098,  ..., 0.8549, 0.8235, 0.8706]],

         [[0.8431, 0.8392, 0.8275,  ..., 0.9412, 0.9451, 0.9451],
          [0.7765, 0.8196, 0.8353,  ..., 0.9373, 0.9451, 0.9529],
          [0.7686, 0.8000, 0.7725,  ..., 0.9137, 0.9255, 0.9373],
          ...,
          [0.2980, 0.2863, 0.2235,  ..., 0.7059, 0.5961, 0.5373],
          [0.3529, 0.2706, 0.1725,  ..., 0.6902, 0.6118, 0.6157],
          [0.3059, 0.3059, 0.3020,  ..., 0.7020, 0.6627, 0.7059]],

         [[0.8784, 0.8745, 0.8667,  ..., 0.9216, 0.9255, 0.9255],
          [0.8118, 0.8549, 0.8784,  ..., 0.9176, 0.9255, 0.9333],
          [0.8039, 0.8353, 0.8

validation:  61%|██████▏   | 35/57 [00:03<00:01, 16.92it/s]

{'sat_img': tensor([[[[0.7804, 0.7922, 0.7569,  ..., 0.4078, 0.4118, 0.4235],
          [0.9098, 0.9333, 0.9373,  ..., 0.4039, 0.4078, 0.4235],
          [0.8902, 0.9216, 0.9725,  ..., 0.4039, 0.4039, 0.4353],
          ...,
          [0.7098, 0.7647, 0.7451,  ..., 0.5137, 0.6000, 0.6627],
          [0.7686, 0.8078, 0.8588,  ..., 0.5843, 0.6275, 0.6431],
          [0.7843, 0.8118, 0.9098,  ..., 0.7137, 0.7098, 0.6667]],

         [[0.7412, 0.7529, 0.7137,  ..., 0.2627, 0.2941, 0.3137],
          [0.8824, 0.8941, 0.8980,  ..., 0.2588, 0.2824, 0.3137],
          [0.8627, 0.8941, 0.9412,  ..., 0.2588, 0.2784, 0.3255],
          ...,
          [0.6667, 0.7216, 0.7020,  ..., 0.4745, 0.5647, 0.6392],
          [0.7216, 0.7608, 0.8078,  ..., 0.5451, 0.6039, 0.6235],
          [0.7294, 0.7569, 0.8549,  ..., 0.6745, 0.6863, 0.6471]],

         [[0.7765, 0.7882, 0.7686,  ..., 0.4235, 0.4431, 0.4627],
          [0.9098, 0.9294, 0.9412,  ..., 0.4196, 0.4353, 0.4627],
          [0.8902, 0.9255, 0.9

validation:  68%|██████▊   | 39/57 [00:03<00:01, 16.91it/s]

{'sat_img': tensor([[[[0.9255, 0.9255, 0.9294,  ..., 0.7647, 0.8157, 0.8196],
          [0.9529, 0.9412, 0.9373,  ..., 0.7882, 0.8824, 0.9412],
          [0.9373, 0.9255, 0.9137,  ..., 0.7961, 0.9059, 1.0000],
          ...,
          [0.8588, 0.7961, 0.7373,  ..., 0.8275, 0.8784, 0.9255],
          [0.8157, 0.7608, 0.7176,  ..., 0.8627, 0.9137, 0.9255],
          [0.7608, 0.7176, 0.6980,  ..., 0.8745, 0.8941, 0.8549]],

         [[0.9412, 0.9412, 0.9373,  ..., 0.6667, 0.7255, 0.7294],
          [0.9686, 0.9569, 0.9451,  ..., 0.6902, 0.7922, 0.8510],
          [0.9529, 0.9412, 0.9216,  ..., 0.7059, 0.8157, 0.9137],
          ...,
          [0.7059, 0.6431, 0.5804,  ..., 0.7765, 0.8039, 0.8314],
          [0.6627, 0.6078, 0.5647,  ..., 0.8235, 0.8431, 0.8353],
          [0.6078, 0.5647, 0.5451,  ..., 0.8431, 0.8314, 0.7843]],

         [[0.9373, 0.9373, 0.9333,  ..., 0.7843, 0.8392, 0.8431],
          [0.9647, 0.9529, 0.9412,  ..., 0.8078, 0.9059, 0.9647],
          [0.9490, 0.9373, 0.9

validation:  75%|███████▌  | 43/57 [00:04<00:00, 16.58it/s]

{'sat_img': tensor([[[[0.9647, 0.9882, 0.9843,  ..., 0.6941, 0.6667, 0.6314],
          [0.9569, 0.9804, 0.9765,  ..., 0.6745, 0.6549, 0.6314],
          [0.9490, 0.9647, 0.9686,  ..., 0.6627, 0.6745, 0.6667],
          ...,
          [1.0000, 0.8941, 0.8314,  ..., 0.6392, 0.6745, 0.6784],
          [1.0000, 0.9647, 0.8784,  ..., 0.6667, 0.6941, 0.6980],
          [0.9686, 0.9765, 0.9451,  ..., 0.6627, 0.6784, 0.6706]],

         [[0.9333, 0.9569, 0.9529,  ..., 0.5059, 0.4784, 0.4431],
          [0.9255, 0.9490, 0.9451,  ..., 0.4863, 0.4667, 0.4431],
          [0.9176, 0.9333, 0.9373,  ..., 0.4745, 0.4863, 0.4824],
          ...,
          [0.9176, 0.7961, 0.7333,  ..., 0.4588, 0.4941, 0.4980],
          [0.9451, 0.8588, 0.7725,  ..., 0.4902, 0.5176, 0.5216],
          [0.8588, 0.8667, 0.8353,  ..., 0.4902, 0.5020, 0.4941]],

         [[0.9216, 0.9451, 0.9451,  ..., 0.7020, 0.6745, 0.6471],
          [0.9137, 0.9373, 0.9373,  ..., 0.6824, 0.6706, 0.6471],
          [0.9059, 0.9216, 0.9

validation:  79%|███████▉  | 45/57 [00:04<00:00, 13.58it/s]

torch.Size([1, 1, 512, 512])
S-1908-009781_PAS_2of2 [x=9728,y=10752,w=512,h=512].jpg
{'sat_img': tensor([[[[0.8392, 0.7765, 0.7059,  ..., 0.7255, 0.7569, 0.8314],
          [0.7686, 0.7216, 0.6745,  ..., 0.8078, 0.7882, 0.7843],
          [0.6941, 0.6549, 0.6118,  ..., 0.8392, 0.7922, 0.7333],
          ...,
          [0.7686, 0.8118, 0.8510,  ..., 0.6706, 0.5922, 0.5098],
          [0.7804, 0.8275, 0.8627,  ..., 0.6314, 0.5882, 0.5529],
          [0.7843, 0.8275, 0.8627,  ..., 0.6667, 0.6431, 0.6275]],

         [[0.6980, 0.6353, 0.5647,  ..., 0.6549, 0.6863, 0.7608],
          [0.6353, 0.5882, 0.5333,  ..., 0.7373, 0.7176, 0.7137],
          [0.5608, 0.5216, 0.4784,  ..., 0.7686, 0.7216, 0.6588],
          ...,
          [0.6235, 0.6745, 0.7216,  ..., 0.6000, 0.5333, 0.4549],
          [0.6353, 0.6824, 0.7333,  ..., 0.5529, 0.5294, 0.4980],
          [0.6275, 0.6824, 0.7255,  ..., 0.5882, 0.5765, 0.5725]],

         [[0.8392, 0.7765, 0.7059,  ..., 0.7490, 0.7882, 0.8627],
          [

validation:  82%|████████▏ | 47/57 [00:04<00:00, 12.98it/s]

torch.Size([1, 1, 512, 512])
S-1908-009781_PAS_2of2 [x=7168,y=17408,w=512,h=512].jpg
{'sat_img': tensor([[[[0.9882, 0.9020, 0.7765,  ..., 0.7137, 0.6706, 0.6706],
          [1.0000, 0.9451, 0.8588,  ..., 0.6784, 0.6824, 0.7373],
          [0.9804, 0.8784, 0.8157,  ..., 0.6706, 0.6902, 0.7529],
          ...,
          [0.9608, 0.8863, 0.7686,  ..., 0.7373, 0.6549, 0.6588],
          [0.9882, 0.9255, 0.8314,  ..., 0.7725, 0.6863, 0.6824],
          [0.9843, 0.9137, 0.8745,  ..., 0.7961, 0.7020, 0.6980]],

         [[0.8706, 0.7804, 0.6706,  ..., 0.6627, 0.6196, 0.6196],
          [0.9098, 0.8275, 0.7529,  ..., 0.6235, 0.6275, 0.6824],
          [0.8784, 0.7765, 0.7137,  ..., 0.6157, 0.6353, 0.6980],
          ...,
          [0.8078, 0.7608, 0.6706,  ..., 0.6000, 0.5176, 0.5176],
          [0.8392, 0.7922, 0.7373,  ..., 0.6000, 0.5098, 0.4980],
          [0.8314, 0.7843, 0.7804,  ..., 0.6078, 0.5059, 0.4902]],

         [[0.8784, 0.8000, 0.6980,  ..., 0.7020, 0.6588, 0.6588],
          [

validation:  89%|████████▉ | 51/57 [00:04<00:00, 13.37it/s]

torch.Size([1, 1, 512, 512])
S-1908-009781_PAS_2of2 [x=6144,y=14848,w=512,h=512].jpg
{'sat_img': tensor([[[[0.9843, 0.9608, 0.9529,  ..., 0.7059, 0.7176, 0.8471],
          [0.9608, 0.9412, 0.9333,  ..., 0.6824, 0.7451, 0.8627],
          [0.9490, 0.9333, 0.9255,  ..., 0.6627, 0.7529, 0.8118],
          ...,
          [0.7490, 0.7804, 0.7882,  ..., 0.9725, 0.9176, 0.8392],
          [0.6745, 0.7020, 0.7569,  ..., 0.9020, 0.8549, 0.7686],
          [0.6196, 0.6275, 0.6824,  ..., 0.8667, 0.8549, 0.8000]],

         [[0.9451, 0.9216, 0.9137,  ..., 0.5922, 0.6039, 0.7333],
          [0.9294, 0.9059, 0.8980,  ..., 0.5686, 0.6431, 0.7608],
          [0.9137, 0.8980, 0.8902,  ..., 0.5608, 0.6510, 0.7098],
          ...,
          [0.6745, 0.6902, 0.6824,  ..., 0.8902, 0.8431, 0.7647],
          [0.6588, 0.6588, 0.6667,  ..., 0.8196, 0.7725, 0.6941],
          [0.6314, 0.6039, 0.6078,  ..., 0.7843, 0.7725, 0.7176]],

         [[0.9412, 0.9255, 0.9176,  ..., 0.7020, 0.7137, 0.8431],
          [

validation:  93%|█████████▎| 53/57 [00:04<00:00, 13.32it/s]

torch.Size([1, 1, 512, 512])
S-1908-009781_PAS_2of2 [x=5632,y=16896,w=512,h=512].jpg
{'sat_img': tensor([[[[0.8000, 0.9922, 1.0000,  ..., 0.2627, 0.2510, 0.3059],
          [0.9020, 1.0000, 0.9804,  ..., 0.2863, 0.3294, 0.3765],
          [0.8980, 1.0000, 1.0000,  ..., 0.3020, 0.3804, 0.4196],
          ...,
          [0.4824, 0.6078, 0.8000,  ..., 0.6392, 0.6980, 0.7098],
          [0.4902, 0.6980, 0.8980,  ..., 0.5647, 0.6157, 0.7137],
          [0.5216, 0.7647, 0.9608,  ..., 0.5333, 0.5882, 0.7490]],

         [[0.7529, 0.9333, 0.9412,  ..., 0.2196, 0.2078, 0.2627],
          [0.8588, 0.9686, 0.9137,  ..., 0.2510, 0.2863, 0.3333],
          [0.8706, 0.9765, 0.9765,  ..., 0.2706, 0.3490, 0.3882],
          ...,
          [0.4431, 0.5804, 0.7765,  ..., 0.4980, 0.5529, 0.5765],
          [0.4627, 0.6745, 0.8824,  ..., 0.4314, 0.4824, 0.5882],
          [0.4980, 0.7490, 0.9529,  ..., 0.4000, 0.4627, 0.6314]],

         [[0.7686, 0.9529, 0.9647,  ..., 0.3882, 0.3765, 0.4314],
          [

validation: 100%|██████████| 57/57 [00:05<00:00, 14.43it/s]

torch.Size([1, 1, 512, 512])
S-1908-009781_PAS_2of2 [x=7680,y=11264,w=512,h=512].jpg
{'sat_img': tensor([[[[0.7961, 0.8157, 0.8353,  ..., 0.6078, 0.6157, 0.6510],
          [0.8275, 0.8314, 0.8275,  ..., 0.5804, 0.5882, 0.6157],
          [0.8471, 0.8471, 0.8431,  ..., 0.4706, 0.4627, 0.4667],
          ...,
          [0.7922, 0.7922, 0.7529,  ..., 0.6863, 0.6549, 0.6431],
          [0.8000, 0.8078, 0.7882,  ..., 0.6824, 0.6667, 0.6667],
          [0.7412, 0.7529, 0.7647,  ..., 0.7098, 0.7333, 0.7569]],

         [[0.7333, 0.7529, 0.7765,  ..., 0.5373, 0.5451, 0.5765],
          [0.7686, 0.7725, 0.7686,  ..., 0.5059, 0.5137, 0.5412],
          [0.7882, 0.7882, 0.7843,  ..., 0.4039, 0.3961, 0.4000],
          ...,
          [0.6667, 0.6824, 0.6667,  ..., 0.5451, 0.5137, 0.5020],
          [0.6745, 0.6941, 0.6941,  ..., 0.5529, 0.5373, 0.5373],
          [0.6078, 0.6353, 0.6706,  ..., 0.5804, 0.6039, 0.6275]],

         [[0.8353, 0.8549, 0.8667,  ..., 0.7020, 0.7098, 0.7529],
          [

validation: 100%|██████████| 57/57 [00:05<00:00, 10.98it/s]


In [4]:
import os
import numpy as np
from skimage.io import imread

In [5]:
def calculate_scores(ground_truth, predicted):
    # print(np.shape(ground_truth))
    # print(np.shape(predicted))
    ground_truth = np.asarray(ground_truth)
    predicted = np.asarray(predicted)
    
    # Flatten the arrays to simplify the calculations
    ground_truth_flat = ground_truth.flatten()
    predicted_flat = predicted.flatten()
    
    # Calculate the intersection and union
    intersection = np.sum(ground_truth_flat * predicted_flat)
    union = np.sum(ground_truth_flat) + np.sum(predicted_flat)
    
    # Dice coefficient
    dice_score = (2 * intersection) / union if union != 0 else 1.0
    
    # Intersection over Union
    iou_score = intersection / (union - intersection) if (union - intersection) != 0 else 1.0
    return dice_score, iou_score

def read_images_and_calculate_scores(gt_folder, pred_folder):
    dice_scores = []
    iou_scores = []
    files = os.listdir(gt_folder)
    for gt_filename in files:
        if gt_filename.endswith('.png'):
            # Construct the predicted image filename
            pred_filename = 'pred_' + gt_filename
            
            # Construct the full paths to the images
            gt_path = os.path.join(gt_folder, gt_filename)
            pred_path = os.path.join(pred_folder, pred_filename)
            
            # Read the images
            gt_image = imread(gt_path)
            pred_image = imread(pred_path, as_gray=True)
            
            # Ensure the images are binary
            gt_image = gt_image > 0
            pred_image = pred_image > 0
            
            # Calculate the scores
            dice_score, iou_score = calculate_scores(gt_image, pred_image)
            # Append the scores to the list
            dice_scores.append(dice_score)
            iou_scores.append(iou_score)
    return dice_scores, iou_scores
    
gt_folder = '/home/manojkumargalla/ResUNet/data/test/mask_crop/'
pred_folder = '/home/manojkumargalla/ResUNet/data/test/predicted_mask_crop_newmodel/'
dice_scores, iou_scores = read_images_and_calculate_scores(gt_folder, pred_folder)
avg_dice_score = sum(dice_scores)/len(dice_scores)
avg_iou_score = sum(iou_scores)/len(iou_scores)
print(round(avg_dice_score, 2))
print(round(avg_iou_score, 2))

0.63
0.48


In [None]:
0.471480057879705
0.3260011468578666

In [None]:
0.5365197075794806
0.387148930603562