forked from MengtingWan/KDEm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CRH.py
58 lines (51 loc) · 1.69 KB
/
CRH.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
"""
CRH.py
@author: Mengting Wan
"""
from __future__ import division
import numpy as np
import numpy.linalg as la
import basic_functions as bsf
def update_w(claim, index, truth, m, n, eps=1e-15):
rtn = np.zeros(m)
for i in range(n):
rtn[index[i]] = rtn[index[i]] + (claim[i]-truth[i])**2/max(np.std(claim[i]),eps)
tmp = np.sum(rtn)
if(tmp>0):
rtn[rtn>0] = np.copy(-np.log(rtn[rtn>0]/tmp))
return(rtn)
def update_truth(claim, index, w_vec, m, n):
rtn = np.zeros(n)
for i in range(n):
rtn[i] = np.dot(w_vec[index[i]],claim[i])/np.sum(w_vec[index[i]])
return(rtn)
def CRH(data, m, n, tol=1e-3, max_itr=99):
err = 99
index, claim, count = bsf.extract(data, m, n)
itr = 0
w_vec = np.ones(m)
truth = np.zeros(n)
while((err > tol) & (itr < max_itr)):
itr = itr+1
truth_old = np.copy(truth)
truth = update_truth(claim, index, w_vec, m, n)
w_vec = update_w(claim, index, truth, m, n)
err = la.norm(truth-truth_old)/la.norm(truth_old)
return([truth, w_vec])
def CRH_discret(data, m, n, tol=1e-3, max_itr=99):
err = 99
index, claim, count = bsf.extract(data, m, n)
itr = 0
w_vec = np.ones(m)
truth = np.zeros(n)
while((err > tol) & (itr < max_itr)):
itr = itr+1
truth_old = np.copy(truth)
truth = update_truth(claim, index, w_vec, m, n)
w_vec = update_w(claim, index, truth, m, n)
err = la.norm(truth-truth_old)/la.norm(truth_old)
truth = np.zeros(n)
for i in range(n):
truth[i] = claim[i][w_vec[index[i]].argmax()]
return([truth, w_vec])