In [1]:
%matplotlib notebook

from __future__ import division
import math
import sys
import os
import csv
import sqlite3
import pandas
import numpy as np
from scipy.interpolate import interp1d
from scipy.optimize import minimize


from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib import cm

sys.path.append('.')
import scripts2d.utils as u
from wde.estimator import WaveletDensityEstimator
from wde.simple_estimator import SimpleWaveletDensityEstimator
from wde.common import np_mult
from wde.thresholding import soft_threshold, hard_threshold, block_threshold


In [2]:
def contour_plot_it(dists, data, title='Contour', fname=None):
    fig = plt.figure(figsize=(4, 4), dpi=144)
    X = np.linspace(0.0,1.0, num=75)
    Y = np.linspace(0.0,1.0, num=75)
    XX, YY = np.meshgrid(X, Y)
    if type(dists) is not tuple:
        dists=(dists,)
    minz, maxz = float('inf'), float('-inf')
    Zs = []
    for dist in dists:
        Z = dist.pdf((XX, YY))
        Zs.append(Z)
        minz = min(minz, Z.min())
        maxz = max(maxz, Z.max())
    levels = np.linspace(minz, maxz, num=10)
    cmap = cm.get_cmap('BuGn')
    if minz == 0:
        levels = np.linspace(minz + (maxz-minz)/100, maxz, num=10)
        cmap.set_under("magenta")
    linestyles = enumerate(['solid','dashed', 'dashdot', 'dotted'])
    alphas = enumerate([0.4,1.0,0.2,0.2])
    for dist, Z in zip(dists, Zs):
        linestyle = next(linestyles)[1]
        cs = plt.contour(XX, YY, Z, alpha=(next(alphas)[1]), linestyles=linestyle, levels=levels, extend='min', cmap=cmap)
        if linestyle == 'dashed':
            plt.clabel(cs, inline=1, fontsize=10)
    plt.scatter(data[:,0], data[:,1], s=1, alpha=0.4)
    #avg = data.mean(axis=0)
    #plt.scatter(avg[0],avg[1], s=10, marker='+', color='r')
    plt.title(title)
    if fname is not None:
        plt.savefig('data/%s' % fname, pad_inches=0.0, orientation='portrait', frameon=False)
    plt.show()
    
def plot_it(dist, fname=None):
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    X = np.linspace(0.0,1.0, num=75)
    Y = np.linspace(0.0,1.0, num=75)
    XX, YY = np.meshgrid(X, Y)
    Z = dist.pdf((XX, YY))
    surf = ax.plot_surface(XX, YY, Z, edgecolors='k', linewidth=0.5, cmap=cm.get_cmap('BuGn'))
    #ax.set_zlim(0, 5)
    if fname is not None:
        plt.savefig('data/%s' % fname, pad_inches=0.0, orientation='portrait', frameon=False)
    plt.show()


In [22]:
wde_glob = None
def p11(n, wv, **kwargs):
    global wde_glob
    dist = u.dist_from_code('mix2')
    data = dist.rvs(n)
    wde = SimpleWaveletDensityEstimator(wv, **kwargs)
    #print data.mean(axis=0)
    wde.fit(data)
    #print 'sum of coeffs',wde.norm_const
    XX, YY = u.mise_mesh()
    Z = dist.pdf((XX, YY))
    ise = u.calc_ise(wde.pdf, Z)
    print 'ISE j0=%d, j1=%d:' % (kwargs['j0'], kwargs['j1']), ise
    if kwargs['j0'] <= kwargs['j1']:
        kwargs['j0'] = kwargs['j1'] + 1
    else:
        kwargs['j1'] = kwargs['j0'] - 1
        kwargs['j0'] = kwargs['j1'] - 1
    wde = SimpleWaveletDensityEstimator(wv, **kwargs)
    wde.fit(data)
    Z = dist.pdf((XX, YY))
    ise = u.calc_ise(wde.pdf, Z)
    print 'ISE j0=%d, j1=%d:' % (kwargs['j0'], kwargs['j1']), ise
    return
    #swde = SimpleWaveletDensityEstimator(wv, **kwargs)
    #swde.fit(data)
    #ise = u.calc_ise(swde.pdf, Z)
    print 'ISE simple:', ise
    #contour_plot_it((dist,), data, 'Mult & WDE - n=%d, wv=%s, j0=%d, j1=%d' % (n, wv, kwargs['j0'], kwargs['j1']))
    #plot_it(dist, fname='mix2-true')
    #plot_it(wde, fname='mix2-wde')
    #plot_it(swde, fname='mix2-classic')
    wde_glob = wde
p11(1024, 'db2', j0=3, j1=0)

ISE j0=3, j1=0: 0.369463442573
ISE j0=1, j1=2: 0.369342915195


In [4]:
import scipy.stats as stats
def xx():
    dist = stats.dirichlet(alpha=[2,3,3])
    x = dist.rvs(1000)[:,[0,1]]
    x = np.stack([np.power(x[:,0],1/4), x[:,1]], axis=1)
    print x.shape
    plt.figure()
    plt.scatter(x[:,0], x[:,1], alpha=0.2)
    plt.xlim([0,1])
    plt.ylim([0,1])
    plt.show()
#xx()

In [5]:
def pp():
    X = np.linspace(1/255,1.0, num=255)
    Y = np.linspace(1/255,1.0, num=255)
    XX, YY = np.meshgrid(X, Y) # X,Y
    ZZ = 1 - XX - YY
    print ZZ.shape
    zpos = (ZZ > 0) & (ZZ <= 1)
    print zpos.shape
    ZZm = ZZ[zpos]
    XXm = XX[zpos]
    YYm = YY[zpos]
    dist = stats.dirichlet(alpha=[2,3,5])
    print ZZm.shape
    print np.min(XXm),np.min(YYm),np.min(ZZm)
    vals = dist.pdf((XXm, YYm, ZZm))
    resp = np.zeros(ZZ.shape)
    it = np.nditer(resp, op_flags=['writeonly'])
    it2 = np.nditer(zpos, op_flags=['readonly'])
    ix = 0
    while not it.finished:
        if it2[0]:
            it[0] = vals[ix]
            ix += 1
        it.iternext()
        it2.iternext()
    print ix, vals.shape
    print resp.shape
    print resp.sum()/(255*255)
#pp()


In [6]:
def qq():
    dist = u.dist_from_code('mult')
    XX, YY = u.mise_mesh()
    vv = dist.pdf((XX, YY))
    print vv.shape
#qq()