In [None]:
!pip install mecab-python3
!pip install fugashi[unidic-lite]

In [None]:
#@title
def initial_setup():
    necessary_files = [
        # "20170201.tar.bz2",
        # "cc.ja.300.vec.gz",
        "wnjpn.db",
        "jawiki.all_vectors.100d.txt.bz2",
        "corpus/ldcc-20140209.tar.gz",
    ]
    if all(list(map(lambda path: os.path.exists(os.path.join(BASE_PATH, path)),
                   necessary_files ))):pass
    else:
        %cd /content/drive/MyDrive/DemoLesson/
        # WordNet
        !wget https://github.com/bond-lab/wnja/releases/download/v1.1/wnjpn.db.gz
        with gzip.open(os.path.join(BASE_PATH, 'wnjpn.db.gz'), 'rb') as f_in:
            with open(os.path.join(BASE_PATH, 'wnjpn.db'), 'wb') as f_out:
                shutil.copyfileobj(f_in, f_out)
        # word2vecのダウンロード
        # !wget https://dl.fbaipublicfiles.com/fasttext/vectors-crawl/cc.ja.300.vec.gz
        # !wget http://www.cl.ecei.tohoku.ac.jp/~m-suzuki/jawiki_vector/data/20170201.tar.bz2
        !wget https://github.com/singletongue/WikiEntVec/releases/download/20190520/jawiki.all_vectors.100d.txt.bz2


In [None]:
from google import colab
colab.drive.mount("/content/drive")

BASE_PATH = "/content/drive/MyDrive/DemoLesson"

In [None]:
import gzip
import shutil
import os
import pandas as pd
from collections import Counter
import random

initial_setup()

from gensim.models import KeyedVectors
wv = KeyedVectors.load_word2vec_format(os.path.join(BASE_PATH, 'jawiki.all_vectors.100d.txt.bz2'), binary=False)

import sqlite3
import numpy as np

words = []
with open("/content/drive/MyDrive/DemoLesson/words_for_game.csv", "r") as f:
    words = f.read()

words = words.split("\n")[1:]

In [None]:
# SQL周りの関数定義
def get_hype_list(target_word: str):
    """上位語を取得する関数"""
    hype_list = []

    conn = sqlite3.connect(os.path.join(BASE_PATH, 'wnjpn.db'))
    cursor = conn.cursor()

    query = f"""
    SELECT
      w1.lemma, sl.link, w2.lemma
    FROM synlink AS sl
    INNER JOIN synset AS sy1 ON sy1.synset = sl.synset1
    INNER JOIN synset AS sy2 ON sy2.synset = sl.synset2
    INNER JOIN sense AS se1 ON se1.synset = sy1.synset
    INNER JOIN sense AS se2 ON se2.synset = sy2.synset
    INNER JOIN word AS w1 ON w1.wordid = se1.wordid
    INNER JOIN word AS w2 ON w2.wordid = se2.wordid
    WHERE w1.lemma = '{target_word}' AND sl.link IN ('hype', 'hypo')
      AND se1.lang = 'jpn' AND w1.lang = 'jpn' AND se2.lang = 'jpn' AND w2.lang = 'jpn'
    """

    cursor.execute(query)
    results = cursor.fetchall()

    conn.close()

    for row in results:
        _, link, lemma2 = row
        if link == 'hype': hype_list.append(lemma2)
    return hype_list


def get_hypo_list(hype: str):
    """ 下位語を取得して，そのリストを返却"""
    hypo_list = []

    conn = sqlite3.connect(os.path.join(BASE_PATH, 'wnjpn.db'))
    cursor = conn.cursor()

    query = f"""
    SELECT
      w1.lemma, sl.link, w2.lemma
    FROM synlink AS sl
    INNER JOIN synset AS sy1 ON sy1.synset = sl.synset1
    INNER JOIN synset AS sy2 ON sy2.synset = sl.synset2
    INNER JOIN sense AS se1 ON se1.synset = sy1.synset
    INNER JOIN sense AS se2 ON se2.synset = sy2.synset
    INNER JOIN word AS w1 ON w1.wordid = se1.wordid
    INNER JOIN word AS w2 ON w2.wordid = se2.wordid
    WHERE w1.lemma = '{hype}' AND sl.link IN ('hype', 'hypo')
      AND se1.lang = 'jpn' AND w1.lang = 'jpn' AND se2.lang = 'jpn' AND w2.lang = 'jpn'
    """

    cursor.execute(query)
    results = cursor.fetchall()

    conn.close()

    for row in results:
        _, link, lemma2 = row
        if link == 'hypo': hypo_list.append(lemma2)
    return hypo_list

def get_options(ans, w1, w2):
    options = []
    for word in [ans, w1, w2]:
        for hype in get_hype_list(word):
            options.extend(get_hypo_list(hype))
            options = list(set(options))
            if len(options) >= 8:
                options = options[:8]
                if ans in options:
                    break
                else:
                    options = options[:7]
                    options += [ans]
        if len(options) >= 8:break
    random.shuffle(options)
    return options

def get_formula(w1, w2, ans):
    d = dict(w1=w1, w2=w2)
    keys = list(d.keys())
    random.shuffle(keys)
    s = np.random.randint(1, 3)
    if s==1:
        params = dict(positive=list(d.values()), negative=[ans])
        simw = wv.most_similar(topn=1, **params)[0][0].replace("#", "")
        pos_list = list(d.values())
        neg = simw
    else:
        parmas = dict(positive=[d[keys[0]], ans], negative=[d[keys[1]]])
        simw = wv.most_similar(topn=1, **parmas)[0][0].replace("#", "")
        pos_list = [d[keys[1]], simw]
        neg = d[keys[0]]
    random.shuffle(pos_list)
    formula = f"{pos_list[0]} - {neg} + {pos_list[1]} = ?"
    return formula

def get_formula_and_ans():
    i1, i2, i3 = np.random.choice(np.arange(len(words)), 3, replace=False)
    w1 = words[i1]
    w2 = words[i2]
    ans = words[i3]
    formula = get_formula(w1, w2, ans)
    return formula, ans, w1, w2

def remove_hash(*args):
    return list(map(lambda w: w.replace("#", ""), args))

def init_game():
    formula, ans, w1, w2 = get_formula_and_ans()
    options = get_options(ans, w1, w2)
    return formula, options, ans

In [None]:
# for _ in range(100):
#     formula, options, ans = init_game()
#     # print(formula)
#     # print("選択肢:", options)
#     if ans not in options: raise "回答が選択肢に存在しません"

In [None]:
formula, options, ans = init_game()
print(formula)
print("選択肢:", options)

In [None]:
ans

In [None]:
wv.most_similar(positive=["東京", "フランス"], negative=["日本"], topn=5)