In [None]:
from __future__ import division, print_function
import re
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
gb_files = [
    './Intoshia_variabili.gb', 
    './Rapollura_litoralis.gb',
    './Intoshia_linnei.gb',
]

labels = [
    'Intoshia variabili', 
    'Rhopalura litoralis',
    'Intoshia linei',
]

In [None]:
color_map = {
    'gene': {
        'cox': {'bg': '#00ff00', 'fg': 'black', 'edge': 'black', 'out': 'red'},
        'atp': {'bg': '#ccff33', 'fg': 'black', 'edge': 'black', 'out': 'black'}, 
        'nad': {'bg': '#00ffff', 'fg': 'black', 'edge': 'black', 'out': 'black'},
        'cob': {'bg': '#1aa3ff', 'fg': 'black', 'edge': 'black', 'out': 'black'},
        'ORF': {'bg': '#e6e6ff', 'fg': 'black', 'edge': 'black', 'out': 'black'},
        'orf': {'bg': '#e6e6ff', 'fg': 'black', 'edge': 'black', 'out': 'black'},
        'rpl': {'bg': '#ffff99', 'fg': 'black', 'edge': 'black', 'out': 'black'},
        'rps': {'bg': '#ccffff', 'fg': 'black', 'edge': 'black', 'out': 'black'},
        'default': {'bg': 'lightgray', 'fg': 'white', 'edge': 'black', 'out': 'black'},
    },
    'tRNA': {
        'default': {'bg': '#277c84', 'fg': 'black', 'edge': 'black', 'out': 'black'},
    },
    'rRNA': {
        'default': {'bg': '#8c8c8c', 'fg': 'white', 'edge': 'black', 'out': 'black'},
    },
    'default': {
        'default': {'bg': 'darkgray', 'fg': 'white', 'edge': 'black', 'out': 'black',
                    'font': {
                        'size': 11,
                        'style': 'normal',
                        'weight': 'bold',
                        'va': 'center',
                        'ha': 'center',
                    },
                    'font_out': {
                        'size': 11,
                        'style': 'italic',
                        'weight': 'bold',
                        'va': 'center',
                        'ha': 'center',
                    }},
    }
}

settings = {
    'font_axis': {
        'size': 10,
        'style': 'normal',
        'weight': 'normal',
        'va': 'center',
        'ha': 'center'
    },
    'font_legend': {
        'size': 12,
        'style': 'normal',
        'weight': 'bold',
        'va': 'center',
        'ha': 'center'
    }
}

In [None]:
def parse_header(header):
    for field in header.strip().split('  '):
        if field == '':
            continue
        if field[-2:] == 'bp':
            return int(field[0:-2])

def parse_feature(line1, line2):
    complement = False
    fields = line1.strip().split(' ')
    position_field = fields[-1].split('(')
    if position_field[0] == 'complement':
        complement = True
    position_field = position_field[-1].split(')')[0].split('..')
    position = [int(x) for x in position_field]
    return fields[0], position, line2.strip().split('=')[-1][1:-1], complement
        
def parse_features(file_name):
    print(file_name)
    f = open(file_name, 'r')
    s = f.readlines()
    length = parse_header(s[0])
    label = s[1][10:].strip()
    features = []
    i = 4
    while True:
        line1 = s[i]
        if 'BASE COUNT' in line1:
            break
        line2 = s[i+1]
        i += 2
        features.append(parse_feature(line1, line2))
    f.close
    return label, length, features

def format_label(label):
    label = label.strip()
    while label[-1] in '.,-_:;':
        label = label[:-1]
    label = label.replace('. ', '\n')
    label = label.replace(', ', '\n')
    return label

In [None]:
def polar_sigmoid(theta1, theta2, r1, r2, speed=50, n=100):
    theta = np.linspace(theta1 - theta2, theta2 - theta1, num=n, endpoint=True)
    theta = (theta2 - theta1) / (1 + np.exp(-theta * speed)) + theta1
    r = np.linspace(r1, r2, num=n, endpoint=True)
    return theta, r

