In [None]:
# ! pip install scikit-image
# ! pip install tensorboardX

import warnings

warnings.simplefilter("ignore", (UserWarning, FutureWarning))
from utils.hparams import HParam
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
from dataset import dataloader
from utils import metrics
from core.res_unet import ResUnet
from core.res_unet_plus import ResUnetPlusPlus
from utils.logger import MyWriter
import torch
import argparse
import os
import skimage

from torchvision.utils import save_image

def main(hp, name):
    
    output_dir = hp.pred

    os.makedirs("{}/{}".format(hp.log, name), exist_ok=True)
    writer = MyWriter("{}/{}".format(hp.log, name))
    # get model

    if hp.RESNET_PLUS_PLUS:
        model = ResUnetPlusPlus(3).cuda()
    else:
        model = ResUnet(3, 64).cuda()

    checkpoint = torch.load(hp.checkpoints)

    model.load_state_dict(checkpoint["state_dict"])
    print(
        "=> loaded checkpoint '{}' (epoch {})".format(
            hp.checkpoints, checkpoint["epoch"]
        )
    )

    mass_dataset_val = dataloader.ImageDataset(
        hp, False, transform=transforms.Compose([dataloader.ToTensorTarget()])
    )
    
    # creating loaders

    val_dataloader = DataLoader(
        mass_dataset_val, batch_size=1, num_workers=2, shuffle=False
    )

    model.eval()
    print(type(val_dataloader))
    
    for idx, data in enumerate(tqdm(val_dataloader, desc="validation")):
        print(data)

        inputs = data["sat_img"].cuda()

        prob_map = model(inputs) # last activation was a sigmoid
        outputs = (prob_map > 0.3).float()
        
        print(outputs.shape)
        image_name = mass_dataset_val.getPath(idx).split("/")[-1]
        print(image_name)
        
        img = outputs[0] #torch.Size([3,28,28]
        
        save_image(img, f'{output_dir}/pred_img_{idx}_{image_name}')
        
if __name__ == "__main__":
    
    args = {
        "name": "ResUnetTrain3",
        "config": "configs/default3.yaml",
        "resume": ""
    }
    
    class Struct:
        def __init__(self, entries):
            self.__dict__.update(entries)
            
    args = Struct(args)
    
    hp = HParam(args.config)
    with open(args.config, "r") as f:
        hp_str = "".join(f.readlines())

    main(hp, name=args.name)


In [1]:
import os, sys
notebook_dir = os.getcwd()
project_dir = os.path.abspath(os.path.join(notebook_dir, '..'))
if project_dir not in sys.path:
    sys.path.append(project_dir)
utils_dir = os.path.abspath(os.path.join(notebook_dir, '..', 'utils'))
if utils_dir not in sys.path:
    sys.path.append(utils_dir)

import warnings
warnings.simplefilter("ignore", (UserWarning, FutureWarning))
from hparams import HParam
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import dataloader
import metrics
from core.res_unet import ResUnet
from core.res_unet_plus import ResUnetPlusPlus
from logger import MyWriter
import torch
import argparse
import glob
import skimage
from torchvision.utils import save_image

def main(hp, name):
    output_dir = hp.pred
    os.makedirs("{}/{}".format(hp.log, name), exist_ok=True)
    writer = MyWriter("{}/{}".format(hp.log, name))
    
    if hp.RESNET_PLUS_PLUS:
        model = ResUnetPlusPlus(3).cuda()
    else:
        model = ResUnet(3, 64).cuda()

    checkpoint = torch.load(hp.checkpoints)
    model.load_state_dict(checkpoint["state_dict"])
    print("=> loaded checkpoint '{}' (epoch {})".format(hp.checkpoints, checkpoint["epoch"]))
    
    folder_names = [name for name in os.listdir(hp.valid) if os.path.isdir(os.path.join(hp.valid, name))]
    print(folder_names)
    print (f'There are {len(folder_names)} image folders')
    for foldername in folder_names:
        output_dir = hp.pred
        output_dir = os.path.join(output_dir, foldername)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        mass_dataset_val = dataloader.ImageDataset(
            foldername, hp, False, transform=transforms.Compose([dataloader.ToTensorTarget()])
        )
        print(f'working on {foldername}')
        # print(os.path.join(mass_dataset_val.path, "*.jpg"))
        # image_list = glob.glob(
        #         os.path.join(mass_dataset_val.path, "*.jpg"), recursive=True
        #     )
        # print(len(image_list))
        # # Debugging print statements
        # print(f"Validation path: {hp.valid}")
        # print(f"Number of validation images: {len(mass_dataset_val.image_list)}")


        val_dataloader = DataLoader(
            mass_dataset_val, batch_size=1, num_workers=2, shuffle=False
        )

        model.eval()
        print(type(val_dataloader))

        for idx, data in enumerate(tqdm(val_dataloader, desc="validation")):
            print(data)
            inputs = data["sat_img"].cuda()
            prob_map = model(inputs) # last activation was a sigmoid
            outputs = (prob_map > 0.3).float()
            print(outputs.shape)
            image_name = mass_dataset_val.getPath(idx).split("/")[-1]
            print(image_name)
            img = outputs[0] #torch.Size([3,28,28]
            save_image(img, f'{output_dir}/pred_{foldername}_{idx}_{image_name}')
        
