In [None]:
# !pip install ImageHash

In [None]:
from pymongo import MongoClient
import os
from pathlib import Path
from PIL import Image
import imagehash
from tqdm.notebook import tqdm
import numpy as np
import json

from lib.parallel import parallel

In [None]:
from lib.fix_nested_tqdm import fix_nested_tqdm

fix_nested_tqdm()

In [None]:
mongo = MongoClient('172.17.0.1', 27017)
db = mongo['bad-vis']
imagefiles = db['imagefiles']

In [None]:
images_dir = Path('./images')

In [None]:
class ImageFile ():
    _attrs = [
        'id',
        'post_id',
        'image_id',
        'short_image_id',
        'album',
        'index_in_album',
        'image_type',
        'file_path',
        'filename',
        'basename',
        'ext',
        'animated',
        'source_platform',
        'source',
        'size',
        'width',
        'height',
        'ahash',
        'phash',
        'pshash',
        'dhash',
        'whash'
    ]

    def __init__ (self, file_path):
        self.file_path = file_path
        self.filename = os.path.basename(file_path)
        self.basename, self.ext = os.path.splitext(self.filename)
        self._im = Image.open(file_path)
        self.width, self.height = self._im.size

        self._file_path_tokens = self.file_path.split('/')

    def digest (self):
        return {a:getattr(self, a) for a in ImageFile._attrs}

    @property
    def size (self):
        return os.path.getsize(self.file_path)

    @property
    def animated (self):
        return self._im.is_animated if self._im.format == 'GIF' or self._im.format == 'WEBP' else False

    @property
    def image_type (self):
        return self._file_path_tokens[1]

    @property
    def source_platform (self):
        return self._file_path_tokens[2]

    @property
    def source (self):
        return self._file_path_tokens[3]

    @property
    def album (self):
        return self._file_path_tokens[4] if len(self._file_path_tokens) > 5 else ''

    @property
    def index_in_album (self):
        return int(self.basename.split('-')[0]) if '-' in self.basename else 0

    @property
    def id (self):
        return self.album if self.album else self.basename

    @property
    def post_id (self):
        return f"{self.source_platform}/{self.source}/{self.id}"

    @property
    def image_id (self):
        return f"{self.post_id}:{self.index_in_album}"

    @property
    def short_image_id (self):
        return f"{self.id}:{self.index_in_album}"

    @property
    def ahash (self):
        if not hasattr(self, '_ahash'):
            self._ahash = imagehash.average_hash(self._im)
        return str(self._ahash)

    @property
    def phash (self):
        if not hasattr(self, '_phash'):
            self._phash = imagehash.phash(self._im)
        return str(self._phash)

    @property
    def pshash (self):
        if not hasattr(self, '_pshash'):
            self._pshash = imagehash.phash_simple(self._im)
        return str(self._pshash)

    @property
    def dhash (self):
        if not hasattr(self, '_dhash'):
            self._dhash = imagehash.dhash(self._im)
        return str(self._dhash)

    @property
    def whash (self):
        if not hasattr(self, '_whash'):
            self._whash = imagehash.whash(self._im)
        return str(self._whash)

In [None]:
def digest_file (name, root=''):
    basename, ext = os.path.splitext(name)
    imageFile = ImageFile(os.path.join(root, name))
    if not imagefiles.find_one({'file_path': imageFile.file_path}, {'_id': 1}):
        imagefiles.insert_one(imageFile.digest())

with tqdm(os.walk(images_dir)) as t:
    for root, dirs, files in t:
        t.set_postfix(root=root)
        parallel(digest_file, files, params_dict={'root': root}, tqdm_postfix=root, leave=None)

In [None]:
# import warnings
# warnings.filterwarnings("error")
# %%capture cap_out --no-stderr
# for root, dirs, files in os.walk(images_dir):
#     for name in files:
#         try:
#             imageFile = ImageFile(os.path.join(root, name))
#             imageFile.digest()
#         except Exception as inst:
#             print(f"Error digesting image: {os.path.join(root, name)} {inst}")

In [None]:
# with open('error.txt', 'w') as f:
#     f.write(cap_out.stdout)

# Match up reddit preview with Imgur albums

In [None]:
from IPython.display import Image, display

In [None]:
post_ids = {f['post_id'] for f in [f for f in imagefiles.find({'index_in_album': {'$ne': 0}})]}
len(post_ids)

In [None]:
hashes = ['phash', 'whash', 'dhash', 'ahash']
def make_hashes (image):
    return [imagehash.hex_to_hash(image[h]) for h in hashes]

def hashes_diff (hashes_x, hashes_y):
#     return hashes_x[0] - hashes_y[0]
    diffs = [abs(hashes_x[i] - hashes_y[i]) for i in range(len(hashes))]
    diffs.sort()
    diff = sum(diffs[1:3]) / 2
    return diff

In [None]:
for i in post_ids:
    related_image_files = [f for f in imagefiles.find({'post_id': i})]
    previews = [r for r in related_image_files if r['index_in_album'] == 0]

    if len(previews) < 1:
        print('No preview image', i)
        continue

    preview = [p for p in previews if p['image_type'] == 'preview']

    if len(preview) == 0:
        preview = [p for p in previews if p['image_type'] == 'thumbnail']

    if len(preview) > 1:
        print('More than 1 preview image', i)
        continue

    preview = preview[0]

    album = [r for r in related_image_files if r['index_in_album'] != 0]
    preview_hashes = make_hashes(preview)
    album_hashes = [make_hashes(a) for a in album]

    distance = [hashes_diff(preview_hashes, ah) for ah in album_hashes]
    min_index = np.argmin(np.asarray(distance))
#     print(min_index, distance[min_index], preview['image_id'], preview['image_type'])
    if distance[min_index] >= 4:
        print(min_index, distance[min_index], preview['image_id'], preview['image_type'], distance)
        display(Image(filename=preview['file_path']), Image(filename=album[min_index]['file_path']))
        break
    for p in previews:
        new_album = album[min_index]['album']
        new_index_in_album = album[min_index]['index_in_album']
        imagefiles.update_one({'file_path': p['file_path']},
                              {'$set': {
                                  'album': new_album,
                                  'index_in_album': new_index_in_album,
                                  'image_id': p['image_id'].replace(':0', f":{new_index_in_album}"),
                                  'short_image_id': p['short_image_id'].replace(':0', f":{new_index_in_album}")
                              }})