Skip to content

Commit

Permalink
added the heatmap visualization
Browse files Browse the repository at this point in the history
  • Loading branch information
cheyneh committed Nov 23, 2015
1 parent fad0fc6 commit ee061bc
Show file tree
Hide file tree
Showing 7 changed files with 932 additions and 41 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -6,3 +6,4 @@ venv
avoidencegrid.py
OLD_permtools.py
Icon
plot.py
643 changes: 643 additions & 0 deletions .ipynb_checkpoints/Overview-checkpoint.ipynb

Large diffs are not rendered by default.

114 changes: 93 additions & 21 deletions Overview.ipynb

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions permpy/__init__.py
Expand Up @@ -8,5 +8,17 @@
from .InsertionEncoding import *
import permpy.RestrictedContainer

try:
import matplotlib as mpl
mpl.rc('axes', fc='E5E5E5', ec='white', lw='1',
grid='True', axisbelow='True')
mpl.rc('grid', c='white', ls='-')
mpl.rc('figure', fc='white')
mpl_imported = True
except ImportError:
print('Install matplotlib for extra plotting functionality')
pass


Perm = Permutation
Av = AvClass
8 changes: 7 additions & 1 deletion permpy/permclass.py
Expand Up @@ -3,7 +3,7 @@
from math import factorial

import permpy.permutation
import permpy.permset
from permpy.permset import PermSet

class PermClass(list):

Expand Down Expand Up @@ -149,6 +149,12 @@ def plus_one_class(self):
D[l+1] = D[l+1].union(P.all_extensions())
return D

def heatmap(self, **kwargs):
permset = PermSet()
for item in self:
permset.update(item)
permset.heatmap(**kwargs)

def sum_closure(self,length=8, has_syms=False):
return PermClass.class_from_test(lambda P : ((len(P) < len(self) and P in self[len(P)]) or P.sum_decomposable()) and all([Q in self[len(Q)] for Q in P.chom_sum()]), l=length, has_all_syms=has_syms)

Expand Down
166 changes: 159 additions & 7 deletions permpy/permset.py
@@ -1,17 +1,40 @@
import random
import fractions
from functools import reduce

import permpy.permutation
from permpy.permutation import Permutation
import permpy.permclass
# import permpy.permclass

try:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl_imported = True
except ImportError:
mpl_imported = False


class PermSet(set):
"""Represents a set of permutations, and allows statistics to be computed
across the set."""

def __repr__(self):
# if len(self) > 10:
return 'Set of {} permutations'.format(len(self))
# else:
# return set.__repr__(self)

def __add__(self, other):
"""Returns the union of the two permutation sets. Does not modify in
place.
Example
-------
>>> S = PermSet.all(3) + PermSet.all(4)
>>> len(S)
30
"""
result = PermSet()
result.update(self); result.update(other)
return result


@classmethod
def all(cls, length):
Expand All @@ -31,15 +54,144 @@ def all(cls, length):
return PermSet(Permutation.listall(length))

def get_random(self):
"""Returns a random element from the set."""
"""Returns a random element from the set.
Example
-------
>>> p = PermSet.all(4).get_random()
>>> p in PermSet.all(4) and len(p) == 4
True
"""

return random.sample(self, 1)[0]


def get_length(self, length=None):
"""Returns the subset of permutations which have the specified length.
Parameters
----------
length : int
lenght of permutations to be returned
Example
-------
>>> S = PermSet.all(4) + PermSet.all(3)
>>> S.get_length(3) == PermSet.all(3)
True
"""
return PermSet(p for p in self if len(p) == length)