if __name__ == "__main__":
    args = {
        "name": "ResUnetTrain3",
        "config": "/home/manojkumargalla/PostProcess/config/default3.yaml",
        "resume": ""
    }
    
    class Struct:
        def __init__(self, entries):
            self.__dict__.update(entries)
            
    args = Struct(args)
    
    hp = HParam(args.config)
    with open(args.config, "r") as f:
        hp_str = "".join(f.readlines())

    main(hp, name=args.name)


=> loaded checkpoint '/home/manojkumargalla/PostProcess/models/ResUnetTrain5_checkpoint_1000.pt' (epoch 45)
['S-1905-018731_PAS_2of2', 'S-1904-007293_PAS_1of2', 'S-1905-017738_PAS_1of2', 'S-2106-003588_PAS_1of2', 'S-1908-010066_PAS_1of2', 'S-2001-005357_PAS_1of2', 'S-2103-004857_PAS_2of2', 'S-1909-007149_PAS_1of2', 'S-1910-000089_PAS_2of2', '18-162_PAS_4of6']
There are 10 image folders
Found 548 images
working on S-1905-018731_PAS_2of2
<class 'torch.utils.data.dataloader.DataLoader'>


validation:   0%|          | 0/548 [00:00<?, ?it/s]

{'sat_img': tensor([[[[0.7922, 0.7294, 0.6941,  ..., 0.6627, 0.6980, 0.7255],
          [0.6824, 0.6353, 0.6471,  ..., 0.5922, 0.6549, 0.6667],
          [0.7373, 0.7216, 0.7373,  ..., 0.5647, 0.5843, 0.5804],
          ...,
          [0.4706, 0.4549, 0.3373,  ..., 0.7333, 0.7020, 0.6863],
          [0.4314, 0.4510, 0.4196,  ..., 0.7333, 0.7098, 0.6667],
          [0.4431, 0.4510, 0.4157,  ..., 0.6980, 0.6627, 0.6588]],

         [[0.6784, 0.6157, 0.5373,  ..., 0.3490, 0.4275, 0.4549],
          [0.6039, 0.5569, 0.5216,  ..., 0.3216, 0.4000, 0.4118],
          [0.6941, 0.6784, 0.6627,  ..., 0.3529, 0.3569, 0.3529],
          ...,
          [0.1647, 0.1490, 0.1765,  ..., 0.6078, 0.5725, 0.5569],
          [0.1412, 0.1608, 0.2078,  ..., 0.6000, 0.5804, 0.5373],
          [0.1490, 0.1569, 0.1843,  ..., 0.5647, 0.5255, 0.5216]],

         [[0.7686, 0.7059, 0.7843,  ..., 0.6706, 0.7137, 0.7412],
          [0.6667, 0.6196, 0.7333,  ..., 0.6471, 0.6941, 0.7059],
          [0.7373, 0.7216, 0.8

validation:   0%|          | 0/548 [00:41<?, ?it/s]

torch.Size([1, 1, 512, 512])
13312x_09728y.png





NameError: name 'folder_name' is not defined

In [None]:
from utils.hparams import HParam
from dataset.dataloader import ImageDataset

hp = HParam('configs/default3.yaml')
mass_dataset_val = ImageDataset(hp, train=False, transform=None)

print(f"Validation path: {hp.valid}")
print(f"Number of validation images in dataset: {len(mass_dataset_val.image_list)}")
print(f"Number of validation images using __len__: {len(mass_dataset_val)}")


In [None]:
print(mass_dataset_val.image_list[0])