Permalink
Browse files

added the heatmap visualization

  • Loading branch information...
cheyneh committed Nov 23, 2015
1 parent fad0fc6 commit ee061bcb8981f87f9efdb1040eae05b65c03ffc2
Showing with 932 additions and 41 deletions.
  1. +1 −0 .gitignore
  2. +643 −0 .ipynb_checkpoints/Overview-checkpoint.ipynb
  3. +93 −21 Overview.ipynb
  4. +12 −0 permpy/__init__.py
  5. +7 −1 permpy/permclass.py
  6. +159 −7 permpy/permset.py
  7. +17 −12 permpy/permutation.py
View
@@ -6,3 +6,4 @@ venv
avoidencegrid.py
OLD_permtools.py
Icon
plot.py

Large diffs are not rendered by default.

Oops, something went wrong.
View

Large diffs are not rendered by default.

Oops, something went wrong.
View
@@ -8,5 +8,17 @@
from .InsertionEncoding import *
import permpy.RestrictedContainer
try:
import matplotlib as mpl
mpl.rc('axes', fc='E5E5E5', ec='white', lw='1',
grid='True', axisbelow='True')
mpl.rc('grid', c='white', ls='-')
mpl.rc('figure', fc='white')
mpl_imported = True
except ImportError:
print('Install matplotlib for extra plotting functionality')
pass
Perm = Permutation
Av = AvClass
View
@@ -3,7 +3,7 @@
from math import factorial
import permpy.permutation
import permpy.permset
from permpy.permset import PermSet
class PermClass(list):
@@ -149,6 +149,12 @@ def plus_one_class(self):
D[l+1] = D[l+1].union(P.all_extensions())
return D
def heatmap(self, **kwargs):
permset = PermSet()
for item in self:
permset.update(item)
permset.heatmap(**kwargs)
def sum_closure(self,length=8, has_syms=False):
return PermClass.class_from_test(lambda P : ((len(P) < len(self) and P in self[len(P)]) or P.sum_decomposable()) and all([Q in self[len(Q)] for Q in P.chom_sum()]), l=length, has_all_syms=has_syms)
View
@@ -1,17 +1,40 @@
import random
import fractions
from functools import reduce
import permpy.permutation
from permpy.permutation import Permutation
import permpy.permclass
# import permpy.permclass
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl_imported = True
except ImportError:
mpl_imported = False
class PermSet(set):
"""Represents a set of permutations, and allows statistics to be computed
across the set."""
def __repr__(self):
# if len(self) > 10:
return 'Set of {} permutations'.format(len(self))
# else:
# return set.__repr__(self)
def __add__(self, other):
"""Returns the union of the two permutation sets. Does not modify in
place.
Example
-------
>>> S = PermSet.all(3) + PermSet.all(4)
>>> len(S)
30
"""
result = PermSet()
result.update(self); result.update(other)
return result
@classmethod
def all(cls, length):
@@ -31,15 +54,144 @@ def all(cls, length):
return PermSet(Permutation.listall(length))
def get_random(self):
"""Returns a random element from the set."""
"""Returns a random element from the set.
Example
-------
>>> p = PermSet.all(4).get_random()
>>> p in PermSet.all(4) and len(p) == 4
True
"""
return random.sample(self, 1)[0]
def get_length(self, length=None):
"""Returns the subset of permutations which have the specified length.
Parameters
----------
length : int
lenght of permutations to be returned
Example
-------
>>> S = PermSet.all(4) + PermSet.all(3)
>>> S.get_length(3) == PermSet.all(3)
True
"""
return PermSet(p for p in self if len(p) == length)
def heatmap(self, only_length=None, ax=None, blur=False, gray=False, **kwargs):
"""Visalization of a set of permutations, which, for each length, shows
the relative frequency of each value in each position.
Paramters
---------
only_length : int or None
if given, restrict to the permutations of this length
"""
if not mpl_imported:
err = 'heatmap requires matplotlib to be imported'
raise NotImplementedError(err)
try:
import numpy as np
except ImportError as e:
err = 'heatmap function requires numpy'
raise e(err)
# first group permutations by length
total_size = len(self)
perms_by_length = {}
for perm in self:
n = len(perm)
if n in perms_by_length:
perms_by_length[n].add(perm)
else:
perms_by_length[n] = PermSet([perm])
# if given a length, ignore all other lengths
if only_length:
perms_by_length = {only_length: perms_by_length[only_length]}
lengths = list(perms_by_length.keys())
def lcm(l):
"""Returns the least common multiple of the list l."""
lcm = reduce(lambda x,y: x*y // fractions.gcd(x,y), l)
return lcm
grid_size = lcm(lengths)
grid = np.zeros((grid_size, grid_size))
def inflate(a, n):
"""Inflates a k x k array A by into a nk x nk array by inflating
each entry from A into a n x n matrix."""
ones = np.ones((n, n))
c = np.multiply.outer(a, ones)
c = np.concatenate(np.concatenate(c, axis=1), axis=1)
return c
for length, permset in perms_by_length.items():
small_grid = np.zeros((length, length))
for p in permset:
for idx, val in enumerate(p):
small_grid[length-val-1, idx] += 1
mul = grid_size // length
inflated = inflate(small_grid, mul)
num_perms = len(permset)
inflated /= inflated.max()
grid += inflated
if not ax:
ax = plt.gca()
if blur:
interpolation = 'bicubic'
else:
interpolation = 'nearest'
def get_cubehelix(gamma=1, start=1, rot=-1, hue=1, light=1, dark=0):
"""Get a cubehelix palette."""
cdict = mpl._cm.cubehelix(gamma, start, rot, hue)
cmap = mpl.colors.LinearSegmentedColormap("cubehelix", cdict)
x = np.linspace(light, dark, 256)
pal = cmap(x)
cmap = mpl.colors.ListedColormap(pal)
return cmap
if gray:
cmap = get_cubehelix(start=.5, rot=1, light=1, dark=.2, hue=0)
else:
cmap = get_cubehelix(start=.5, rot=-.5, light=1, dark=.2)
ax.imshow(grid, cmap=cmap, interpolation=interpolation)
ax.set_aspect('equal')
ax.set(**kwargs)
ax.axis('off')
return ax
def show_all(self):
"""The default representation doesn't print the entire set, function
allows this."""
"""The default representation doesn't print the entire set, this
function does."""
return set.__repr__(self)
def __add__(self, other):
result = PermSet()
result.update(self)
result.update(other)
return result
def minimal_elements(self):
"""Returns the elements of the set which are minimal with respect to
the permutation pattern order.
Examples
--------
"""
B = list(self)
B = sorted(B, key=len)
C = B[:]
View
@@ -11,18 +11,23 @@
# python 2/3 compatibility
from functools import reduce
mpl_imported = False
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', fc='E5E5E5', ec='white', lw='1',
grid='True', axisbelow='True')
mpl.rc('grid', c='white', ls='-')
mpl.rc('figure', fc='white')
mpl_imported = True
except ImportError:
print('Install matplotlib for extra plotting functionality')
pass
mpl_imported = False
# try:
# import matplotlib as mpl
# import matplotlib.pyplot as plt
# mpl.rc('axes', fc='E5E5E5', ec='white', lw='1',
# grid='True', axisbelow='True')
# mpl.rc('grid', c='white', ls='-')
# mpl.rc('figure', fc='white')
# mpl_imported = True
# except ImportError:
# print('Install matplotlib for extra plotting functionality')
# pass
@@ -142,15 +147,15 @@ def standardize(cls, L):
ordered.sort()
return [ordered.index(x) for x in L]
@staticmethod
@classmethod
def change_repr(cls, representation=None):
"""Toggles globally between cycle notation or one-line notation. Note
that internal representation is still one-line."""
L = ['oneline', 'cycles', 'both']
L = ['oneline', 'cycle', 'both']
if representation in L:
cls._REPR = representation
else:
k = int(input('1 for oneline, 2 for cycles, 3 for both\n '))
k = int(input('1 for oneline, 2 for cycle, 3 for both\n '))
k -= 1
cls._REPR = L[k]
@@ -271,7 +276,7 @@ def __repr__(self):
"""Tells python how to display a permutation object."""
if Permutation._REPR == 'oneline':
return self.oneline()
if Permutation._REPR == 'cycles':
if Permutation._REPR == 'cycle':
return self.cycles()
else:
return '\n'.join([self.oneline(), self.cycles()])

0 comments on commit ee061bc

Please sign in to comment.