In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

In [2]:
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

In [3]:
from matplotlib.patches import Patch
import matplotlib.patches as mpatches
from matplotlib.lines import Line2D
import seaborn as sns
from matplotlib.path import Path
import matplotlib.patches as patches

In [None]:
from operator import sub
import sys
import collections
import re
from matplotlib.legend_handler import HandlerPatch

top10_chr_info = pd.read_csv("GWEAStop10.chr.txt", sep="\t") # save in utf-8 with notepad as txt format
top10_chr_info.head()
top10_chr_traits = pd.read_csv("GWEAStop10.trait.0925.txt", sep="\t")

def get_disp_ratio(ax):
    # Total figure size
    figW, figH = ax.get_figure().get_size_inches()
    # Axis size on figure
    _, _, w, h = ax.get_position().bounds
    # Ratio of display units
    disp_ratio = (figH * h) / (figW * w)
    # Ratio of data units
    # Negative over negative because of the order of subtraction
    #data_ratio = sub(*ax.get_ylim()) / sub(*ax.get_xlim())
    return disp_ratio 

class HandlerShape(HandlerPatch): # from https://stackoverflow.com/questions/44098362/using-mpatches-patch-for-a-custom-legend
    def create_artists(self, legend, orig_handle,
                       xdescent, ydescent, width, height, fontsize, trans):
        center = 0.5 * width - 0.5 * xdescent, 0.5 * height - 0.5 * ydescent
        maker_scale_up = 1.8
        p = mpatches.Ellipse(xy=center, width=(height + xdescent)*maker_scale_up,
                             height=(height + ydescent)*maker_scale_up)
        self.update_prop(p, orig_handle, legend)
        p.set_transform(trans)
        return [p]

def assign_color_hatch_to_trait(traits, colors, hatches):
    traits_style_dict = {}
    mix = []
    traits = sorted(list(set(list(traits))))
    for h in hatches:
        for c in colors:
            mix.append([c, h])
    #print(f"mix={mix}")
    mix_total_len = len(mix)
    if mix_total_len < len(traits):
        print(f"error: mix color with hatch is empty now, because mix_total_len={mix_total_len}, but trait num={len(traits)} ")
        sys.exit(1)
    for trait in traits:
        traits_style_dict[trait] = mix.pop(0) # [color, hatch]
    
    return traits_style_dict


def plot_rect_with_round_corner_not_fancybox(ax, left_low_point_xy, width, height, round_corner_radius, facecolor="Coral"):
    x, y = left_low_point_xy

    if round_corner_radius * 2 > width:
        print(f"error: round_corner_radius={round_corner_radius} * 2 > width={width}")
        sys.exit(1)
    r = round_corner_radius
    verts=[(x+width-r, y),
           (x+width, y+r),
           (x+width, y+height - r),
           (x+width-r, y+height),
           (x+r, y+height),
           (x, y+height-r),
           (x, y+r),
           (x+r, y),
           (x+width-r, y)]
    codes = [
        Path.MOVETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO,
        Path.LINETO
    ]
    corner_list = []
    corner_list.append(mpatches.Circle((x+width-r, y+r), r, facecolor=facecolor))
    corner_list.append(mpatches.Circle((x+width-r, y+height-r), r, facecolor=facecolor))
    corner_list.append(mpatches.Circle((x+r, y+height-r), r, facecolor=facecolor))
    corner_list.append(mpatches.Circle((x+r, y+r), r, facecolor=facecolor))
    for corner in corner_list:
        ax.add_patch(corner)
    #corner = Path.circle(center=(x+width-r, y+r), radius=r)
        corner = mpatches.Circle((x+width-r, y+r), r, facecolor=facecolor)
   
    path = Path(verts, codes)
    patch = patches.PathPatch(path, lw=0, facecolor=facecolor)
    ax.add_patch(patch)
    return ax


