<a href="https://colab.research.google.com/github/kowshikasarker/DL4H-Project/blob/main/Notebook/PGM-1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import math
from torch.nn.functional import normalize
import random
from scipy.special import beta

In [None]:
seed = 1234
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

In [None]:
def nCr(n, r):
    #print('n', n, 'r', r)
    return int((math.factorial(n) / (math.factorial(r) * math.factorial(n - r))))

In [None]:
class EM(nn.Module):
    def init_nodes(self, V, O, G):
        # Dimensions
        self.N = V.shape[0] # No. of data points (ideally #individuals * #genes)
        self.len_V = V.shape[1] # No. of sequence annotations
        self.len_O = O.shape[1] # No. of outliers

        self.dim_V = (self.N, self.len_V) # 2D
        self.dim_O = (self.N, self.len_O) # 2D
        self.dim_G = (self.N, 1)

        self.dim_H = (self.N, self.len_O)
        self.dim_Z = (self.N, 1)

        self.dim_W0 = (self.len_V, self.len_O) # row - feature, col - outlier
        self.dim_W1 = (1, self.len_O) # row->, col -> outlier
        self.dim_W2 = (1, nCr(self.len_O, 2)) # row->, col -> h_ih_j
        #self.dim_W3 = (1, nCr(self.len_O, 3)) # row->, col -> h_ih_jh_k
        self.dim_A = (1, self.len_O) # row->, col -> outlier
        self.dim_M = (self.len_O, 1) # row -> outlier, col ->
        self.dim_B = (1, 1)
        
        #Data
        self.V = torch.from_numpy(V)
        self.O = torch.from_numpy(O)
        self.G = torch.from_numpy(G)
        
        #print('self.O', self.O)
       #print('self.G', self.G)

        #Param
        self.W0 = nn.Parameter(torch.rand(self.dim_W0, dtype=torch.double))
        self.W1 = nn.Parameter(torch.rand(self.dim_W1, dtype=torch.double))
        self.W2 = nn.Parameter(torch.rand(self.dim_W2, dtype=torch.double))
        self.M = nn.Parameter(torch.rand(self.dim_M, dtype=torch.double))

        self.A = torch.rand(self.dim_A, dtype=torch.double)
        self.B = torch.rand(self.dim_B, dtype=torch.double)

        #print('W0\n', self.W0)
        #print('W1\n', self.W1)
        #print('W2\n', self.W2)
        #print('A\n', self.A)
        #print('B\n', self.B)
       #print('M\n', self.M)

    def preprocess_H(self):
        uniq_H_count = 2**self.len_O
    
        self.bin_H_all = [None] * uniq_H_count # int to bin array mapping 3 -> [0, 1, 1]
        self.joint_H_all = [None] * uniq_H_count # int to joint array mapping 3 -> [0*0, 0*1, 0*1, 1*1]

        bin_H_format = "0" + str(self.len_O) + "b"

        for H in range (uniq_H_count):
            bin_H = format(H, bin_H_format)
            bin_H = list(map(int, list(bin_H)))
            joint_H = []
            
            for i in range(self.len_O):
                for j in range(i+1, self.len_O):
                    joint_H.append(bin_H[i] * bin_H[j])

            self.bin_H_all[H] = bin_H
            self.joint_H_all[H] = joint_H

        self.bin_H_all = torch.tensor(self.bin_H_all, dtype=torch.double)
        self.joint_H_all = torch.tensor(self.joint_H_all, dtype=torch.double)
        self.bin_H_sum = torch.sum(self.bin_H_all, dim=1)
       #print('self.bin_H_sum', self.bin_H_sum)

       #print(self.bin_H_all.shape)
       #print(self.joint_H_all.shape)

       #print(self.bin_H_all)
       #print(self.joint_H_all)


    def __init__(self, V, O, G):
        super ().__init__()
        self.init_nodes(V, O, G)
        self.preprocess_H()
        self.optimizer = optim.Adam(self.parameters())

    def prob_H_given_V(self):
        res2 = torch.matmul(self.V, self.W0)  # (N*|V|)*(|V|*|O|) = (N*|O|)
        #print('res2', res2)

        # Prior
        
        prob_Hni_1_given_V = torch.special.expit(res2) # (N*|O|)
        #print('prob_Hni_1_given_V', prob_Hni_1_given_V)
        prob_Hni_0_given_V = 1 - prob_Hni_1_given_V # (N*|O|)
        #print('prob_Hni_0_given_V', prob_Hni_0_given_V)

        res3 = prob_Hni_1_given_V.unsqueeze(dim=1) * self.bin_H_all # (N*1*|O|)*(uniq_H_count*|O|) = (N*uniq_H_count*|O|)
        #print('res3', res3)
        res4 = prob_Hni_0_given_V.unsqueeze(dim=1) * (1 - self.bin_H_all) # (N*1*|O|)*(uniq_H_count*|O|) = (N*uniq_H_count*|O|)
        #print('res4', res4)
        res5 = torch.prod(res3 + res4, dim=2) # (N*uniq_H_count)
        #print('res5', res5)


        res6 = torch.matmul(self.bin_H_all, self.W1.t()) # (uniq_H_count*|O|) * (1*|O|)T -> (uniq_H_count*|O|) * (|O|*1) = (uniq_H_count*1)
        #print('res6', res6)
        res7 = torch.matmul(self.joint_H_all, self.W2.t()) # (uniq_H_count*|O|_C_2) * (1*|O|_C_2)T -> (uniq_H_count*|O|_C_2) * (|O|_C_2*1) = (uniq_H_count*1)
        #print('res7', res7)
        res8 = torch.special.expit((res6 + res7).t()) # (1*uniq_H_count)
        #print('res8', res8)
        
        res9 = res5 * res8 # (N*uniq_H_count)
        #print('res9', res9)
        res10 = torch.sum(res9, dim=1, keepdim=True) # (N*1)
        #print('res10', res10)

        prob_H_given_V = res9 / res10 # (N*uniq_H_count) # X = {V, O}
        #print('self.prob_H_given_V', self.prob_H_given_V)

        return prob_H_given_V

    def step_E_prob_H(self):
        # & - elementwise multiplication
        # * - matrix multiplication

        res2 = torch.matmul(self.V, self.W0)  # (N*|V|)*(|V|*|O|) = (N*|O|)
        #print('res2', res2)

        # Posterior
    
        res6 = torch.matmul(self.bin_H_all, self.W1.t()) # (uniq_H_count*|O|) * (1*|O|)T -> (uniq_H_count*|O|) * (|O|*1) = (uniq_H_count*1)
        #print('res6', res6)
        res7 = torch.matmul(self.joint_H_all, self.W2.t()) # (uniq_H_count*|O|_C_2) * (1*|O|_C_2)T -> (uniq_H_count*|O|_C_2) * (|O|_C_2*1) = (uniq_H_count*1)
        #print('res7', res7)
        res8 = torch.special.expit((res6 + res7).t()) # (1*uniq_H_count)
        #print('res8', res8)

        square_A = torch.square(self.A)
        #d = beta(1, square_A)
        prob_O_given_H1 =  square_A * torch.pow(1 - self.O, square_A - 1)  # ((1-Oi)^(Ai-1))/B(1, Ai) # (1*|O|)&(N*|O|) = (N*|O|)
       #print('prob_O_given_H1', prob_O_given_H1.shape)
       #print('res2', res2.shape)

        prob_Hni_1_given_X = prob_O_given_H1 / (prob_O_given_H1 + torch.exp(-res2)) # (N*|O|) # X = {V, O}
        #print('prob_Hni_1_given_X', prob_Hni_1_given_X)
        prob_Hni_0_given_X = 1 - prob_Hni_1_given_X # (N*|O|) # X = {V, O}
        #print('prob_Hni_0_given_X', prob_Hni_0_given_X)

        res11 = prob_Hni_1_given_X.unsqueeze(dim=1) * self.bin_H_all # (N*1*|O|)*(uniq_H_count*|O|) = (N*uniq_H_count*|O|)
        #print('res11', res11)
        res12 = prob_Hni_0_given_X.unsqueeze(dim=1) * (1 - self.bin_H_all) # (N*1*|O|)*(uniq_H_count*|O|) = (N*uniq_H_count*|O|)
        #print('res12', res12)
        res13 = torch.prod(res11 + res12, dim=2) # (N*uniq_H_count)
        #print('res13', res13)
        
        res14 = res13 * res8 # (N*uniq_H_count)
        #print('res14', res14)
        res15 = torch.sum(res14, dim=1, keepdim=True) # (N*1)
        #print('res15', res15)

        self.prob_H_given_X = res14 / res15 # (N*uniq_H_count) # X = {V, O}

        #print('self.prob_H_given_X', self.prob_H_given_X.shape)
       #print(self.prob_H_given_X)

        '''
        prob_Hni_1_given_X1 = res1 * prob_Hni_1_given_V
        prob_Hni_0_given_X1 = prob_Hni_0_given_V

        prob_Hni_1_given_X1 = prob_Hni_1_given_X1 / (prob_Hni_1_given_X1 + prob_Hni_0_given_X1)
        prob_Hni_0_given_X1 = 1 - prob_Hni_1_given_X1

        # Both  are equal, so calculations are correct

       #print(torch.equal(torch.round(prob_Hni_1_given_X, decimals=5), torch.round(prob_Hni_1_given_X1, decimals=5)))
       #print(torch.equal(torch.round(prob_Hni_0_given_X, decimals=5), torch.round(prob_Hni_0_given_X1, decimals=5)))

       #print(prob_Hni_1_given_X - prob_Hni_1_given_X1)
       #print(prob_Hni_0_given_X - prob_Hni_0_given_X1)
        '''

    def prob_Z_given_H(self):
        #print('Here')
        res2 = torch.matmul(self.bin_H_all, self.M).t() # (uniq_H_count*|O|)*(|O|*1) = (uniq_H_count*1) -> (1*uniq_H_count)
        #print('res2', res2)
        prob_Z1_given_H = torch.special.expit(res2).repeat(self.N, 1) # (N*uniq_H_count)
        #print('prob_Z1_given_H', prob_Z1_given_H)
        prob_Z0_given_H = 1 - prob_Z1_given_H # (N*uniq_H_count)
        #print('prob_Z0_given_H', prob_Z0_given_H)

        return {'Z1': prob_Z1_given_H, 'Z0': prob_Z0_given_H}


    def step_E_prob_Z(self):
        #self.prob_Z_given_H()

        square_B = torch.square(self.B)
        prob_G_given_Z1 = (square_B * torch.pow(self.G, square_B - 1)).t() #/ beta(square_B, 1)).t() # G^(B-1)/Beta(B) # (1*1)&(N*1) = (1*N) -> (N*1)
       #print(prob_G_given_Z1)
        res2 = torch.matmul(self.bin_H_all, self.M).t() # (uniq_H_count*|O|)*(|O|*1) = (uniq_H_count*1) -> (1*uniq_H_count)
        ''#print('res2', res2)
       #print(res2.shape)

       #print('prob_G_given_Z1', prob_G_given_Z1)
       #print('prob_G_given_Z1.shape', prob_G_given_Z1.shape)'''

        self.prob_Z1_given_X = prob_G_given_Z1 / (prob_G_given_Z1 + torch.exp(-res2)) # (N*uniq_H_count) # X = {H, G}
        self.prob_Z0_given_X = 1 - self.prob_Z1_given_X # (N*uniq_H_count) # X = {H, G}

        #print('self.prob_Z1_given_X', self.prob_Z1_given_X.shape)
        #print(self.prob_Z1_given_X)

        #print('self.prob_Z0_given_X', self.prob_Z0_given_X.shape)
       #print(self.prob_Z0_given_X)

    def step_E(self):
        self.step_E_prob_H()
        self.step_E_prob_Z()

    def compute_Q(self):
        ''#print('self.prob_H_given_X', self.prob_H_given_X)
       #print('self.prob_Z1_given_X', self.prob_Z1_given_X)
       #print('self.prob_Z0_given_X', self.prob_Z0_given_X)'''

        z1_const = (self.prob_H_given_X * self.prob_Z1_given_X).detach()
        z0_const = (self.prob_H_given_X * self.prob_Z0_given_X).detach()

        ''#print('z1_const.is_leaf', z1_const.is_leaf, 'z0_const.is_leaf', z0_const.is_leaf)
       #print('z1_const', z1_const)
       #print('z0_const', z0_const)'''

        prob_H_given_V = self.prob_H_given_V()
        prob_Z_given_H = self.prob_Z_given_H()

        ''#print('prob_H_given_V', prob_H_given_V)
       #print('prob_Z1_given_H', prob_Z_given_H['Z1'])
       #print('prob_Z0_given_H', prob_Z_given_H['Z0'])'''

        z1_Q = prob_H_given_V * prob_Z_given_H['Z1']
        z0_Q = prob_H_given_V * prob_Z_given_H['Z0']

        ''#print('z1_Q', z1_Q)
       #print('z0_Q', z0_Q)'''

        Q = z1_const * z1_Q + z0_const * z0_Q
        ''#print('Q', Q)'''

        Q = torch.sum(Q)
        #print('Q', Q)
        return Q

    def step_M_update_WM(self, wm_epoch):
        prev_Q = -100
        for epoch in range(wm_epoch):
            Q = self.compute_Q()
            #print('WM Epoch', epoch+1, 'Q', Q.item())
            change_Q = Q - prev_Q
            if(change_Q < 0.0001):
                break
            prev_Q = Q
            neg_Q = -Q
            self.optimizer.zero_grad()
            neg_Q.backward()
            self.optimizer.step()
        return Q.item()

    def step_M_update_A(self):
        ''#print('In step_M_update_A')
       #print('self.prob_H_given_X', self.prob_H_given_X)'''
        res1 = self.prob_H_given_X * (self.prob_Z1_given_X + self.prob_Z0_given_X) # (N,uniq_H_count)&(N,uniq_H_count) = (N,uniq_H_count)
        #print('res1', res1) # 
        res2 = res1 * self.bin_H_sum # (N,uniq_H_count)&(1,uniq_H_count) = (N,uniq_H_count)
        #print('res2', res2)

        res3 = torch.unsqueeze(torch.log(1 - self.O), dim=1)
        #print('res3', res3)

        res4 = torch.unsqueeze(res2, dim=2)
        #print('res4', res4)
        
        res5 = res4 * res3
        #print('res5', res5)

        res6 = torch.sum(res2)
        #print('res6', res6)

        res7 = torch.sum(torch.sum(res5, dim=1), dim=0)
        #print('res7', res7)

        self.A = res6 / res7
        #print('self.A', self.A)

    def step_M_update_B(self):
        ''#print('In step_M_update_B')
       #print('self.G', self.G)
       #print('self.prob_H_given_X', self.prob_H_given_X)
       #print('self.prob_Z1_given_X', self.prob_Z1_given_X)'''
        res1 = self.prob_H_given_X * self.prob_Z1_given_X
        #print('res1', res1)
        res2 = res1 * torch.log(self.G.unsqueeze(dim=1))
        #print('res2', res2)
        self.B = (torch.sum(res1) / torch.sum(res2)).reshape(1, 1)
        #print('self.B', self.B)

    def step_M(self, wm_epoch):
        self.step_M_update_A()
        self.step_M_update_B()
        Q = self.step_M_update_WM(wm_epoch)
        return Q

    def forward(self, em_epoch=1000, wm_epoch=50000):
        prev_Q = -100
        for epoch in range(em_epoch):
           #print('Running EM, Epoch', epoch+1)
            self.step_E()
            Q = self.step_M(wm_epoch)
            print('Epoch', epoch, 'Q', Q)
            change_Q = Q - prev_Q
            '''if(change_Q < 0.0001):
                break'''
            prev_Q = Q
       #print('Final Q', Q)


In [None]:
def read_data(excel_path, sheet_name):
    df = pd.read_excel(excel_path, sheet_name)
    df = df.head(5)
    cols_V = [col for col in df.columns if col.startswith('V')]
    cols_O = [col for col in df.columns if col.startswith('O')]

    V = df[cols_V].to_numpy()
    O = df[cols_O].to_numpy()
    G = df['G'].to_numpy()

    return {'V': V, 'O': O, 'G': G, 'cols_V': cols_V, 'cols_O': cols_O}

In [None]:
def main():
    excel_path = '/content/drive/MyDrive/Mayo/Rare/Simulation/Simulation-2/simulated-2.xlsx'
    sheet_name = 'Data'
    data = read_data(excel_path, sheet_name)
    model = EM(data['V'], data['O'], data['G'])
    epoch_count = 2
    model()

In [None]:
main()