forked from weecology/DeepForest
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_dataset.py
148 lines (121 loc) · 5.03 KB
/
test_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#test dataset model
from deepforest import get_data
from deepforest import dataset
from deepforest import utilities
import os
import pytest
import torch
import pandas as pd
import numpy as np
import tempfile
def single_class():
csv_file = get_data("example.csv")
return csv_file
def multi_class():
csv_file = get_data("testfile_multi.csv")
return csv_file
@pytest.mark.parametrize("csv_file,label_dict",[(single_class(), {"Tree":0}), (multi_class(),{"Alive":0,"Dead":1})])
def test_TreeDataset(csv_file, label_dict):
root_dir = os.path.dirname(get_data("OSBS_029.png"))
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
label_dict=label_dict)
raw_data = pd.read_csv(csv_file)
assert len(ds) == len(raw_data.image_path.unique())
for i in range(len(ds)):
#Between 0 and 1
path, image, targets = ds[i]
assert image.max() <= 1
assert image.min() >= 0
assert targets["boxes"].shape == (raw_data.shape[0],4)
assert targets["labels"].shape == (raw_data.shape[0],)
assert len(np.unique(targets["labels"])) == len(raw_data.label.unique())
def test_single_class_with_empty(tmpdir):
"""Add fake empty annotations to test parsing """
csv_file1 = get_data("example.csv")
csv_file2 = get_data("OSBS_029.csv")
df1 = pd.read_csv(csv_file1)
df2 = pd.read_csv(csv_file2)
df = pd.concat([df1,df2])
df.loc[df.image_path == "OSBS_029.tif","xmin"] = 0
df.loc[df.image_path == "OSBS_029.tif","ymin"] = 0
df.loc[df.image_path == "OSBS_029.tif","xmax"] = 0
df.loc[df.image_path == "OSBS_029.tif","ymax"] = 0
df.to_csv("{}_test_empty.csv".format(tmpdir))
root_dir = os.path.dirname(get_data("OSBS_029.png"))
ds = dataset.TreeDataset(csv_file="{}_test_empty.csv".format(tmpdir),
root_dir=root_dir,
label_dict={"Tree":0})
assert len(ds) == 2
#First image has annotations
assert not torch.sum(ds[0][2]["boxes"]) == 0
#Second image has no annotations
assert torch.sum(ds[1][2]["boxes"]) == 0
@pytest.mark.parametrize("augment",[True,False])
def test_TreeDataset_transform(augment):
csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
transforms=dataset.get_transform(augment=augment))
for i in range(len(ds)):
#Between 0 and 1
path, image, targets = ds[i]
assert image.max() <= 1
assert image.min() >= 0
assert targets["boxes"].shape == (79, 4)
assert targets["labels"].shape == (79,)
assert torch.is_tensor(targets["boxes"])
assert torch.is_tensor(targets["labels"])
assert torch.is_tensor(image)
def test_collate():
"""Due to data augmentations the dataset class may yield empty bounding box annotations"""
csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
transforms=dataset.get_transform(augment=False))
for i in range(len(ds)):
#Between 0 and 1
batch = ds[i]
collated_batch = utilities.collate_fn(batch)
assert len(collated_batch) == 2
def test_empty_collate():
"""Due to data augmentations the dataset class may yield empty bounding box annotations"""
csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
transforms=dataset.get_transform(augment=False))
for i in range(len(ds)):
#Between 0 and 1
batch = ds[i]
collated_batch = utilities.collate_fn([None, batch, batch])
len(collated_batch[0]) == 2
def test_predict_dataloader():
csv_file = get_data("example.csv")
root_dir = os.path.dirname(csv_file)
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
train=False)
image = next(iter(ds))
#Assert image is channels first format
assert image.shape[0] == 3
def test_multi_image_warning():
tmpdir = tempfile.gettempdir()
csv_file1 = get_data("example.csv")
csv_file2 = get_data("OSBS_029.csv")
df1 = pd.read_csv(csv_file1)
df2 = pd.read_csv(csv_file2)
df = pd.concat([df1, df2])
csv_file = "{}/multiple.csv".format(tmpdir)
df.to_csv(csv_file)
root_dir = os.path.dirname(csv_file1)
ds = dataset.TreeDataset(csv_file=csv_file,
root_dir=root_dir,
transforms=dataset.get_transform(augment=False))
for i in range(len(ds)):
#Between 0 and 1
batch = ds[i]
collated_batch = utilities.collate_fn([None, batch, batch])
len(collated_batch[0]) == 2