In [None]:
def sep(s, thou=',', dec='.'):
    try:
        integer, decimal = s.split(".")
    except ValueError:
        integer = s
        decimal = ''
        pass
    integer = re.sub(r'\B(?=(?:\d{3})+$)', thou, integer)
    if decimal != '':
        return integer + dec + decimal
    else:
        return integer

In [None]:
def draw_main_body(ax, level, levels, length, max_length, hole, gap, thickness, opening, draw_complement):
    if draw_complement:
            ax.bar([0],
                   [0],
                   [2 * np.pi * length / max_length - opening],
                   [hole + level * (thickness + gap) + thickness/2],
                   color='white',
                   edgecolor='black', align='edge')
    else:
        ax.bar([0],
               [thickness],
               [2 * np.pi * length / max_length - opening],
               [hole + level * (thickness + gap)],
               color='white',
               edgecolor='black', align='edge')
    if levels > 1:
        ax.text(0, hole + level * (thickness + gap) + thickness/2, '%d' % (level+1),
                ha=settings['font_legend']['ha'],
                va=settings['font_legend']['va'],
                size=settings['font_legend']['size'], weight=settings['font_legend']['weight'],
                bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.1'))

In [None]:
def get_color(feature):
    if feature[0] in color_map.keys():
        scheme = color_map[feature[0]]
    else:
        scheme = color_map['default']
    scheme_found = False
    for key in scheme.keys():
        if key == feature[2]:
            feature_scheme = scheme[key]
            scheme_found = True
            break
    if not scheme_found:
        for key in scheme.keys():
            if key in feature[2]:
                feature_scheme = scheme[key]
                scheme_found = True
                break
    if not scheme_found:
        feature_scheme = scheme['default']
        scheme_found = True
    if 'font' in feature_scheme.keys():
        font = feature_scheme['font']
    else:
        font = color_map['default']['default']['font']
    if 'font_out' in feature_scheme.keys():
        font_out = feature_scheme['font_out']
    else:
        font_out = color_map['default']['default']['font_out']
    return feature_scheme, font, font_out

In [None]:
def draw_feature(ax, feature, level, length, max_length, hole, gap, thickness, opening, draw_complement):
    width = (2 * np.pi - opening) * (feature[1][1] - feature[1][0]) / max_length
    start = (2 * np.pi - opening) * feature[1][0] / max_length
    color, font, font_out = get_color(feature)
    if draw_complement:
        if feature[3]:
            rects = ax.bar([start], [thickness / 2],
                           [width], [hole + level * (thickness + gap)],
                           color=color['bg'],
                           edgecolor=color['edge'], align='edge')
        else:
            rects = ax.bar([start], [thickness / 2],
                           [width], [hole + level * (thickness + gap) + thickness/2],
                           color=color['bg'],
                           edgecolor=color['edge'], align='edge')
    else:
        rects = ax.bar([start], [thickness],
                       [width], [hole + level * (thickness + gap)],
                       color=color['bg'],
                       edgecolor=color['edge'], align='edge')
    return rects, width, color, font, font_out

In [None]:
def label_feature(ax, feature, level, thickness, gap, hole,
                  rect, width, color, font, font_out, stair_level,
                  draw_complement, width_threshold=[0.25, 0.1, 0.05]):
    x = rect.get_x() + rect.get_width() / 2
    if width > width_threshold[0]:
        stair_level = 0
        y = rect.get_y() + rect.get_height() / 2
        if x < np.pi / 2 or x > 3 * np.pi / 2:
            rotation = -np.rad2deg(x)
        else:
            rotation = 180 - np.rad2deg(x)
        ax.text(x, y, feature[2], rotation=rotation,
                color=color['fg'],
                ha=font['ha'], va=font['va'], size=font['size'],
                style=font['style'], weight=font['weight'])
    elif width > width_threshold[1]:
        stair_level = 0
        y = rect.get_y() + rect.get_height() / 2
        if x < np.pi:
            rotation = 90 - np.rad2deg(x)
        else:
            rotation = 270 - np.rad2deg(x)
        ax.text(x, y, feature[2], rotation=rotation,
                color=color['fg'],
                ha=font['ha'], va=font['va'], size=font['size'],
                style=font['style'], weight=font['weight'])
    elif width > width_threshold[2]:
        stair_level = 0
        y = rect.get_y() + rect.get_height() / 2
        if x < np.pi:
            rotation = 90 - np.rad2deg(x)
        else:
            rotation = 270 - np.rad2deg(x)
        ax.text(x, y, feature[2], rotation=rotation,
                color=color['fg'],
                ha=font['ha'], va=font['va'], size=font['size']-2,
                style=font['style'], weight=font['weight'])
    elif draw_complement:
        d_theta = 0.025 * stair_level / (level + 1)
        x += d_theta
        y0 = hole + level * (thickness + gap) + thickness
        y = y0 + gap/4
        if feature[3]:
            y0 -= thickness/2
        theta, r = polar_sigmoid(x - d_theta, x, y0, y)
        if x < np.pi:
            rotation = 90 - np.rad2deg(x)
        else:
            rotation = 270-np.rad2deg(x)
        ax.plot(theta, r, c='black', zorder=0, linestyle='-', linewidth=1)
        ax.text(x, y + gap / 4, feature[2], rotation=rotation,
                ha=font_out['ha'], va=font_out['va'], size=font_out['size'],
                style=font_out['style'], weight=font_out['weight'])
        stair_level += 1
    return stair_level

In [None]:
def draw_features(ax, features, level, length, max_length, hole, gap, thickness, opening,
                  draw_complement, width_threshold=[0.25, 0.1, 0.05]):
    stair_level = 0
    for feature in features:
        rects, width, color, font, font_out = draw_feature(ax, feature, level, length, max_length,
                                                           hole, gap, thickness, opening, draw_complement)
        stair_level = label_feature(ax, feature, level, thickness, gap, hole,
                                    rects[0], width, color, font, font_out, stair_level,
                                    draw_complement, width_threshold)

In [None]:
def draw_axis(ax, hole, thickness, gap, lengths, max_length, opening,
              labels, levels, centre_legend, ticks_offset):
    ax.bar([0], [hole - ticks_offset], [2 * np.pi], [0],
           color='white', edgecolor='white', align='edge')
    ax.set_theta_direction(-1)
    ax.set_theta_zero_location('N')
    ax.set_axis_off()
    xticks = (2 * np.pi - opening) * (np.arange(1, 1 + max_length // 1000) * 1000) / max_length
    tick_labels = ['%dk' % x for x in np.arange(1, 1 + max_length // 1000)]
    for tick, tick_label in zip(xticks, tick_labels):
        ax.plot((0, tick), (0, hole + thickness/2),
                c='black', zorder=0, linestyle=':', linewidth=1)
        if tick < np.pi:
            rotation = 90 - np.rad2deg(tick)
        else:
            rotation = 270 - np.rad2deg(tick)
        ax.text(tick, hole - 2 * ticks_offset, tick_label, rotation=rotation,
                ha=settings['font_axis']['ha'], va=settings['font_axis']['va'],
                size=settings['font_axis']['size'],
                style=settings['font_axis']['style'], weight=settings['font_axis']['weight'])
    if levels > 1:
        label_text = ''
        for level in range(levels):
            label_text += '%d: %s\n' % (level +1, labels[level])
        x = 0.0
        y = 1.0
        ha = 'left'
        edgecolor='black'
        if centre_legend:
            x = 0.5
            y = 0.5
            ha = 'center'
            edgecolor='none'
        ax.text(x, y, label_text[:-1],
                ha=ha,
                va=settings['font_legend']['va'],
                size=settings['font_legend']['size'], weight=settings['font_legend']['weight'],
                bbox=dict(facecolor='white', edgecolor=edgecolor, boxstyle='round,pad=0.1'),
                transform = ax.transAxes)
    else:
        length_text = sep('%d' % lengths[0])
        label_text = '%s\n%s bp' % (labels[0], length_text)
        ax.text(0.5, 0.5, label_text,
                ha=settings['font_legend']['ha'],
                va=settings['font_legend']['va'],
                size=settings['font_legend']['size'], weight=settings['font_legend']['weight'],
                bbox=dict(facecolor='white', edgecolor='none', boxstyle='round,pad=0.1'),
                transform = ax.transAxes)

In [None]:
def sunburst(files, labels=None, level=0, lengths=None, max_length=None, opening=0.2,
             hole=1, thickness=0.5, gap=0.2,  ticks_offset=0.1, levels=None,
             draw_complement=True, width_threshold=[0.25, 0.1, 0.05], centre_legend=False, ax=None):
    if isinstance(files, (list, tuple, np.ndarray)):
        levels = len(files)
        if labels is None:
            labels = []
        if lengths is None:
            lengths = []
        if max_length is None:
            max_length = 0
            fill_labels = False
            fill_lengths = False
            if not labels:
                fill_labels = True
            if not lengths:
                fill_lengths = True
            for file in files:
                label, length, features = parse_features(file)
                if fill_labels:
                    label = format_label(label)
                    labels.append(label)
                if fill_lengths:
                    lengths.append(length)
                if length > max_length:
                    max_length = length
        for file in files:
            ax, labels, lengths = sunburst(file, labels=labels, level=level, lengths=lengths, max_length=max_length,
                                           opening=opening, hole=hole, thickness=thickness, gap=gap,
                                           ticks_offset=ticks_offset, levels=levels,
                                           draw_complement=draw_complement, width_threshold=width_threshold,
                                           centre_legend=centre_legend, ax=ax)
            level += 1
        return ax, labels, lengths
    elif isinstance(files, str):
        file = files
        if levels is None:
            levels = 1
        label, length, features = parse_features(file)
        label = format_label(label)
        if labels is None:
            labels = [label]
        elif not labels:
            labels.append(label)
        if lengths is None:
            lengths = [length]
        elif not lengths:
            lengths.append(length)
        if max_length is None:
            max_length = length
        ax = ax or plt.subplot(111, projection='polar')
        draw_main_body(ax, level, levels, length, max_length, hole, gap, thickness, opening, draw_complement)
        draw_features(ax, features, level, length, max_length, hole, gap, thickness, opening,
                      draw_complement, width_threshold)
        if level == 0:
            draw_axis(ax, hole, thickness, gap, lengths, max_length, opening, labels, levels, centre_legend, ticks_offset)
        return ax, labels, lengths
    else:
        raise ValueError('wrong FileName argument value')

In [None]:
my_dpi=600
W = 8000
H = 8000
fig = plt.figure(figsize=(W / my_dpi, H / my_dpi), dpi=my_dpi)
ax, _, _ = sunburst(gb_files[:], labels=labels, hole=4 ,thickness=2, gap=0.3, ticks_offset=0.3, opening=0.1,
                 draw_complement=False, width_threshold=[0.25, 0.1, 0.05], centre_legend=True)
fig.savefig('temp.eps', dpi=fig.dpi)

In [None]:
my_dpi=600
W = 8000
H = 8000
fig = plt.figure(figsize=(W / my_dpi, H / my_dpi), dpi=my_dpi)
ax, _, _ = sunburst(gb_files[0], labels=[labels[0]], hole=4 ,thickness=3, gap=2, ticks_offset=0.3, opening=0.0,
                 draw_complement=True, width_threshold=[0.2, 0.1, 0.07], centre_legend=True)
fig.savefig('temp2.eps', dpi=fig.dpi)