In [1]:
# !pip install pydash

In [2]:
from pymongo import MongoClient
from functools import cmp_to_key
from pathlib import Path
import imagehash
from tqdm.notebook import tqdm
import pydash as _
import numpy as np
import json
import shelve

from IPython.display import display, Image
from ipywidgets import widgets, HBox, VBox, Box, Layout

from lib.parallel import parallel
from lib.sort_things import sort_images, simple_sort_images
from lib.image_dedup import make_hashes, hashes_diff, is_duplicated
from lib.PersistentSet import PersistentSet

In [3]:
mongo = MongoClient('172.17.0.1', 27017)
db = mongo['bad-vis']
posts = db['posts']
imagefiles = db['imagefiles']
imagevalidfiles = db['imagevalidfiles']
imagemeta = db['imagemeta']

In [4]:
images_dir = Path('./images')

In [5]:
main_image_phashes = shelve.open('handmade/main_image_phashes')

In [6]:
invalid_post_ids = set(json.load(open('handmade/invalid_post_ids.json')))
invalid_image_phashes = set(json.load(open('handmade/invalid_image_phashes.json')))
duplicated_image_phash_pairs = PersistentSet.load_set('handmade/duplicated_image_phash_pairs.json')

In [7]:
imagevalidfiles.drop()
imagemeta.drop()

In [8]:
# discard reddit preview if (Imgur) albums or manual downloaded images exist
discard_preview_post_ids = {f['post_id'] for f in [f for f in imagefiles.find({'index_in_album': {'$ne': 0}})]}
len(discard_preview_post_ids)

105

In [9]:
for f in tqdm(imagefiles.find()):
    if f['post_id'] in invalid_post_ids:
        # f['invalid'] = 'invalid_post'
        continue
    if f['phash'] in invalid_image_phashes:
        # f['invalid'] = 'invalid_image'
        continue
    if f['post_id'] in discard_preview_post_ids and f['index_in_album'] == 0:
        # f['invalid'] = 'overridden'
        continue
    imagevalidfiles.insert_one(f)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




# Group image files into images

In [10]:
class ImageMeta ():
    _attrs = [
        'id',
        'post_id',
        'image_id',
        'short_image_id',
        'album',
        'index_in_album',
        'image_type',
        'file_path',
        'source_platform',
        'source',
        'ext',
        'animated',
        'size',
        'width',
        'height',
        'pixels',
        'thumbnail',
        'preview',
        'external_link',
        'archive',
        'manual',
        'ahash',
        'phash',
        'pshash',
        'dhash',
        'whash',
        'related_images',
        'image_order'
    ]

    def __init__ (self, imageFiles=[]):
        if len(imageFiles) == 0:
            raise Exception('Empty imageFiles array.')
        self._imageFiles = imageFiles
        self.image_id = imageFiles[0]['image_id']
        self.image_order = [i['image_type']
                            for i in sort_images([
                                getattr(self, a)
                                for a in self.available_image_types])]
        for k, v in self.main_image.items():
            setattr(self, k, v)

    def digest (self):
        for k, v in self.main_image.items():
            setattr(self, k, v)
        return {a:getattr(self, a) for a in ImageMeta._attrs}

    @property
    def is_hash_consistent (self):
        return _.every([self.hash_consistent(h) for h in ['phash']])
#         return _.every([self.hash_consistent(h) for h in ['ahash', 'phash', 'dhash', 'whash']])

    def hash_distance (self, hash_type):
        main_hash = imagehash.hex_to_hash(self.main_image[hash_type])
        image_types = [t for t in self.image_order if t != 'thumbnail' and t != 'thumbnail_alt']
        return [imagehash.hex_to_hash(self.find_image_type(t)[hash_type]) - main_hash for t in image_types]

    def hash_consistent (self, hash_type):
        main_hash = imagehash.hex_to_hash(self.main_image[hash_type])
        return _.every(self.hash_distance(hash_type), lambda x: abs(x) < 5) # magic number threshold, by experiment when trying out the imagehash library

    @property
    def main_image (self):
        if self.main_image_phash:
            images = [getattr(self, a) for a in self.image_order]
            image = [i for i in images if i['phash'] == self.main_image_phash]
            if len(image) > 0:
                return image[0]
        return getattr(self, self.image_order[0])

    @property
    def main_image_phash (self):
        return main_image_phashes.get(self.image_id, None)

    @property
    def related_images (self):
        return list({f['image_id'] for f in imagefiles.find({'post_id': self.post_id}, {'image_id': 1})})

    @property
    def thumbnail (self):
        return self.find_image_type('thumbnail')

    @property
    def thumbnail_alt (self):
        return self.find_image_type('thumbnail_alt')

    @property
    def preview (self):
        return self.find_image_type('preview')

    @property
    def preview_alt (self):
        return self.find_image_type('preview_alt')

    @property
    def external_link (self):
        return self.find_image_type('external_link')

    @property
    def external_link_alt (self):
        return self.find_image_type('external_link_alt')

    @property
    def archive (self):
        return self.find_image_type('archive')

    @property
    def manual (self):
        return self.find_image_type('manual')

    def find_image_type (self, image_type):
        return _.find(self._imageFiles, lambda x: x['image_type'] == image_type)

    @property
    def available_image_types (self):
        return [image_type
             for image_type in ['manual', 'archive', 'external_link', 'external_link_alt', 'preview', 'preview_alt', 'thumbnail', 'thumbnail_alt']
             if getattr(self, image_type)]