def get_trait_pos_name(chr_traits, chr_name_short, arm, arm_medium_point_pos):
    trait_pos_name_dict = {}
    df = chr_traits[chr_traits["Chr"] == chr_name_short]
    for pos, trait in zip(list(df["Pos"]), list(df["Trait"])):
        if arm == "down":
            if pos <= arm_medium_point_pos: continue
        elif arm == "up":
            if pos > arm_medium_point_pos: continue
        else:
            print(f"error: not support arm={arm} for chr={chr_name_short}, only up or down")
            sys.exit(1)
        if pos not in trait_pos_name_dict:
            trait_pos_name_dict[pos] = []
        if trait in trait_pos_name_dict[pos]:
            print(f"error: trait={trait} occuer more than one time in pos={pos} of chr={chr_name_short}")
            sys.exit(1)
        trait_pos_name_dict[pos].append(trait)
    if arm == "down":
        return collections.OrderedDict(sorted(trait_pos_name_dict.items()))
    else:
        return collections.OrderedDict(sorted(trait_pos_name_dict.items(), reverse=True))
        

def plot_traits_for_arm(ax, arm_width, chr_name_short, all_patches_dict,all_patches_label, traits_style_dict, arm_trait_pos_name_dict, x_limit_max, 
                        arm_right_x, space_arm_with_trait, trait_circle_radius, arm, arm_medium_point_pos, y_start_point=1,free_width_cursor=0):
    if arm == "down":
        yshift_unit = -1
    elif arm == "up":
         yshift_unit = 1
    else:
        print(f"error: arm={arm}, but only support down or up in plot_traits_for_arm")
        sys.exit(1)
    y_last_trait = y_start_point
    x_last_trait = arm_right_x + space_arm_with_trait
    y_cursor = y_start_point #上一个trait的y

    free_width_space = trait_circle_radius * 2
    zorder = 0
    
    flag = 0
    if chr_name_short == 3 and arm == "down":
        flag = 1
    
    #print(f"chr_name_short={chr_name_short} ,typeof={type(chr_name_short)}")
    last_pos_trait_num = 0
    pos_index = 0
    for pos in arm_trait_pos_name_dict: # pos -> traits_list
        traits_list = arm_trait_pos_name_dict[pos]
     
        mutation_y = y_start_point + yshift_unit * abs(arm_medium_point_pos - pos)
        y_trait_cursor = mutation_y + yshift_unit * -1 * trait_circle_radius
        
        zorder +=1
        

        if arm == "down":
            if y_trait_cursor <= y_cursor: # 相邻trait之间无重叠
                y_trait = y_trait_cursor + yshift_unit * trait_circle_radius
                free_width_space = 0
                free_width_cursor = 0
                y_cursor = y_trait + yshift_unit * trait_circle_radius
            else: # 相邻trait之间有重叠
                free_width_space = trait_circle_radius * 2
                if pos_index >=2 and (free_width_space + 2*trait_circle_radius * len(traits_list)) <= free_width_cursor and len(traits_list) <= 1 and last_pos_trait_num <=1:
                    y_cursor = y_last_trait
                    y_trait = y_cursor + yshift_unit * -1 * trait_circle_radius
                    free_width_space = trait_circle_radius * 2
                else:
                        #x_limit_max - free_width_cursor - (arm_right_x+space_arm_with_trait) +
                    y_trait = y_cursor + yshift_unit * trait_circle_radius
                    y_cursor = y_trait + yshift_unit * trait_circle_radius
                    free_width_space = 0
                    free_width_cursor = 0
        elif arm == "up":
            if y_trait_cursor > y_cursor: # 相邻trait之间无重叠
                y_trait = y_trait_cursor + yshift_unit * trait_circle_radius
                free_width_space = 0
                free_width_cursor = 0
                y_cursor = y_trait + yshift_unit * trait_circle_radius
            else: # 相邻trait之间有重叠
                free_width_space = trait_circle_radius * 2
                if pos_index >=2 and (free_width_space + 2*trait_circle_radius * len(traits_list)) <= free_width_cursor and len(traits_list) <= 1 and last_pos_trait_num <=1:
                    y_cursor = y_last_trait
                    y_trait = y_cursor + yshift_unit * -1 * trait_circle_radius
                    free_width_space = trait_circle_radius * 2
                else:
                    y_trait = y_cursor + yshift_unit * trait_circle_radius    
                    y_cursor = y_trait + yshift_unit * trait_circle_radius
                    free_width_space = 0
                    free_width_cursor = 0
            
        #if flag: print(f"pos={pos}, free_width_cursor={free_width_cursor}, free_width_space={free_width_space}")
            
        
        pos_index +=1
        if free_width_cursor == 0 :
            x_last_trait = arm_right_x + space_arm_with_trait
            
        ## plot linker line between mutation pos and trait
        back_x = 0
        if arm == "down":
            if mutation_y < y_trait:
                y_trait -= 2* trait_circle_radius
                back_x = 1
        else:
            if mutation_y > y_trait:
                y_trait += 2* trait_circle_radius 
                back_x = 1
        if back_x == 1:
            x_last_trait = arm_right_x + space_arm_with_trait
            y_cursor = y_cursor + yshift_unit * 2* trait_circle_radius
        else:
            x_last_trait += free_width_space
        #if free_width_cursor
        y_last_trait = y_cursor
        
        ## linker between pos and trait
        verts=[(arm_right_x,mutation_y),(arm_right_x + space_arm_with_trait,y_trait)]
        #verts=[(arm_right_x,mutation_y),(x_last_trait,y_trait)]
        codes = [
            Path.MOVETO,
            Path.LINETO,
        ]
        path = Path(verts, codes)
        patch = patches.PathPatch(path, lw=0.3, zorder=zorder)
        ax.add_patch(patch)
        
        # add mutatation pos tick 
        verts=[(arm_right_x-arm_width, mutation_y), (arm_right_x,mutation_y)]
        codes = [
            Path.MOVETO,
            Path.LINETO,
        ]
        path = Path(verts, codes)
        patch = patches.PathPatch(path, lw=0.5)
        ax.add_patch(patch)
    
        last_pos_trait_num = len(traits_list)    
        
        for j,trait in enumerate(traits_list):
            color, hatch = traits_style_dict[trait]
            #patch_x_center  = arm_right_x + space_arm_with_trait + free_width_space + trait_circle_radius  + trait_circle_radius * 2 * j
            patch_x_center  = x_last_trait + trait_circle_radius
            patch_y_center = y_trait
            free_width_cursor = x_limit_max - (patch_x_center + trait_circle_radius)
            x_last_trait = patch_x_center + trait_circle_radius
            if patch_x_center + trait_circle_radius > x_limit_max: 
                x_last_trait = arm_right_x + space_arm_with_trait
                free_width_cursor = 0
                print(f"error: patch_x_center={patch_x_center}+ trait_circle_radius={trait_circle_radius} > x_limit_max={x_limit_max}")
                #sys.exit(1)
                break
            # https://matplotlib.org/3.1.1/gallery/shapes_and_collections/hatch_demo.html
            p = mpatches.Circle((patch_x_center, patch_y_center), trait_circle_radius, facecolor=color, hatch=hatch, zorder=zorder)
            ax.add_patch(p)
            p2 = mpatches.Circle((0, 0), 1, facecolor=color, hatch=hatch)
            p_label = trait
            if p_label not in all_patches_label:
                all_patches_dict[p_label] = p2
                all_patches_label.append(p_label)
            
    return ax, all_patches_dict, all_patches_label, y_last_trait
        
