In [1]:
import sys
import numpy as np
from scipy.stats import entropy
from pgmpy.factors.discrete import DiscreteFactor


  from .autonotebook import tqdm as notebook_tqdm


###   Joint distribution $P(X,Y)$

In [2]:
pXY = np.array([[0.25, 0.20, 0.15],
          [0.20, 0.10, 0.10],
        ])
pXY

array([[0.25, 0.2 , 0.15],
       [0.2 , 0.1 , 0.1 ]])

### Marginals $P(X)$, $P(Y)$

In [83]:
pX = pXY.sum(0)
pX, pX.sum()

(array([0.45, 0.3 , 0.25]), 1.0)

In [84]:
pY = pXY.sum(1) # UNIFORM !!!
pY, pY.sum()

(array([0.6, 0.4]), 1.0)

In [85]:
pXY.T.shape,  pY.shape

((3, 2), (2,))

In [90]:
pX_Y = (pXY.T/ pY).T

pY_X = pXY / pX

print(f'P(X|Y)\n{pX_Y}\n \nP(Y|X)\n{pY_X}')

P(X|Y)
[[0.41666667 0.33333333 0.25      ]
 [0.5        0.25       0.25      ]]
 
P(Y|X)
[[0.55555556 0.66666667 0.6       ]
 [0.44444444 0.33333333 0.4       ]]


In [91]:
pX_Y.sum(1)

array([1., 1.])

In [67]:
pY_X.sum(0)

array([1., 1., 1.])

In [50]:
joint = DiscreteFactor(
    variables=['X', 'Y'],
    cardinality=[3, 2],
    values=[0.25,0.20, 0.20,0.10, 0.15,0.10], #given in tuples (x1,y1, x2,y2, x3,y3)
    state_names={
        'X':['1','2','3'],
        'Y':['1','2']}
)
print(joint)

+------+------+------------+
| X    | Y    |   phi(X,Y) |
| X(1) | Y(1) |     0.2500 |
+------+------+------------+
| X(1) | Y(2) |     0.2000 |
+------+------+------------+
| X(2) | Y(1) |     0.2000 |
+------+------+------------+
| X(2) | Y(2) |     0.1000 |
+------+------+------------+
| X(3) | Y(1) |     0.1500 |
+------+------+------------+
| X(3) | Y(2) |     0.1000 |
+------+------+------------+


In [51]:
pY = joint.marginalize(variables=['X'], inplace=False)
print(pY)

+------+----------+
| Y    |   phi(Y) |
| Y(1) |   0.6000 |
+------+----------+
| Y(2) |   0.4000 |
+------+----------+


In [52]:
print(joint / pY)

+------+------+------------+
| X    | Y    |   phi(X,Y) |
| X(1) | Y(1) |     0.4167 |
+------+------+------------+
| X(1) | Y(2) |     0.5000 |
+------+------+------------+
| X(2) | Y(1) |     0.3333 |
+------+------+------------+
| X(2) | Y(2) |     0.2500 |
+------+------+------------+
| X(3) | Y(1) |     0.2500 |
+------+------+------------+
| X(3) | Y(2) |     0.2500 |
+------+------+------------+
