Skip to content

Commit

Permalink
Ported demo scripts.
Browse files Browse the repository at this point in the history
  • Loading branch information
nigma committed Jul 11, 2012
1 parent 2787e35 commit 22c56ed
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 85 deletions.
45 changes: 22 additions & 23 deletions demo/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,30 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time, gc, sys, csv, warnings
from __future__ import print_function

import time, gc, sys
import pywt
import numpy
import pylab

#sys.stderr = sys.stdout
#gc.set_debug(gc.DEBUG_LEAK)

if sys.platform == 'win32':
clock = time.clock
else:
clock = time.time

sizes = [20, 50, 100, 120, 150, 200, 250, 300, 400, 500, 600, 750,
1000, 2000, 3000, 4000, 5000, 6000, 7500,
10000, 15000, 20000, 25000, 30000, 40000, 50000, 75000,
100000, 150000, 200000, 250000, 300000, 400000, 500000,
600000, 750000, 1000000, 2000000, 5000000][:-4]
sizes = [
20, 50, 100, 120, 150, 200, 250, 300, 400, 500, 600, 750,
1000, 2000, 3000, 4000, 5000, 6000, 7500,
10000, 15000, 20000, 25000, 30000, 40000, 50000, 75000,
100000, 150000, 200000, 250000, 300000, 400000, 500000,
600000, 750000, 1000000
]

wavelet_names = ['db1', 'db2', 'db3', 'db4', 'db5', 'db6', 'db7',
'db8', 'db9', 'db10', 'sym10', 'coif1', 'coif2',
'coif3', 'coif4', 'coif5']
wavelet_names = [
'db1', 'db2', 'db3', 'db4', 'db5', 'db6', 'db7', 'db8', 'db9', 'db10',
'sym10', 'coif1', 'coif2', 'coif3', 'coif4', 'coif5'
]

dtype = numpy.float64

Expand All @@ -35,36 +37,33 @@
repeat = 5

for j, size in enumerate(sizes):
#if size > 500000:
# warnings.warn("Warning, too big data size may cause page swapping.")

data = numpy.ones((size,), dtype)

print ("%d/%d" % (j+1, len(sizes))).rjust(6), str(size).rjust(9),
print("{0:>2}/{1:<3}{0:>9}".format(j+1, len(sizes), size), end="")

for i, w in enumerate(wavelets):
min_t1, min_t2 = 9999., 9999.
for _ in range(repeat):
t1 = clock()
(a,d) = pywt.dwt(data, w, mode)
(a, d) = pywt.dwt(data, w, mode)
t1 = clock() - t1
min_t1 = min(t1, min_t1)

t2 = clock()
a0 = pywt.idwt(a, d, w, mode)
t2 = clock() - t2
min_t2 = min(t2, min_t2)

times_dwt[i].append(min_t1)
times_idwt[i].append(min_t2)
print '.',
print
print(".", end="")
print()
gc.collect()


for j, (times,name) in enumerate([(times_dwt, 'dwt'), (times_idwt, 'idwt')]):
for j, (times, name) in enumerate([(times_dwt, 'dwt'), (times_idwt, 'idwt')]):
pylab.figure(j)
pylab.title(name)

for i, n in enumerate(wavelet_names):
pylab.loglog(sizes, times[i], label=n)

Expand Down
16 changes: 8 additions & 8 deletions demo/dwt_signal_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function

import pywt
import time
import numpy
import pylab

data1 = pylab.array(range(1,400) + range(398, 600) + range(601, 1024))
data1 = pylab.array(list(range(1,400)) + list(range(398, 600)) + list(range(601, 1024)))

x = pylab.arange(612-80, 20, -0.5)/250.
data2 = pylab.sin(40*pylab.log(x)) * pylab.sign((pylab.log(x)))

from sample_data import ecg as data3

mode = pywt.MODES.sp1
def plot(data, w, title):
print title
print(title)

w = pywt.Wavelet(w)
a = data
ca = []
cd = []
for i in range(5):
(a, d) = pywt.dwt(a, w, mode)
(a, d) = pywt.dwt(a, w, "sp1")
ca.append(a)
cd.append(d)

Expand All @@ -42,7 +43,6 @@ def plot(data, w, title):
pylab.xlim(0, len(data)-1)

