Skip to content

Commit

Permalink
dev(narugo): add prediction && use float as tagger output
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed May 14, 2024
1 parent 1b9e2c0 commit 5fb64ae
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
22 changes: 15 additions & 7 deletions imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,16 @@ def get_wd14_tags(
:return: A tuple containing dictionaries for rating, general, and character tags with their probabilities.
:rtype: Tuple[Dict[str, float], Dict[str, float], Dict[str, float]]
.. note::
About ``fmt`` argument, these are the available names:
* ``rating``, a dict containing ratings and their confidences
* ``general``, a dict containing general tags and their confidences
* ``character``, a dict containing character tags and their confidences
* ``tag``, a dict containing all tags (including general and character, not including rating) and their confidences
* ``embedding``, a 1-dim embedding of image, recommended for index building after L2 normalization
* ``prediction``, a 1-dim prediction result of image
Example:
Here are some images for example
Expand Down Expand Up @@ -202,16 +212,14 @@ def get_wd14_tags(
preds, embeddings = model.run([label_name, emb_name], {input_name: image})
labels = list(zip(tag_names, preds[0].astype(float)))

ratings_names = [labels[i] for i in rating_indexes]
rating = dict(ratings_names)
rating = {labels[i][0]: labels[i][1].item() for i in rating_indexes}

general_names = [labels[i] for i in general_indexes]
if general_mcut_enabled:
general_probs = np.array([x[1] for x in general_names])
general_threshold = _mcut_threshold(general_probs)

general_res = [x for x in general_names if x[1] > general_threshold]
general_res = dict(general_res)
general_res = {x: v.item() for x, v in general_names if v > general_threshold}
if drop_overlap:
general_res = drop_overlap_tags(general_res)

Expand All @@ -221,8 +229,7 @@ def get_wd14_tags(
character_threshold = _mcut_threshold(character_probs)
character_threshold = max(0.15, character_threshold)

character_res = [x for x in character_names if x[1] > character_threshold]
character_res = dict(character_res)
character_res = {x: v.item() for x, v in character_names if v > character_threshold}

return vreplace(
fmt,
Expand All @@ -231,6 +238,7 @@ def get_wd14_tags(
'general': general_res,
'character': character_res,
'tag': {**general_res, **character_res},
'embedding': embeddings[0],
'embedding': embeddings[0].astype(np.float32),
'prediction': preds[0].astype(np.float32),
}
)
5 changes: 5 additions & 0 deletions test/tagging/test_wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ def test_get_wd14_tags(self):
assert rating['general'] > 0.9
assert tags['cat_girl'] >= 0.8
assert not chars
assert isinstance(rating['general'], float)
assert isinstance(tags['cat_girl'], float)

rating, tags, chars = get_wd14_tags(get_testfile('6125785.jpg'))
assert 0.6 <= rating['general'] <= 0.8
assert tags['1girl'] >= 0.95
assert chars['hu_tao_(genshin_impact)'] >= 0.95
assert isinstance(rating['general'], float)
assert isinstance(tags['1girl'], float)
assert isinstance(chars['hu_tao_(genshin_impact)'], float)

def test_wd14_tags_sample(self):
rating, tags, chars = get_wd14_tags(get_testfile('nude_girl.png'))
Expand Down

0 comments on commit 5fb64ae

Please sign in to comment.