def heatmap(self, only_length=None, ax=None, blur=False, gray=False, **kwargs):
"""Visalization of a set of permutations, which, for each length, shows
the relative frequency of each value in each position.
Paramters
---------
only_length : int or None
if given, restrict to the permutations of this length
"""
if not mpl_imported:
err = 'heatmap requires matplotlib to be imported'
raise NotImplementedError(err)
try:
import numpy as np
except ImportError as e:
err = 'heatmap function requires numpy'
raise e(err)
# first group permutations by length
total_size = len(self)
perms_by_length = {}
for perm in self:
n = len(perm)
if n in perms_by_length:
perms_by_length[n].add(perm)
else:
perms_by_length[n] = PermSet([perm])
# if given a length, ignore all other lengths
if only_length:
perms_by_length = {only_length: perms_by_length[only_length]}
lengths = list(perms_by_length.keys())
def lcm(l):
"""Returns the least common multiple of the list l."""
lcm = reduce(lambda x,y: x*y // fractions.gcd(x,y), l)
return lcm
grid_size = lcm(lengths)
grid = np.zeros((grid_size, grid_size))
def inflate(a, n):
"""Inflates a k x k array A by into a nk x nk array by inflating
each entry from A into a n x n matrix."""
ones = np.ones((n, n))
c = np.multiply.outer(a, ones)
c = np.concatenate(np.concatenate(c, axis=1), axis=1)
return c
for length, permset in perms_by_length.items():
small_grid = np.zeros((length, length))
for p in permset:
for idx, val in enumerate(p):
small_grid[length-val-1, idx] += 1
mul = grid_size // length
inflated = inflate(small_grid, mul)
num_perms = len(permset)
inflated /= inflated.max()
grid += inflated

if not ax:
ax = plt.gca()
if blur:
interpolation = 'bicubic'
else:
interpolation = 'nearest'
def get_cubehelix(gamma=1, start=1, rot=-1, hue=1, light=1, dark=0):
"""Get a cubehelix palette."""
cdict = mpl._cm.cubehelix(gamma, start, rot, hue)
cmap = mpl.colors.LinearSegmentedColormap("cubehelix", cdict)
x = np.linspace(light, dark, 256)
pal = cmap(x)
cmap = mpl.colors.ListedColormap(pal)
return cmap
if gray:
cmap = get_cubehelix(start=.5, rot=1, light=1, dark=.2, hue=0)
else:
cmap = get_cubehelix(start=.5, rot=-.5, light=1, dark=.2)

ax.imshow(grid, cmap=cmap, interpolation=interpolation)
ax.set_aspect('equal')
ax.set(**kwargs)
ax.axis('off')
return ax











def show_all(self):
"""The default representation doesn't print the entire set, function
allows this."""
"""The default representation doesn't print the entire set, this
function does."""
return set.__repr__(self)

def __add__(self, other):
result = PermSet()
result.update(self)
result.update(other)
return result

def minimal_elements(self):
"""Returns the elements of the set which are minimal with respect to
the permutation pattern order.
Examples
--------
"""



B = list(self)
B = sorted(B, key=len)
C = B[:]
Expand Down
29 changes: 17 additions & 12 deletions permpy/permutation.py
Expand Up @@ -11,18 +11,23 @@
# python 2/3 compatibility
from functools import reduce

mpl_imported = False
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', fc='E5E5E5', ec='white', lw='1',
grid='True', axisbelow='True')
mpl.rc('grid', c='white', ls='-')
mpl.rc('figure', fc='white')
mpl_imported = True
except ImportError:
print('Install matplotlib for extra plotting functionality')
pass
mpl_imported = False

# try:
# import matplotlib as mpl
# import matplotlib.pyplot as plt
# mpl.rc('axes', fc='E5E5E5', ec='white', lw='1',
# grid='True', axisbelow='True')
# mpl.rc('grid', c='white', ls='-')
# mpl.rc('figure', fc='white')
# mpl_imported = True
# except ImportError:
# print('Install matplotlib for extra plotting functionality')
# pass



Expand Down Expand Up @@ -142,15 +147,15 @@ def standardize(cls, L):
ordered.sort()
return [ordered.index(x) for x in L]

@staticmethod
@classmethod
def change_repr(cls, representation=None):
"""Toggles globally between cycle notation or one-line notation. Note
that internal representation is still one-line."""
L = ['oneline', 'cycles', 'both']
L = ['oneline', 'cycle', 'both']
if representation in L:
cls._REPR = representation
else:
k = int(input('1 for oneline, 2 for cycles, 3 for both\n '))
k = int(input('1 for oneline, 2 for cycle, 3 for both\n '))
k -= 1
cls._REPR = L[k]

Expand Down Expand Up @@ -271,7 +276,7 @@ def __repr__(self):
"""Tells python how to display a permutation object."""
if Permutation._REPR == 'oneline':
return self.oneline()
if Permutation._REPR == 'cycles':
if Permutation._REPR == 'cycle':
return self.cycles()
else:
return '\n'.join([self.oneline(), self.cycles()])
Expand Down

0 comments on commit ee061bc

Please sign in to comment.