In [17]:
# from ipynb.fs.full.rejection_sampling import *
import math
import scipy.optimize as opt
import numpy as np
import matplotlib.pyplot as plt
import random
# q function is unnormalized density of target distribution

def q(x):
    if 0<x and x<1:
        return x**2
    if x>=1:
        return 1/(x**5)
    else:
        raise ValueError("x should be nonnegative")
        
def gamma(n):
    if n>0 and isinstance(n,int):
        return math.factorial(n-1)
    else:
        raise NotImplementedError
        
def p_exponential(x, lamb):
    if x < 0:
        return 0
    return lamb*math.e**(-lamb*x)

def p_gamma(x, k, theta):
    return x**(k-1) * math.e**(-x/theta) * (1/(gamma(k) * theta**k))

def random_exponential(n, lamb):
    beta = 1/lamb
    return np.random.exponential(beta, size=n)

def random_gamma(n, k, theta):
    return np.random.gamma(k,theta,n)

def importance_sampling(g_pdf, g_dist, S, q):
    draws = g_dist(S)
    weights = [q(theta)/g_pdf(theta) for theta in draws]
    values = [theta*weight for theta,weight in zip(draws,weights)]
    return sum(values)/sum(weights)

def importance_resampling(g_pdf, g_dist, S, q, k):
    draws = list(g_dist(S))
    weights = [q(theta)/g_pdf(theta) for theta in draws]
    sir,sir_weights = [],[]
    for i in range(k+1):
        wSum = sum(weights)
        probability_distribution = [w/wSum for w in weights]
        picked_theta = random.choices(draws, probability_distribution)[0]
        index = draws.index(picked_theta)
        sir_weights.append(weights[index])
        sir.append(picked_theta)
        weights[index] = 0
    values = [theta*weight for theta,weight in zip(sir,sir_weights)]
    return sum(values)/sum(sir_weights)

def rejection_sampling(n, M, intv_start, intv_end, q, g):
    '''
    n = number of draws, M = upperbound, q= unnormalized target distribution, g=approximate distribution
    '''
    accepts = []
    for i in range(n+1):
        t = random_exponential(1,1)[0]
        u = np.random.uniform(0,1,1)[0]
        importance_ratio = q(t)/(M*g(t))
        if importance_ratio > 1:
            print(importance_ratio)
            raise ValueError("importance ratio exceeds 1")
        
        if u<importance_ratio:
            accepts.append(t)
    return accepts


if __name__=="__main__":
    # Rejection sampling with g = exp(1), on interval [0,8]
    trials = 100000
    a,b = 0,15
    M = (math.e**(b))/(1/(b**5))
    print(M)
    accepts = rejection_sampling(trials, M, a, b, q, lambda x: p_exponential(x,1))
#     accepts_2 = rejection_sampling(trials, M, a**2, b**2, q, lambda x: p_exponential(x**2,1))
    ev = np.mean(accepts)
#     ev_2 = np.mean(accepts_2)
#     var = ev_2 - ev**2
    print(f"EV Rejection sampling, exp(1) = {ev:.5f}")
#     print(f"Var Rejection sampling, exp(1) = {var:.5f}")
    
    # Importance sampling with g = exp(1)
    g_pdf = lambda x:p_exponential(x,1)
    g_dist = lambda x:random_exponential(x,1)
    ev = importance_sampling(g_pdf, g_dist, trials, q)
    print(f"EV Importance sampling, exp(1) = {ev:.5f}")
    
#     # Importance resampling, g = exp(1)
#     g_pdf = lambda x:p_exponential(x,1)
#     g_dist = lambda x:random_exponential(x,1)
#     ev = importance_resampling(g_pdf, g_dist, trials, q, math.ceil(trials/2))
#     print(f"EV Importance resampling, exp(1) = {ev:.5f}")
    
    # Rejection sampling with g = gamma(2,1), on interval [0,5]
    trials = 100000
#     a,b = 0,5
#     M = 1/p_gamma(2,2,1) #value of mean
#     accepts = rejection_sampling(trials, M, a, b, q, lambda x: p_gamma(x,2,1))
#     ev = np.mean(accepts)
#     print(f"EV Rejection sampling, gamma(2,1) = {ev:.5f}")
    
    # Importance sampling with g = gamma(2,1)
    g_pdf = lambda x:p_gamma(x,2,.5)
    g_dist = lambda x:random_gamma(x,2,.5)
    ev = importance_sampling(g_pdf, g_dist, trials, q)
    print(f"EV Importance sampling, gamma(2,1) = {ev:.5f}")
    
    # Importance resampling, g = gamma(2,1)
    g_pdf = lambda x:p_gamma(x,1,1)
    g_dist = lambda x:random_gamma(x,1,1)
    ev = importance_resampling(g_pdf, g_dist, trials, q, math.ceil(trials/2))
    print(f"EV Importance resampling, gamma(2,1) = {ev:.5f}")

    

2482410067221.007
EV Rejection sampling, exp(1) = nan
EV Importance sampling, exp(1) = 0.99675
EV Importance sampling, gamma(2,1) = 0.99986


KeyboardInterrupt: 