Skip to content

Commit

Permalink
Add test_mode in dataloader to enable testine w/o compiling training …
Browse files Browse the repository at this point in the history
…sets
  • Loading branch information
meetps committed Jan 8, 2019
1 parent 4654882 commit bb561cf
Show file tree
Hide file tree
Showing 11 changed files with 51 additions and 55 deletions.
10 changes: 0 additions & 10 deletions ptsemseg/loader/__init__.py
Expand Up @@ -25,13 +25,3 @@ def get_loader(name):
"sunrgbd": SUNRGBDLoader,
"vistas": mapillaryVistasLoader,
}[name]


def get_data_path(name, config_file="config.json"):
"""get_data_path
:param name:
:param config_file:
"""
data = json.load(open(config_file))
return data[name]["data_path"]
13 changes: 8 additions & 5 deletions ptsemseg/loader/ade20k_loader.py
Expand Up @@ -19,22 +19,25 @@ def __init__(
img_size=512,
augmentations=None,
img_norm=True,
test_mode=False,
):
self.root = root
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.test_mode = test_mode
self.n_classes = 150
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = collections.defaultdict(list)

for split in ["training", "validation"]:
file_list = recursive_glob(
rootdir=self.root + "images/" + self.split + "/", suffix=".jpg"
)
self.files[split] = file_list
if not self.test_mode:
for split in ["training", "validation"]:
file_list = recursive_glob(
rootdir=self.root + "images/" + self.split + "/", suffix=".jpg"
)
self.files[split] = file_list

def __len__(self):
return len(self.files[self.split])
Expand Down
13 changes: 8 additions & 5 deletions ptsemseg/loader/camvid_loader.py
Expand Up @@ -6,7 +6,7 @@
import matplotlib.pyplot as plt

from torch.utils import data
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, raw_input
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate


class camvidLoader(data.Dataset):
Expand All @@ -18,20 +18,23 @@ def __init__(
img_size=None,
augmentations=None,
img_norm=True,
test_mode=False
):
self.root = root
self.split = split
self.img_size = [360, 480]
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.test_mode = test_mode
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.n_classes = 12
self.files = collections.defaultdict(list)

for split in ["train", "test", "val"]:
file_list = os.listdir(root + "/" + split)
self.files[split] = file_list
if not self.test_mode:
for split in ["train", "test", "val"]:
file_list = os.listdir(root + "/" + split)
self.files[split] = file_list

def __len__(self):
return len(self.files[self.split])
Expand Down Expand Up @@ -132,7 +135,7 @@ def decode_segmap(self, temp, plot=False):
axarr[j][0].imshow(imgs[j])
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
plt.show()
a = raw_input()
a = input()
if a == "ex":
break
else:
Expand Down
5 changes: 3 additions & 2 deletions ptsemseg/loader/cityscapes_loader.py
Expand Up @@ -6,7 +6,7 @@
from torch.utils import data

from ptsemseg.utils import recursive_glob
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale, raw_input
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale


class cityscapesLoader(data.Dataset):
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
augmentations=None,
img_norm=True,
version="cityscapes",
test_mode=False
):
"""__init__
Expand Down Expand Up @@ -245,7 +246,7 @@ def encode_segmap(self, mask):
axarr[j][0].imshow(imgs[j])
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
plt.show()
a = raw_input()
a = input()
if a == "ex":
break
else:
Expand Down
5 changes: 3 additions & 2 deletions ptsemseg/loader/mapillary_vistas_loader.py
Expand Up @@ -4,14 +4,15 @@
import numpy as np

from torch.utils import data
from PIL import Image

from ptsemseg.utils import recursive_glob
from ptsemseg.augmentations import Compose, Image, RandomHorizontallyFlip, RandomRotate
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate


class mapillaryVistasLoader(data.Dataset):
def __init__(
self, root, split="training", img_size=(640, 1280), is_transform=True, augmentations=None
self, root, split="training", img_size=(640, 1280), is_transform=True, augmentations=None, test_mode=False
):
self.root = root
self.split = split
Expand Down
1 change: 1 addition & 0 deletions ptsemseg/loader/mit_sceneparsing_benchmark_loader.py
Expand Up @@ -32,6 +32,7 @@ def __init__(
img_size=512,
augmentations=None,
img_norm=True,
test_mode=False
):
"""__init__
Expand Down
6 changes: 4 additions & 2 deletions ptsemseg/loader/nyuv2_loader.py
Expand Up @@ -7,7 +7,7 @@
from torch.utils import data

from ptsemseg.utils import recursive_glob
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale, raw_input
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale


class NYUv2Loader(data.Dataset):
Expand All @@ -31,12 +31,14 @@ def __init__(
img_size=(480, 640),
augmentations=None,
img_norm=True,
test_mode=False
):
self.root = root
self.is_transform = is_transform
self.n_classes = 14
self.augmentations = augmentations
self.img_norm = img_norm
self.test_mode = test_mode
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = collections.defaultdict(list)
Expand Down Expand Up @@ -156,7 +158,7 @@ def decode_segmap(self, temp):
axarr[j][0].imshow(imgs[j])
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
plt.show()
a = raw_input()
a = input()
if a == "ex":
break
else:
Expand Down
37 changes: 15 additions & 22 deletions ptsemseg/loader/pascal_voc_loader.py
Expand Up @@ -15,20 +15,6 @@
from torchvision import transforms


