In [1]:
%matplotlib inline

In [2]:
from ankisync2 import Apkg
from tqdm import tqdm
import re
from bs4 import BeautifulSoup
import scipy as sp
import numpy as np
import os
import itertools
import math
import sqlite3
from contextlib import closing
import pandas as pd
from time import sleep

In [8]:
class Anki(object):
    THRESHOLD = 0.15  # Only record similarity values above threshold
    PARTIAL = None  # Only look at first n entries
    DELAY = 0.5  # Just for display purposes
    
    def __init__(self, file, overwrite=True, read=True):
        self.file = os.path.abspath(file)
        self.sql = f'{self.file}.sql'
        self.db = []
        self.note_sim = {'images': [], 'tags': [], 'text': [], 'overall': []}  # Empty until we need it
        self.image_sim = {}
        self.tag_sim = {}
        
        if overwrite:
            self.init_sql()
            
        if read:
            self.read_file()
        
    def read_file(self):
        with Apkg(self.file) as apkg:
            for card in tqdm(apkg, f'Reading `{self.file}`: ', 
                             total=sum(1 for _ in apkg), 
                             position=0, leave=True):

                note = card['note']
                content = dict(zip(note['model']['flds'], note['flds']))

                # Extracting images
                images = []
                for field in ['Text', 'Extra', 'Image', 'Lecture Notes', 
                              'Missed Questions', 'Pathoma', 'Boards and Beyond', 
                              'First Aid', 'Sketchy', 'Pixorize', 'Physeo', 
                              'Additional Resources']:
                    imgs = Anki.get_images(content.get(field))
                    if imgs:
                        images += imgs

                # Regularizing the 'Text' and 'Extra' fields
                if 'Text' in content.keys():
                    text = Anki.clean_html_tags(Anki.remove_cloze(content['Text']))
                else:
                    try:
                        text = clean_html_tags(content['Header'])
                    except:
                        # A few of the in-house psych cards used a weird format
                        # Honestly just gonna skip
                        continue
                        
                if 'Extra' in content.keys():
                    extra = Anki.clean_html_tags(content['Extra'])
                else:
                    extra = ''

                out = {}
                out['id'] = note['id']
                out['data'] = f'{text} \n {extra}'
                out['images'] = list(set(images))
                out['tags'] = Anki.telescope_tags(note['tags'])

                self.db.append(out)
        self.n_cards = len(self.db)
        