In [11]:
def make_imageMeta (image_id):
    imageFiles = [i for i in imagevalidfiles.find({'image_id': image_id})]

    if imageFiles[0]['post_id'] in invalid_post_ids:
        return

    if not posts.find_one({'post_id': imageFiles[0]['post_id']}, {'_id': 1}):
        print(f"Cannot find corresponding post: {imageFiles[0]['post_id']}")
        return

    valid_imageFiles = [i for i in imageFiles if i['phash'] not in invalid_image_phashes]
    if len(valid_imageFiles) == 0:
        print(f"All image files are invalid: {image_id}")
        return

    for i in valid_imageFiles:
        del i['_id']
    imageMeta = ImageMeta(valid_imageFiles)
    imagemeta.replace_one({'image_id': imageMeta.image_id}, imageMeta.digest(), upsert=True)
    return imageMeta

In [12]:
imageMetas = parallel(make_imageMeta, {f['image_id'] for f in imagevalidfiles.find({}, {'image_id': 1})})
# imageMetas = parallel(make_imageMeta, {f['image_id'] for f in imagefiles.find({'source_platform': 'reddit'}, {'image_id': 1})})

HBox(children=(FloatProgress(value=0.0, max=6738.0), HTML(value='')))




# Find images with inconsistent hash

In [13]:
def make_link (url, text):
    return widgets.HTML(value=f"<a href='{url}' target='_blank'>{text}</a>")

def make_main_image_box (imageMeta):

    layoutArgs = {
        'padding': '10px',
        'margin': '5px',
        'border': '3px solid lightblue'
    }

    output = widgets.Output()

    def select (phash):
        with output:
            print(f"-{main_image_phashes.get(imageMeta.image_id, '')}")
            main_image_phashes[imageMeta.image_id] = phash
            imagemeta.replace_one({'image_id': imageMeta.image_id}, imageMeta.digest(), upsert=True)
            print(f"+{phash}")
            print(f"+: {imageMeta.main_image['image_type']}")

    def makeSelectBtn (imageFile):
        btn = widgets.Button(description="Select", button_style='')
        btn.on_click(lambda b: select(imageFile['phash']))
        return btn

    def makeImageBox (imageFile):
        return VBox([
            widgets.Label(value=f"{imageFile['image_type']} {imageFile['ext']}"),
            widgets.Label(value=f"{imageFile['width']} {imageFile['height']} {imageFile['size']}"),
            widgets.Label(value=f"{imageFile['phash']}"),
            widgets.Image(value=open(imageFile['file_path'], 'rb').read(), width=200),
            makeSelectBtn(imageFile)
        ])

    with output:
        print(f"manual: {main_image_phashes.get(imageMeta.image_id, '')}")
        print(f"phash: {imageMeta.phash}")
        print(f"image_type: {imageMeta.image_type}")

    return HBox(
        [VBox([
            make_link(f"https://www.reddit.com/r/{imageMeta.source}/comments/{imageMeta.id}", imageMeta.image_id),
            output
        ], layout=Layout(**layoutArgs))] + [makeImageBox(imageMeta.find_image_type(i)) for i in imageMeta.image_order])

In [14]:
cnt = 0
for imageMeta in tqdm(imageMetas):
    if not imageMeta:
        continue
    if not imageMeta.is_hash_consistent and imageMeta.image_order[0] != 'manual' and not imageMeta.main_image_phash:
#     if not imageMeta.is_hash_consistent and imageMeta.image_order[0] != 'manual':
        display(make_main_image_box(imageMeta))
        cnt += 1
        if cnt >= 10:
            break

HBox(children=(FloatProgress(value=0.0, max=6738.0), HTML(value='')))




In [15]:
# m = ImageMeta([f for f in imagevalidfiles.find({'id': '32rnec'})])
# display(make_main_image_box(m))

# Visually check invalid images

In [16]:
# invalids = []
# for h in invalid_image_phashes:
#     invalid_images = [f for f in imagefiles.find({'phash': h})]
#     if not invalid_images:
#         continue
# #     print(f"{h} {[i['image_id'] for i in invalid_images]}")
# #     display(HBox([
# #             widgets.Image(value=open(i['file_path'], 'rb').read(), width=100, height=100)
# #             for i in invalid_images]))
#     invalids.append(invalid_images[0])

In [17]:
# display(Box([widgets.Image(value=open(i['file_path'], 'rb').read(), width=100, height=100) for i in invalids],
#                 layout=Layout(display='flex', flex_flow='row wrap')))

In [18]:
# duplicated_image_ids = [c
#                      for c in nx.components.connected_components(nx.Graph(distance <= 1))
#                      if len(c) > 1]
# len(duplicated_image_ids)
# for idxs in duplicated_image_ids:
#     print(f"{[imageDedup[i]['image_id'] for i in idxs]}")
#     if len(idxs) >= 4:
#         display(HBox([
#             widgets.Image(value=open(imageDedup[i]['file_path'], 'rb').read(), width=100, height=100)
#             for i in idxs]))