# Creating Better Figures

Using the design principles we have learned, how could we improve on python's default plots?

In [None]:
# importing the libraries for data processing
import numpy as np 
import pandas as pd 

#These two modules will be used to create some basic visualizations
import seaborn as sns
import matplotlib.pyplot as plt

#This is a jupyter magic command that embeds the image generated by matplotlib right after the code cell
%matplotlib inline

In [None]:
#open merged dataset
df_streams = pd.read_csv('data/merged_chart_tracks.csv')
df_streams.head()

In [None]:
#transform date column into a datetime column
df_streams['date'] = pd.to_datetime(df_streams['date'])
df_streams = df_streams.set_index('date')
df_streams.head()

## Line plot of top-streamed artists

In [None]:
mon_df_streams = df_streams.groupby('artist')[['streams']].resample('M').sum().reset_index()
mon_df_streams = mon_df_streams.set_index('date')

#remove April 2020
mon_df_streams = mon_df_streams.loc[:'2020-03-31']
mon_df_streams

In [None]:
total_df_streams = mon_df_streams.groupby('artist')['streams'].sum()\
                                .reset_index().sort_values('streams',ascending=False)
total_df_streams

In [None]:
top10_artists = total_df_streams[:10]['artist'].values
top10_artists

### Basic Plot

In [None]:
plt.figure()    
ax = plt.subplot(111)  

for i,artist_name in enumerate(top10_artists):  
    # Plot each line separately with its own color
    data = mon_df_streams[mon_df_streams['artist']==artist_name]['streams'].cumsum()
    data = data/1e6
    data.plot(ax=ax, color='C'+str(i), label = artist_name)

plt.ylabel('streams (in millions)')
plt.title("Spotify top-streamed artists in the Philippines") 
plt.legend()

### Tidy it up

In [None]:
plt.figure()    

# Remove the plot frame lines. They are unnecessary chartjunk.    
ax = plt.subplot(111)    
  
# Ensure that the axis ticks only show up on the bottom and left of the plot.    
# Ticks on the right and top of the plot are generally unnecessary chartjunk.    
ax.get_xaxis().tick_bottom()    
ax.get_yaxis().tick_left()    

for i,artist_name in enumerate(top10_artists):  
    # Plot each line separately with its own color
    data = mon_df_streams[mon_df_streams['artist']==artist_name]['streams'].cumsum()
    data = data/1e6
    data.plot(ax=ax, color='C'+str(i), label=artist_name)
    
    x_pos = data.index.values[-1] + np.timedelta64(1, 'D')
    y_pos = data.values[-1]   

#x axis entries are obviously dates so we can remove its label
plt.xlabel('')

plt.ylabel('streams (in millions)')
plt.title("Spotify top-streamed artists in the Philippines",x=0.5,y=1.03) 
 
# Place legend outside the plottign area
# Shrink current axis's height by 15% on the bottom
box = ax.get_position()
ax.set_position([box.x0, box.y0 + box.height * 0.15,
                 box.width, box.height * 0.85])

# Put a legend below current axis
ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), ncol=3)


### Better-designed plot

In [None]:
#Make plots big for better resolution and easier parts control
plt.figure(figsize=(10, 10))    

# Remove the plot frame lines. They are unnecessary chartjunk.    
ax = plt.subplot(111)    
ax.spines["top"].set_visible(False)    
ax.spines["bottom"].set_visible(False)    
ax.spines["right"].set_visible(False)    
ax.spines["left"].set_visible(False)    
  
# Ensure that the axis ticks only show up on the bottom and left of the plot.    
# Ticks on the right and top of the plot are generally unnecessary chartjunk.    
ax.get_xaxis().tick_bottom()    
ax.get_yaxis().tick_left()    

for i,artist_name in enumerate(top10_artists):  
    # Plot each line separately with its own color
    data = mon_df_streams[mon_df_streams['artist']==artist_name]['streams'].cumsum()
    data = data/1e6
    data.plot(ax=ax, color='C'+str(i))
    
    x_pos = data.index.values[-1] + np.timedelta64(1, 'D')
    y_pos = data.values[-1]   
    
    if(i==1):
        y_pos = 1.02*y_pos
    elif(i==7):
        y_pos = 0.985*y_pos
    elif(i==8):
        y_pos = 0.98*y_pos
    elif(i==9):
        y_pos = 0.975*y_pos
        
    plt.text(x_pos,y_pos, artist_name, fontsize=14, color='C'+str(i))  

xgrid = np.linspace(ax.get_xlim()[0],ax.get_xlim()[1],10)
for y in np.arange(0,500,50):    
    plt.plot(xgrid, [y] * len(xgrid),\
             ls="--", lw=0.5, color="black", alpha=0.3)     

#Make sure you trim excess whitespace by changing axis limits
plt.ylim([0,450])

#x axis entries are obviously dates so we can remove its label
plt.xlabel('')

# Make sure your axis ticks are large enough to be easily read.    
# You don't want your viewers squinting to read your plot.    
plt.yticks(np.arange(0,500,50),[str(x)+'M' if x>0 else str(x) for x in np.arange(0,500,50) ],  fontsize=14)    
plt.xticks(fontsize=14)   

# Make the title big enough so it spans the entire plot, but don't make it    
# so big that it requires two lines to show.    
  
# Note that if the title is descriptive enough, it is unnecessary to include    
# axis labels; they are self-evident, in this plot's case.    
plt.title("Total monthly Spotify streams of top-streamed artists in the Philippines"\
          ,fontsize=20,\
          x=0.55,y=1.03)    

plt.savefig('figs/plot.png', bbox_inches='tight')