for i, y in enumerate(rec_a):
#print len(data), len(x), len(data) / (2**(i+1))
ax = pylab.subplot(len(rec_a)+1, 2, 3+i*2)
ax.plot(y, 'r')
pylab.xlim(0, len(y)-1)
Expand All @@ -56,7 +56,7 @@ def plot(data, w, title):
pylab.ylabel("D%d" % (i+1))


print "Signal decomposition (S = An + Dn + Dn-1 + ... + D1)"
print("Signal decomposition (S = An + Dn + Dn-1 + ... + D1)")
plot(data1, 'coif5', "DWT: Signal irregularity")
plot(data2, 'sym5', "DWT: Frequency and phase change - Symmlets5")
plot(data3, 'sym5', "DWT: Ecg sample - Symmlets5")
Expand Down
15 changes: 6 additions & 9 deletions demo/dwt_swt_show_coeffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,25 @@
from __future__ import absolute_import

import pywt
import time
import pylab

#r = pywfdb.Record('d:/mitdb/101')
#data = r.read(0, 5050, 1024)

data1 = pylab.array(range(1,400) + range(398, 600) + range(601, 1024))/1024.
data1 = pylab.array(list(range(1,400)) + list(range(398, 600)) + list(range(601, 1024)))/1024.
data2 = pylab.arange(612-80, 20, -0.5)/250.
data2 = pylab.sin(40*pylab.log(data2)) * pylab.sign((pylab.log(data2)))
from .sample_data import ecg as data3

mode = pywt.MODES.sp1
from sample_data import ecg as data3

DWT = 1

def plot(data, w, title):
w = pywt.Wavelet(w)
a = data
ca = []
cd = []

if DWT:
for i in range(5):
(a, d) = pywt.dwt(a, w, mode)
(a, d) = pywt.dwt(a, w, "sp1")
ca.append(a)
cd.append(d)
else:
Expand All @@ -39,7 +37,6 @@ def plot(data, w, title):
pylab.xlim(0, len(data)-1)