def get_data_path(name):
"""Extract path to data from config file.
Args:
name (str): The name of the dataset.
Returns:
(str): The path to the root directory containing the dataset.
"""
js = open("config.json").read()
data = json.loads(js)
return os.path.expanduser(data[name]["data_path"])


class pascalVOCLoader(data.Dataset):
"""Data loader for the Pascal VOC semantic segmentation dataset.
Expand Down Expand Up @@ -58,27 +44,34 @@ class pascalVOCLoader(data.Dataset):
def __init__(
self,
root,
sbd_path=None,
split="train_aug",
is_transform=False,
img_size=512,
augmentations=None,
img_norm=True,
test_mode=False,
):
self.root = os.path.expanduser(root)
self.root = root
self.sbd_path = sbd_path
self.split = split
self.is_transform = is_transform
self.augmentations = augmentations
self.img_norm = img_norm
self.test_mode = test_mode
self.n_classes = 21
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = collections.defaultdict(list)
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
for split in ["train", "val", "trainval"]:
path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt")
file_list = tuple(open(path, "r"))
file_list = [id_.rstrip() for id_ in file_list]
self.files[split] = file_list
self.setup_annotations()

if not self.test_mode:
for split in ["train", "val", "trainval"]:
path = pjoin(self.root, "ImageSets/Segmentation", split + ".txt")
file_list = tuple(open(path, "r"))
file_list = [id_.rstrip() for id_ in file_list]
self.files[split] = file_list
self.setup_annotations()

self.tf = transforms.Compose(
[
transforms.ToTensor(),
Expand Down Expand Up @@ -199,7 +192,7 @@ def setup_annotations(self):
function also defines the `train_aug` and `train_aug_val` data splits
according to the description in the class docstring
"""
sbd_path = get_data_path("sbd")
sbd_path = self.sbd_path
target_path = pjoin(self.root, "SegmentationClass/pre_encoded")
if not os.path.exists(target_path):
os.makedirs(target_path)
Expand Down
6 changes: 4 additions & 2 deletions ptsemseg/loader/sunrgbd_loader.py
Expand Up @@ -6,7 +6,7 @@
from torch.utils import data

from ptsemseg.utils import recursive_glob
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale, raw_input
from ptsemseg.augmentations import Compose, RandomHorizontallyFlip, RandomRotate, Scale


class SUNRGBDLoader(data.Dataset):
Expand All @@ -30,12 +30,14 @@ def __init__(
img_size=(480, 640),
augmentations=None,
img_norm=True,
test_mode=False
):
self.root = root
self.is_transform = is_transform
self.n_classes = 38
self.augmentations = augmentations
self.img_norm = img_norm
self.test_mode = test_mode
self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
self.mean = np.array([104.00699, 116.66877, 122.67892])
self.files = collections.defaultdict(list)
Expand Down Expand Up @@ -161,7 +163,7 @@ def decode_segmap(self, temp):
axarr[j][0].imshow(imgs[j])
axarr[j][1].imshow(dst.decode_segmap(labels.numpy()[j]))
plt.show()
a = raw_input()
a = input()
if a == "ex":
break
else:
Expand Down
2 changes: 1 addition & 1 deletion ptsemseg/models/frrn.py
Expand Up @@ -26,7 +26,7 @@ class frrn(nn.Module):
2) TF implementation by @kiwonjoon: https://github.com/hiwonjoon/tf-frrn
"""

def __init__(self, n_classes=21, model_type=None, group_norm=False, n_groups=16):
def __init__(self, n_classes=21, model_type='B', group_norm=False, n_groups=16):
super(frrn, self).__init__()
self.n_classes = n_classes
self.model_type = model_type
Expand Down
8 changes: 4 additions & 4 deletions test.py
Expand Up @@ -6,7 +6,7 @@


from ptsemseg.models import get_model
from ptsemseg.loader import get_loader, get_data_path
from ptsemseg.loader import get_loader
from ptsemseg.utils import convert_state_dict

try:
Expand All @@ -30,8 +30,7 @@ def test(args):
img = misc.imread(args.img_path)

data_loader = get_loader(args.dataset)
data_path = get_data_path(args.dataset)
loader = data_loader(data_path, is_transform=True, img_norm=args.img_norm)
loader = data_loader(root=None, is_transform=True, img_norm=args.img_norm, test_mode=True)
n_classes = loader.n_classes

resized_img = misc.imresize(img, (loader.img_size[0], loader.img_size[1]), interp="bicubic")
Expand All @@ -55,7 +54,8 @@ def test(args):
img = torch.from_numpy(img).float()

# Setup Model
model = get_model(model_name, n_classes, version=args.dataset)
model_dict = {"arch": model_name}
model = get_model(model_dict, n_classes, version=args.dataset)
state = convert_state_dict(torch.load(args.model_path)["model_state"])
model.load_state_dict(state)
model.eval()
Expand Down

0 comments on commit bb561cf

Please sign in to comment.