#     def compute_note_similarity(self):
#         self.note_image_similarity()
#         self.note_tag_similarity()
#         self.note_text_similarity()
#         self.note_overall_similarity()

    def note_image_similarity(self):
        res = self.note_list_similarity('images')
        return res
    
    def note_tag_similarity(self):
        res = self.note_list_similarity('tags')
        return res
    
    def note_text_similarity(self):
        pass # TODO
    
    def note_overall_similarity(self):
        pass # TODO
    
    def note_list_similarity(self, list_type):
        assert list_type.lower() in ['images', 'tags']
        res = []

        if self.PARTIAL:
            pairs = itertools.permutations(self.db[0:self.PARTIAL], 2)
            perms = int(math.factorial(self.PARTIAL)/math.factorial(self.PARTIAL-2))
        else:
            pairs = itertools.permutations(self.db, 2)
            perms = int(math.factorial(self.n_cards)/math.factorial(self.n_cards-2))
        
        for pair in tqdm(pairs, f'Calculating similarity of cards by their {list_type}: ',
                         total = perms, position=0, leave=True):
            sim = Anki.jaccard_similarity(pair[0][list_type], pair[1][list_type])
            if sim > self.THRESHOLD:
                res.append((pair[0]['id'], pair[1]['id'], sim))

        print(f'At theshold {self.THRESHOLD}, storing {len(res)} of {perms} combos '
              f'of {list_type} ({100*(1-len(res)/perms):.2f}% reduction)')
        
        tbl = list_type[:-1]
        print(f'Serializing results into table note_{tbl}_sim')
        self.serialize_note_similarity(tbl, res)
        return res
    
    # idc if I'm repeating myself in these next two funcs
    def tag_similarity(self):
        all_tags = list(set(tag for taglist in [card['tags'] for card in self.db] for tag in taglist))
        n_tags = len(all_tags)
        
        pairs = itertools.permutations(all_tags, 2)
        perms = int(math.factorial(n_tags)/math.factorial(n_tags-2))
        
        res = []
        for pair in tqdm(pairs, f'Calculating similarity of tags: ',
                         total = perms, position=0, leave=True):
            tag1 = pair[0]
            tag2 = pair[1]
            tag1_cards = [card['id'] for card in self.db if tag1 in card['tags']]
            tag2_cards = [card['id'] for card in self.db if tag2 in card['tags']]
            sim = Anki.jaccard_similarity(tag1_cards, tag2_cards)
            
            if sim != 0:
                res.append((tag1, tag2, sim))
                
        print(f'Storing {len(res)} of {perms} combinations of tags '
              f'({100*(1-len(res)/perms):.2f}% reduction).')
        print('Serializing tag similarity')
        self.insert_multiple_vals(f'''
                                  INSERT INTO tag_sim (tag_a, tag_b, value)
                                  VALUES (?, ?, ?)
                                   ''', res)
        return res
    
    def image_similarity(self):
        all_imgs = list(set(img for imglist in [card['images'] for card in self.db] for img in imglist))
        n_imgs = len(all_imgs)
        
        pairs = itertools.permutations(all_imgs, 2)
        perms = int(math.factorial(n_imgs)/math.factorial(n_imgs-2))
        
        res = []
        for pair in tqdm(pairs, f'Calculating similarity of images: ',
                         total = perms, position=0, leave=True):
            img1 = pair[0]
            img2 = pair[1]
            img1_cards = [card['id'] for card in self.db if img1 in card['tags']]
            img2_cards = [card['id'] for card in self.db if img2 in card['tags']]
            sim = Anki.jaccard_similarity(img1_cards, img2_cards)
            
            if sim != 0:
                res.append((img1, img2, sim))
                
        print(f'Storing {len(res)} of {perms} combinations of images '
              f'({100*(1-len(res)/perms):.2f}% reduction).')
        print('Serializing image similarity')
        self.insert_multiple_vals(f'''
                                  INSERT INTO image_sim (image_a, image_b, value)
                                  VALUES (?, ?, ?)
                                   ''', res)
        return res
            
    
    def init_sql(self):
        print(f'Creating new storage db at {self.sql}')
        conn = sqlite3.connect(self.sql)
        cursor = conn.cursor()
        for tbl in ['note_image_sim', 'note_tag_sim', 'note_text_sim', 'note_overall_sim',
                    'tag_sim', 'image_sim']:
            cursor.execute(f'DROP TABLE IF EXISTS {tbl}')
            
        cursor.execute('''
                       CREATE TABLE note_image_sim
                       (post_a integer, 
                       post_b integer,
                       value real);
                       ''')
        cursor.execute('CREATE INDEX note_image_sim_idx ON note_image_sim (post_a, post_b);')
        cursor.execute('''
                       CREATE TABLE note_tag_sim
                       (post_a integer, 
                       post_b integer,
                       value real);
                       ''')
        cursor.execute('CREATE INDEX note_tag_sim_idx ON note_tag_sim (post_a, post_b);')
        cursor.execute('''
                       CREATE TABLE note_text_sim
                       (post_a integer, 
                       post_b integer,
                       value real);
                       ''')
        cursor.execute('CREATE INDEX note_text_sim_idx ON note_text_sim (post_a, post_b);')
        cursor.execute('''
                       CREATE TABLE note_overall_sim
                       (post_a integer, 
                       post_b integer,
                       value real);
                       ''')
        cursor.execute('CREATE INDEX note_overall_sim_idx ON note_overall_sim (post_a, post_b);')
        cursor.execute('''
                       CREATE TABLE tag_sim
                       (tag_a string, 
                       tag_b string,
                       value real);
                       ''')
        cursor.execute('CREATE INDEX tag_sim_idx ON tag_sim (tag_a, tag_b);')
        cursor.execute('''
                       CREATE TABLE image_sim
                       (image_a string, 
                       image_b string,
                       value real);
                       ''')
        cursor.execute('CREATE INDEX image_sim_idx ON image_sim (image_a, image_b);')
        conn.commit()
        conn.close()

    def execute_sql(self, sql, params=()):
        with closing(sqlite3.connect(self.sql)) as conn:
            with conn:
                cur = conn.cursor()
                try:
                    cur.execute(sql, params)
                    res = cur.fetchall()
                except sqlite3.ProgrammingError:
                    print(f'SQL:\n{sql}\n{params}')

                if res:
                    df = pd.DataFrame(res)
                    df.columns = [d[0] for d in cur.description]
                else:
                    df = pd.DataFrame({})
        return df
    
    def insert_multiple_vals(self, sql, param_list):
        with closing(sqlite3.connect(self.sql)) as conn:
            with conn:
                cur = conn.cursor()
                try:
                    sleep(self.DELAY)
                    for params in tqdm(param_list, 'Serializing', position=0, leave=True):
                        cur.execute(sql, params)
                    res = cur.fetchall()
                except sqlite3.ProgrammingError:
                    print(f'SQL:\n{sql}\n{params}')

                if res:
                    df = pd.DataFrame(res)
                    df.columns = [d[0] for d in cur.description]
                else:
                    df = pd.DataFrame({})
        return df
        
    def serialize_note_similarity(self, table, params):
        self.insert_multiple_vals(f'''
                                  INSERT INTO note_{table}_sim (post_a, post_b, value)
                                  VALUES (?, ?, ?)
                                   ''', params)
    
    def load_from_db(self):
        print(f'Loading from database {self.sql}')
        self.note_sim['images'] = self.execute_sql('SELECT * from note_image_sim;')
        self.note_sim['tags'] = self.execute_sql('SELECT * from note_tag_sim;')
        self.note_sim['text'] = self.execute_sql('SELECT * from note_text_sim;')
        self.note_sim['overall'] = self.execute_sql('SELECT * from note_overall_sim;')
        self.tag_sim = self.execute_sql('SELECT * from tag_sim;')
        self.image_sim = self.execute_sql('SELECT * from image_sim;')
    
    def set_threshold(self, threshold=0.15):
        self.THRESHOLD = threshold

    @staticmethod
    def clean_html_tags(markup):
        soup = BeautifulSoup(markup, 'html.parser')
        for br in soup.find_all('br'):
            br.replace_with('\n')
        return soup.get_text()

    @staticmethod
    def get_images(markup):
        out = []
        if markup:
            soup = BeautifulSoup(markup, 'html.parser')
            images = soup.findAll('img')
            for image in images:
                out.append(image['src'])  
        return out

    @staticmethod
    def remove_cloze(markup):
        # txt = '<a href="blah"> Hello {{c1::world}} once {{c2::again::hint}} lol </a>'
        return re.sub('{{.*?::(.*?)(::.*?){0,}}}', '\\1', markup)
    
    @staticmethod
    def telescope_tags(taglist):
        out = []
        for tag in taglist:
            splt = tag.lower().split('::')
            for i in range(1, len(splt)):
                out.append('::'.join(splt[0:i]))
        return list(set(out))

    @staticmethod
    def jaccard_similarity(list1, list2):
        s1 = set(list1)
        s2 = set(list2)
        try:
            jaccard = float(len(s1.intersection(s2)) / len(s1.union(s2)))
        except ZeroDivisionError:
            jaccard = 0
        return jaccard