def xx():
    if i >=2: colors = colors[1:]
    for j in range(len(colors)):
        hatch = "*" #
        patch_x =  circle_radius  + circle_radius * (2*j)
        patch_y = 0
        if patch_x + circle_radius > x_limit_max: continue
        # https://matplotlib.org/3.1.1/gallery/shapes_and_collections/hatch_demo.html
        p = mpatches.Circle((patch_x, patch_y), circle_radius, facecolor=colors[j], hatch=hatch)
        ax.add_patch(p)
        p2 = mpatches.Circle((0, 0), 1, facecolor=colors[j], hatch=hatch)
        p_label = colors[j]
        if p_label not in all_patches_label:
            all_patches_dict[p_label] = p2
            all_patches_label.append(p_label)
    

def pytrait(ylim_for_all_traits=[]):
    ## get chr data
    chr_info = top10_chr_info # chr	len	pos, chr=Pp01
    chr_traits = top10_chr_traits # Trait	Chr	Pos, Chr=4

    traits = chr_traits['Trait']
    ncols = len(chr_info.index)
    bottom = 0.13
    #chr_xlimit = 100
   
    if ylim_for_all_traits == []:
        chr_ylimit_max = max(top10_chr_info['pos']) * 2
        chr_ylimit_min = -1 * max(chr_info['len']-chr_info['pos']) * 2
    else:
        chr_ylimit_min, chr_ylimit_max = ylim_for_all_traits
        
    circle_radius = 1000000
    #colors = sns.color_palette("Paired").as_hex()
    colors = ['#a6cee3', '#1f78b4', '#b2df8a', '#33a02c', '#fb9a99', '#e31a1c', '#fdbf6f', '#ff7f00', '#cab2d6', '#6a3d9a', '#ffff99', '#b15928']
    #hatches = ['', '-', '+', 'x', '\\', '*', '\\','o', '.']
    hatches = ['', '--', '++', '////', '\\\\']
    traits_style_dict = assign_color_hatch_to_trait(traits, colors, hatches)

    plt.clf()
    fig, axes = plt.subplots(ncols=ncols, sharey=True, figsize=(16,16))
    fig.subplots_adjust(bottom=bottom)
    #fig.subplots_adjust(wspace=0)
    all_patches_dict = {}
    all_patches_label = []
    y_last_trait_down_list = []
    y_last_trait_up_list = []

    #plt.setp(axes, title='Test')
    #fig.suptitle('An overall title', size=20)
    for i,ax in enumerate(axes):
        print(i)
        chr_name = chr_info.iloc[i]['chr']
        ax.set_title(chr_name)
        ax.xaxis.set_visible(False)
        #axes_pos = ax.get_position()
    
        xy_scale_ratio = get_disp_ratio(ax)
        #print(xy_scale_ratio)
        x_limit_max = (chr_ylimit_max - chr_ylimit_min)/xy_scale_ratio
        #print(x_limit_max)
        x_limit_min = 1
    
        #print(ax.get_ylim())
        #print(ax.get_xlim())
        #print(i)
        ax.set_ylim(chr_ylimit_min,chr_ylimit_max)
        ax.set_xlim(x_limit_min, x_limit_max)
        #print(ax.get_ylim())
        ax.set_aspect('equal')
        #print(ax.get_ylim())
    
        # Hide the right and top spines
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        if i >=1:
            ax.spines['left'].set_visible(False)
            ax.yaxis.set_visible(False)
        
 
        ## plot chr arms
        x_limit_min, x_limit_max = ax.get_xlim()
        x_limit_len = x_limit_max - x_limit_min
        arm_width = x_limit_len * 0.08 # bp
        arm_x_start = x_limit_len * 0.03 # bp
    
        ### plot down arm
        down_arm_width = arm_width
        down_arm_height = int(chr_info.iloc[i]['len']) - int(chr_info.iloc[i]['pos'])
        down_arm_x_start = arm_x_start 
        round_corner_radius = down_arm_width * 0.3 # round corner
        down_arm_facecolor = "DeepSkyBlue" # "Coral"
        down_arm_y_max = 1
        down_arm_y_min = down_arm_y_max - down_arm_height
        down_arm_x_min = down_arm_x_start
        down_arm_x_max = down_arm_x_min + down_arm_width
        left_low_point_xy = [down_arm_x_min, down_arm_y_min]
        ax = plot_rect_with_round_corner_not_fancybox(ax, left_low_point_xy, down_arm_width, down_arm_height, round_corner_radius, facecolor=down_arm_facecolor)
    
        ### plot up arm
        up_arm_width = arm_width
        up_arm_height = int(chr_info.iloc[i]['pos'])
        up_arm_x_start = arm_x_start 
        round_corner_radius = up_arm_width * 0.3 # round corner
        up_arm_facecolor = "Coral"
        up_arm_y_max = up_arm_height
        up_arm_y_min = 1
        up_arm_x_min = up_arm_x_start
        up_arm_x_max = up_arm_x_min + up_arm_width
        left_low_point_xy = [up_arm_x_min, up_arm_y_min]
        ax = plot_rect_with_round_corner_not_fancybox(ax, left_low_point_xy, up_arm_width, up_arm_height, round_corner_radius, facecolor=up_arm_facecolor)

        ## plot traits : chr_traits = top10_chr_traits # Trait	Chr	Pos, Chr=4
        trait_circle_radius = x_limit_len * 0.04 # trait cicle radius
        space_arm_with_trait = trait_circle_radius * 2
        arm_right_x = up_arm_x_max
        chr_name_short = re.sub("Pp0", "", chr_name)
        chr_name_short = int(chr_name_short)
        traits_name_pos = chr_traits[chr_traits['Chr']==chr_name_short]
        chr_medium_point_pos = int(chr_info[chr_info['chr']==chr_name]['pos'])
        arm_medium_point_pos = chr_medium_point_pos
    
        ### plot traits of down arm
        down_arm_trait_pos_name_dict = get_trait_pos_name(chr_traits, chr_name_short, "down", chr_medium_point_pos)
        arm_right_x = down_arm_x_max
        arm = "down"
        ax, all_patches_dict,all_patches_label, y_last_trait_down = plot_traits_for_arm(ax, arm_width, chr_name_short, all_patches_dict, all_patches_label, traits_style_dict, 
                         down_arm_trait_pos_name_dict, x_limit_max, arm_right_x, space_arm_with_trait, 
                         trait_circle_radius, arm, arm_medium_point_pos, y_start_point=1)
    
        ### plot traits of up arm
        up_arm_trait_pos_name_dict = get_trait_pos_name(chr_traits, chr_name_short, "up", chr_medium_point_pos)
        arm_right_x = up_arm_x_max
        arm = "up"
        ax, all_patches_dict,all_patches_label, y_last_trait_up = plot_traits_for_arm(ax, arm_width, chr_name_short, all_patches_dict, all_patches_label, traits_style_dict,
                         up_arm_trait_pos_name_dict, x_limit_max, arm_right_x, space_arm_with_trait, 
                         trait_circle_radius, arm, arm_medium_point_pos, y_start_point=1)
    
        y_last_trait_down_list.append(y_last_trait_down)
        y_last_trait_up_list.append(y_last_trait_up)
        print(f"chr {i}, y_last_trait_down={y_last_trait_down} ,y_last_trait_up={y_last_trait_up}")

    if ylim_for_all_traits == []:
        print(y_last_trait_down_list)
        print(y_last_trait_up_list)
        return -1+int(min(y_last_trait_down_list)), 1 + int(max(y_last_trait_up_list))

    ### sort legend by label
    all_patches_label = sorted(all_patches_label)
    all_patches_list = [ all_patches_dict[plabel] for plabel in all_patches_label]

    ### legend by method2:
    fig.legend(all_patches_list, all_patches_label,bbox_to_anchor=(0.08, bottom*1.2), loc='upper left', ncol=5, handletextpad=-0.3,
           handler_map={mpatches.Circle: HandlerShape()}) # from https://stackoverflow.com/questions/44098362/using-mpatches-patch-for-a-custom-legend
    
    #plt.savefig("plot2_0924.svg", dpi=300, bbox_inches = 'tight')
    plt.savefig("plot2_0925.pdf", dpi=300, bbox_inches = 'tight')
    #plt.tight_layout()
    #plt.show()
    
ylim_for_all_traits = pytrait()
pytrait(ylim_for_all_traits=ylim_for_all_traits)