Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 83 additions & 16 deletions src/python/turicreate/test/test_image_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,25 @@ def get_psnr(x, y):
)

# Get model distances for comparison
img = data[0:1][self.feature][0]
img_fixed = tc.image_analysis.resize(img, *reversed(self.input_image_shape))
tc_ret = self.model.query(img_fixed, k=data.num_rows())

if _mac_ver() >= (10, 13):
from PIL import Image as _PIL_Image

pil_img = _PIL_Image.fromarray(img_fixed.pixel_data)
coreml_ret = coreml_model.predict({"awesome_image": pil_img})

# Compare distances
coreml_distances = np.array(coreml_ret["distance"])
tc_distances = tc_ret.sort("reference_label")["distance"].to_numpy()
psnr_value = get_psnr(coreml_distances, tc_distances)
self.assertTrue(psnr_value > 50)
if self.feature == "awesome_image":
img = data[0:1][self.feature][0]
img_fixed = tc.image_analysis.resize(img, *reversed(self.input_image_shape))
tc_ret = self.model.query(img_fixed, k=data.num_rows())

if _mac_ver() >= (10, 13):
from PIL import Image as _PIL_Image

pil_img = _PIL_Image.fromarray(img_fixed.pixel_data)
coreml_ret = coreml_model.predict({"awesome_image": pil_img})

# Compare distances
coreml_distances = np.array(coreml_ret["distance"])
tc_distances = tc_ret.sort("reference_label")["distance"].to_numpy()
psnr_value = get_psnr(coreml_distances, tc_distances)
self.assertTrue(psnr_value > 50)
else:
# Broad else clause to ignore features not supported in coreml
pass

def test_save_and_load(self):
with test_util.TempDirectory() as filename:
Expand All @@ -287,6 +291,60 @@ def test_save_and_load(self):
print("Export coreml passed")


class ImageSimilarityTestWithKwargs(unittest.TestCase):
@classmethod
def setUpClass(self, input_image_shape=(3, 224, 224), model="resnet-50"):
"""
The setup class method for the basic test case with all default values.
"""
self.feature = "awesome_image"
self.label = None
self.input_image_shape = input_image_shape
self.pre_trained_model = model

# Create the model
self.def_opts = {
"model": "resnet-50",
"verbose": True,
}

# Model
self.model = tc.image_similarity.create(
data, feature=self.feature, label=None, model=self.pre_trained_model,
method='lsh', distance='squared_euclidean'
)
self.nn_model = self.model.feature_extractor
self.lm_model = self.model.similarity_model
self.opts = self.def_opts.copy()

# Answers
self.get_ans = {
"similarity_model": lambda x: type(x)
== tc.nearest_neighbors.NearestNeighborsModel,
"feature": lambda x: x == self.feature,
"training_time": lambda x: x > 0,
"input_image_shape": lambda x: x == self.input_image_shape,
"label": lambda x: x == self.label,
"feature_extractor": lambda x: callable(x.extract_features),
"num_features": lambda x: x == self.lm_model.num_features,
"num_examples": lambda x: x == self.lm_model.num_examples,
"model": lambda x: (
x == self.pre_trained_model
or (
self.pre_trained_model == "VisionFeaturePrint_Screen"
and x == "VisionFeaturePrint_Scene"
)
),
}
self.fields_ans = self.get_ans.keys()

def assertModelWorks(self):
self.assertEqual(self.model.similarity_model.distance[0][1],
'squared_euclidean'
)



class ImageSimilaritySqueezeNetTest(ImageSimilarityTest):
@classmethod
def setUpClass(self):
Expand All @@ -306,7 +364,7 @@ def setUpClass(self):
)


# A test to gaurantee that old code using the incorrect name still works.
# A test to guarantee that old code using the incorrect name still works.
@unittest.skipIf(
_mac_ver() < (10, 14), "VisionFeaturePrint_Scene only supported on macOS 10.14+"
)
Expand All @@ -316,3 +374,12 @@ def setUpClass(self):
super(ImageSimilarityVisionFeaturePrintSceneTest_bad_name, self).setUpClass(
model="VisionFeaturePrint_Screen", input_image_shape=(3, 299, 299)
)


# A test to ensure kwargs are still accepted in create()
class ImageSimilarityCreateKwargsTest(ImageSimilarityTest):
@classmethod
def setUpClass(self):
super(ImageSimilarityCreateKwargsTest, self).setUpClass(
model="resnet-50", input_image_shape=(3, 300, 300)
)