# TODO: Calculate similarity based on text
# TODO: Calculate network fusion

# TODO: combine all the whitespace?
# TODO: similarity of tags/images rather than notes 
# (easier to compute: simply count how many notes any 2 tags have in common)

In [None]:
x = Anki("Subset.apkg")  # Selected Notes.apkg
x.set_threshold(0.15)
x.note_image_similarity()
x.note_tag_similarity()
# x.note_text_similarity()
# x.note_overall_similarity()
x.tag_similarity()
x.image_similarity()
x = Anki("Subset.apkg", overwrite=False, read=False)
x.load_from_db()
print('Done')

Creating new storage db at /mnt/c/Users/edrid.EDRIDGE-DSOUZA-/Documents/GitHub/anki-network/Subset.apkg.sql


Reading `/mnt/c/Users/edrid.EDRIDGE-DSOUZA-/Documents/GitHub/anki-network/Subset.apkg`: 100%|██████████| 572/572 [00:01<00:00, 573.70it/s]
Calculating similarity of cards by their images: 100%|██████████| 284622/284622 [00:00<00:00, 693303.63it/s]


At theshold 0.15, storing 2098 of 284622 combos of images (99.26% reduction)
Serializing results into table note_image_sim


Serializing: 100%|██████████| 2098/2098 [00:00<00:00, 62891.48it/s]
Calculating similarity of cards by their tags: 100%|██████████| 284622/284622 [00:00<00:00, 309967.67it/s]


At theshold 0.15, storing 136468 of 284622 combos of tags (52.05% reduction)
Serializing results into table note_tag_sim


Serializing: 100%|██████████| 136468/136468 [00:00<00:00, 170592.13it/s]
Calculating similarity of tags:   0%|          | 0/193160 [00:00<?, ?it/s]

In [None]:
x.db[101]

In [None]:
# Doc2Vec for similarity scores:
# https://medium.com/red-buffer/doc2vec-computing-similarity-between-the-documents-47daf6c828cd
# https://stackoverflow.com/questions/53503049/measure-similarity-between-two-documents-using-doc2vec
# https://github.com/jhlau/doc2vec#pre-trained-doc2vec-models

# https://github.com/rmarkello/snfpy to fuse similarity networks
# https://github.com/maxconway/SNFtool has more visualization options
# sklearn to do general network stuff
# https://towardsdatascience.com/visualising-similarity-clusters-with-interactive-graphs-20a4b2a18534

In [None]:
from gensim.models import doc2vec
from scipy import spatial

d2v_model = doc2vec.Doc2Vec.load(model_file)

fisrt_text = '..'
second_text = '..'

vec1 = d2v_model.infer_vector(fisrt_text.split())
vec2 = d2v_model.infer_vector(second_text.split())

cos_distance = spatial.distance.cosine(vec1, vec2)
# cos_distance indicates how much the two texts differ from each other:
# higher values mean more distant (i.e. different) texts