# 目的

$P(P_{11}, \cdots, P_{33} | Q_{11}, \cdots, Q_{22} )$を調べたい。

# 原理
ベイズの定理より

$P(P_{11}, \cdots, P_{33} | Q_{11}, \cdots, Q_{22} ) = \frac{P( Q_{11}, \cdots, Q_{22}  | P_{11}, \cdots, P_{33})}{\sum_{P_{ij}} P( Q_{11}, \cdots, Q_{22}  | P_{11}, \cdots, P_{33}) }$
が成り立つ。

$(P_{11},\cdots, P_{33}) = (1,\cdots, 1)$から$(0, \cdots, 0)$までの$2^{9}$通り

$(A_{11},\cdots, A_{22}) = (1,\cdots, 1)$から$(-1, \cdots, -1)$までの$3^{4}$通り

## ライブラリのインポート

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import itertools as iter
import sqlite3

## データベース構築

FileName:data_set 

Colomn=[original, observe]

In [14]:
con = sqlite3.connect("data.db")
cursor = con.cursor()
cursor.execute("CREATE TABLE data_set(p_11, p_12, p_13, p_21, p_22, p_23, p_31, p_32, p_33, q_11, q_12, q_21, q_22)")
p = "INSERT INTO data_set(p_11, p_12, p_13, p_21, p_22, p_23, p_31, p_32, p_33, q_11, q_12, q_21, q_22) VALUES(?,?,?,?,?,?,?,?,?,?,?,?,?)"

データベース削除用 

In [11]:
#DELETE ALL DATA
cursor.execute('delete from data_set')
cursor.execute('select * from data_set')
cursor.fetchall()

[]

## 真値(P行列)から観測値(Q行列)を計算

In [15]:
def calculation(p_11, p_12, p_13, p_21, p_22, p_23, p_31, p_32, p_33):
    for a_11,a_12,a_21,a_22 in iter.product([-1,0,1],[-1,0,1],[-1,0,1],[-1,0,1]):
        q_11 = p_11 * a_12 + p_12 * a_21 + p_21 * a_22 + p_22 * a_22
        q_12 = p_12 * a_12 + p_13 * a_21 + p_22 * a_22 + p_23 * a_22
        q_21 = p_21 * a_12 + p_22 * a_21 + p_31 * a_22 + p_32 * a_22
        q_22 = p_22 * a_12 + p_23 * a_21 + p_32 * a_22 + p_33 * a_22
        cursor.execute(p, (int(p_11), int(p_12), int(p_13), int(p_21), int(p_22), int(p_23), int(p_31), int(p_32), int(p_33), int(q_11), int(q_12), int(q_21), int(q_22)))
        con.commit()

Pのすべての組み合わせを計算し、データベースに保存する。

In [16]:
for p_11,p_12,p_13,p_21,p_22,p_23,p_31,p_32,p_33 in iter.product([0,1],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1],[0,1]):
    calculation(p_11, p_12, p_13, p_21, p_22, p_23, p_31, p_32, p_33)

In [17]:
cursor.execute('select * from data_set')
cursor.fetchall()

[(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0),
 (0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

## 特定のQ行列(観測データ)からP行列になる回数

In [18]:
def visualize_fixedQ(q_11,q_12,q_21,q_22):
    Q = np.array([[q_11,q_12],[q_21,q_22]])
    cursor.execute("select original, observe, count(original) from data_set where observe = '" + str(Q) + "' group by original;")
    data_set = cursor.fetchall()
    total = len(data_set)    
    
    if total >= 10:
        print("TOO MANY LABELS")
    else:
        pass
    
    names = [data_set[i][0] for i in range(total)]
    values = [data_set[i][2] for i in range(total)]
    
    print(values)
    plt.figure(dpi = 100,figsize=(20,4))
    
    plt.xlabel("P Matrix")
    plt.ylabel("Number of observed times")
    plt.bar(names, values)
    plt.show()
    
    

## 特定のP行列(真値)からQ行列(観測データ)になる回数

In [19]:
def visualize_fixedP(p_11,p_12,p_13,p_21,p_22,p_23,p_31,p_32,p_33):
    P = np.array([[p_11,p_12,p_13],[p_21,p_22,p_23],[p_31,p_32,p_33]])
    cursor.execute("select original, observe, count(original) from data_set where original = '" + str(P) + "' group by observe;")
    data_set = cursor.fetchall()
    total = len(data_set)
    
    
    if total >= 10:
        print("TOO MANY LABELS")
    else:
        pass
    
    names = [data_set[i][1] for i in range(total)]
    values = [data_set[i][2] for i in range(total)]
    
    print(values)
    plt.figure(dpi = 100,figsize=(20,4))
    
    plt.xlabel("Q Matrix")
    plt.ylabel("Number of observed times")
    plt.bar(names, values)
    plt.show()

In [20]:
visualize_fixedQ(0,2,2,2)

visualize_fixedP(0,0,1,0,0,1,1,1,0)
visualize_fixedP(0,1,0,0,0,1,1,1,1)
visualize_fixedP(0,1,1,0,0,1,1,1,1)
visualize_fixedP(1,0,0,0,1,1,0,1,1)
visualize_fixedP(1,0,1,0,0,1,1,1,0)

OperationalError: no such column: original

少なくとも来週までには、ベイズの定理が成り立っているか上2つの関数を使って確かめる。

この総当たり戦は、ピクセル数が増えると指数的に計算量も増えるため、現実的ではない。ので、ベイジアンネットを使うとうまくいくかも。。？ by橋爪

In [62]:
cursor.execute("select count(*) from data_set;")
cursor.fetchall()


[(41472,)]

In [79]:
def p11_Given_q11():
    cursor.execute("select p_11, q_11, count(p_11) from data_set where q_11 = 0 group by p_11;")
    temp = cursor.fetchall()
    print(temp)

def p12_Given_q11q12():
    cursor.execute("select p_12, q_11, q_12, count(p_11) from data_set where q_11 = 0 and q_12 = 2 group by p_11;")
    temp = cursor.fetchall()
    print(temp)
    
p11_Given_q11()
p12_Given_q11q12()

[(0, 0, 8064), (1, 0, 5568)]
[(0, 0, 2, 240), (0, 0, 2, 216)]