for i, x in enumerate(ca):
#print len(data), len(x), len(data) / (2**(i+1))
lims = -(len(data) / (2.**(i+1)) - len(x)) / 2.
ax = pylab.subplot(len(ca)+1, 2, 3+i*2)
ax.plot(x, 'r')
Expand Down
6 changes: 3 additions & 3 deletions demo/image_blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
(details only)
"""

from __future__ import print_function
import sys, optparse

import Image # http://effbot.org/downloads/#PIL
Expand Down Expand Up @@ -168,10 +169,9 @@ def main():
im = blend_images(base, texture, options.wavelet, options.level, options.mode, options.base_gain, options.texture_gain)

if options.timeit:
print "%.3fs" % (clock() - t)
print("{0:.3f}s".format(clock() - t))

im.save(options.output)

if __name__ == '__main__':
main()

12 changes: 7 additions & 5 deletions demo/plot_wavelets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

# Plot scaling and wavelet functions for db, sym, coif, bior and rbio families

from __future__ import print_function

import pywt
import pylab
import itertools
Expand All @@ -16,17 +18,17 @@
colors = itertools.cycle('bgrcmyk')

wnames = pywt.wavelist(family)
print wnames
print(wnames)
i = iter(wnames)
for col in range(cols):
for row in range(rows):
try:
wavelet = pywt.Wavelet(i.next())
wavelet = pywt.Wavelet(next(i))
except StopIteration:
break
phi, psi, x = wavelet.wavefun(iterations)

color = colors.next()
color = next(colors)
ax = pylab.subplot(rows, 2*cols, 1 + 2*(col + row*cols))
pylab.title(wavelet.name + " phi")
pylab.plot(x, phi, color)
Expand All @@ -48,13 +50,13 @@
for col in range(cols):
for row in range(rows):
try:
wavelet = pywt.Wavelet(i.next())
wavelet = pywt.Wavelet(next(i))
except StopIteration:
break
phi, psi, phi_r, psi_r, x = wavelet.wavefun(iterations)
row *= 2

color = colors.next()
color = next(colors)
ax = pylab.subplot(2*rows, 2*cols, 1 + 2*(col + row*cols))
pylab.title(wavelet.name + " phi")
pylab.plot(x, phi, color)
Expand Down
31 changes: 18 additions & 13 deletions demo/user_filter_banks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function

import pywt

class FilterBank(object):
Expand All @@ -16,31 +18,34 @@ def __init__(self):
data = [1,2,3,4,5,6]

############################################################################
print "Case 1 (custom filter bank - Haar wavelet)"
print("Case 1 (custom filter bank - Haar wavelet)")

myBank = FilterBank()

# pass the user supplied filter bank as argument
myWavelet = pywt.Wavelet(name="UserSuppliedWavelet", filter_bank=myBank)
#print myWavelet.get_filters_coeffs()

print "data:", data
print("data:", data)

a, d = pywt.dwt(data, myWavelet)
print "a:", a
print "d:", d
print "rec:", pywt.idwt(a, d, myWavelet)
print("a:", a)
print("d:", d)
print("rec:", pywt.idwt(a, d, myWavelet))

############################################################################
print "-" * 75
print "Case 2 (Wavelet object as filter bank - db2 wavelet)"
print("-" * 75)
print("Case 2 (Wavelet object as filter bank - db2 wavelet)")

# builtin wavelets can also be treated as filter banks with theirs
# filter_bank attribute

builtinWavelet = pywt.Wavelet('db2')
builtinWavelet = pywt.Wavelet("db2")
myWavelet = pywt.Wavelet(name="UserSuppliedWavelet", filter_bank=builtinWavelet)

print "data:", data
print("data:", data)

a, d = pywt.dwt(data, myWavelet)
print "a:", a
print "d:", d
print "rec:", pywt.idwt(a, d, myWavelet)

print("a:", a)
print("d:", d)
print("rec:", pywt.idwt(a, d, myWavelet))
27 changes: 14 additions & 13 deletions demo/wavedec.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,31 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function

import pywt

data = range(16)
wavelet = 'db4'
wavelet = "db4"
level = 2
mode = 'cpd'
mode = "cpd"

print "original data:"
print data
print
print("original data:")
print(data)

# dec = [cA(n-1) cD(n-1) cD(n-2) ... cD(2) cD(1)]
dec = pywt.wavedec(data, wavelet, mode, level)

print "decomposition:"
print("decomposition:")

print "cA%d:" % (len(dec)-1)
print ' '.join([("%.3f" % val) for val in dec[0]])
print("cA{0}".format(len(dec) - 1))
print(" ".join("{0:.3f}".format(val) for val in dec[0]))

for i,d in enumerate(dec[1:]):
print "cD%d:" % (len(dec)-1-i)
print ' '.join([("%.3f" % val) for val in d])
print("cD{0}:".format(len(dec) - 1 - i))
print(" ".join("{0:.3f}".format(val) for val in d))

print
print "reconstruction:"
print()
print("reconstruction:")

print ' '.join([("%.3f" % val) for val in pywt.waverec(dec, wavelet, mode)])
print(" ".join("{0:.3f}".format(val) for val in pywt.waverec(dec, wavelet, mode)))
8 changes: 5 additions & 3 deletions demo/waveinfo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import print_function

import sys, os.path
import pywt
import pylab
Expand All @@ -11,13 +13,13 @@
except IndexError:
level = 10
except ValueError:
print "Unknown wavelet"
print("Unknown wavelet")
raise SystemExit
except IndexError:
print usage
print(usage)
raise SystemExit

print wavelet
print(wavelet)

data = wavelet.wavefun(level)
funcs, x = data[:-1], data[-1]
Expand Down
4 changes: 2 additions & 2 deletions demo/wp_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pylab
import numpy
import Image # PIL
import pywt

from pywt import WaveletPacket2D

im = Image.open("data/aero.png").convert('L')
Expand Down Expand Up @@ -42,7 +42,7 @@
for row in wp2.get_level(2, 'freq'):
for node in row:
pylab.subplot(len(row),len(row),i)
pylab.title("%s=(%s row, %s col)" % ((node.path,)+ wp2.expand_2d_path(node.path)))
pylab.title("{0}=({1} row, {2} col)".format(node.path, *wp2.expand_2d_path(node.path)))
pylab.imshow(mod(node.data), origin='image', interpolation="nearest", cmap=pylab.cm.gray)
i += 1

Expand Down
Loading

0 comments on commit 22c56ed

Please sign in to comment.