## Create Interactive Plots with Plotly
In this lecture, we will learn how to use plotly library to build interactive plots.

Plotly library: Plotly's Python graphing library makes interactive, publication-quality graphs online. Examples of how to make line plots, scatter plots, area charts, bar charts, error bars, box plots, histograms, heatmaps, subplots, multiple-axes, polar charts, and bubble charts.

In [1]:
# Run this if plotly is not installed yet
#!pip install plotly==5.10.0
# !pip install plotly.express --quiet


In [2]:
import numpy as np
import pandas as pd
import warnings; warnings.filterwarnings("ignore")


# plotly library
from plotly.offline import init_notebook_mode, iplot, plot
import plotly as py
init_notebook_mode(connected=True)
import plotly.graph_objs as go
import plotly.express as px

# matplotlib
import matplotlib.pyplot as plt

In [3]:
# If running plotly in Google colab, you will need to run this custom initalization function
# in each offline plotting cell
def configure_plotly_browser_state():
  import IPython
  display(IPython.core.display.HTML('''
        <script src="/static/components/requirejs/require.js"></script>
        <script>
          requirejs.config({
            paths: {
              base: '/static/base',
              plotly: 'https://cdn.plot.ly/plotly-latest.min.js?noext',
            },
          });
        </script>
        '''))


- `world_rank` -  world rank for the university
- `university_name` -  name of university
- `country` - country of each university
- `teaching` - university score for teaching (the learning environment)
- `international` - university score international outlook (staff, students, research)
- `research` - university score for research (volume, income and reputation)
- `citations` - university score for citations (research influence)
- `income` - university score for industry income (knowledge transfer)
- `total_score` - total score for university, used to determine rank
- `num_students` - number of students at the university
- `student_staff_ratio` - Number of students divided by number of staff
- `international_students` - Percentage of students who are international
- `female_male_ratio` - Female student to Male student ratio
- `year` - year of the ranking (2011 to 2016 included)

In [4]:
# Dataset
dataset = "https://raw.githubusercontent.com/csbfx/advpy122-data/master/timesData.csv"
uni = pd.read_csv(dataset)

In [5]:
uni.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2603 entries, 0 to 2602
Data columns (total 14 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   world_rank              2603 non-null   object 
 1   university_name         2603 non-null   object 
 2   country                 2603 non-null   object 
 3   teaching                2603 non-null   float64
 4   international           2603 non-null   object 
 5   research                2603 non-null   float64
 6   citations               2603 non-null   float64
 7   income                  2603 non-null   object 
 8   total_score             2603 non-null   object 
 9   num_students            2544 non-null   object 
 10  student_staff_ratio     2544 non-null   float64
 11  international_students  2536 non-null   object 
 12  female_male_ratio       2370 non-null   object 
 13  year                    2603 non-null   int64  
dtypes: float64(4), int64(1), object(9)
memor

In [6]:
uni[uni.world_rank.str.contains('=')]

Unnamed: 0,world_rank,university_name,country,teaching,international,research,citations,income,total_score,num_students,student_staff_ratio,international_students,female_male_ratio,year
1841,=39,"University of California, San Diego",United States of America,56.9,42.9,69.8,98.7,56.7,72.2,27233,6.5,11%,48 : 52,2016
1842,=39,"University of California, Santa Barbara",United States of America,52.6,61.5,66.0,99.2,90.4,72.2,22020,27.3,11%,52 : 48,2016
1846,=44,"University of California, Davis",United States of America,60.1,58.4,72.7,84.3,57.3,71.0,35364,13.9,13%,54 : 46,2016
1847,=44,University of Hong Kong,Hong Kong,64.6,99.5,72.8,70.1,53.7,71.0,19835,17.6,38%,53 : 47,2016
1849,=47,Tsinghua University,China,73.3,39.5,83.0,58.8,100.0,70.0,39763,13.7,10%,32 : 68,2016
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1993,=190,Boston College,United States of America,34.1,60.0,29.3,83.3,46.8,49.6,13216,17.4,19%,54 : 46,2016
1995,=193,University of Luxembourg,Luxembourg,25.0,99.8,26.7,84.8,38.1,49.4,5144,15.9,52%,50 : 50,2016
1996,=193,Texas A&M University,United States of America,49.4,47.8,52.4,47.1,46.4,49.4,50657,21.4,9%,47 : 53,2016
1998,=196,Newcastle University,United Kingdom,30.9,84.3,27.5,81.5,34.7,49.2,20174,15.2,29%,50 : 50,2016


In [7]:
# clean the world_rank data since some values starts with an = sign
duni = uni[uni.world_rank.str.contains('=')==False]
dunib = uni[uni.world_rank.str.contains('=')]
dunib.world_rank = uni.world_rank.str.replace('=','')
dunib.world_rank = pd.to_numeric(dunib.world_rank, errors='coerce')
duni.world_rank = pd.to_numeric(duni.world_rank, errors='coerce')
uni = pd.concat([duni, dunib], sort=False)
uni = uni.sort_index(ascending=True)

In [8]:
uni.info()

<class 'pandas.core.frame.DataFrame'>
Index: 2603 entries, 0 to 2602
Data columns (total 14 columns):
 #   Column                  Non-Null Count  Dtype  
---  ------                  --------------  -----  
 0   world_rank              1201 non-null   float64
 1   university_name         2603 non-null   object 
 2   country                 2603 non-null   object 
 3   teaching                2603 non-null   float64
 4   international           2603 non-null   object 
 5   research                2603 non-null   float64
 6   citations               2603 non-null   float64
 7   income                  2603 non-null   object 
 8   total_score             2603 non-null   object 
 9   num_students            2544 non-null   object 
 10  student_staff_ratio     2544 non-null   float64
 11  international_students  2536 non-null   object 
 12  female_male_ratio       2370 non-null   object 
 13  year                    2603 non-null   int64  
dtypes: float64(5), int64(1), object(8)
memory usa

## RBG Colors
You can provide rbg colors for most plots. An easy way is to select a color using Google's color picker (Just google "Color Picker") or you can randomly generate a rbg color below.

In [9]:
# Random RBG color generator
import numpy as np
color = list(np.random.choice(range(256), size=3))
print(color)

[45, 27, 241]


## Line Chart
Let's plot `Citation` and `Teaching` vs `World Rank` of Top 100 Universities

In [10]:
# Run configure_plotly_browser_state() if you are in Google colab
configure_plotly_browser_state()

# get the top 100 universities
df = uni.iloc[:100,:]

# import graph objects as "go"
import plotly.graph_objs as go

# Creating trace1
trace1 = go.Scatter(
                    x = df.world_rank,
                    y = df.citations,
                    mode = "lines",
                    name = "citations",
                    marker = dict(color = 'rgba(172, 135, 11, 0.8)'),
                    text= df.university_name)
# Creating trace2
trace2 = go.Scatter(
                    x = df.world_rank,
                    y = df.teaching,
                    mode = "lines+markers",
                    name = "teaching",
                    marker = dict(color = 'rgba(127, 207, 159, 0.8)'),
                    text= df.university_name)
data = [trace1, trace2]
layout = dict(title = 'Citation and Teaching vs World Rank of Top 100 Universities',
              xaxis= dict(title= 'World Rank',ticklen= 5,zeroline= False)
             )
fig = dict(data = data, layout = layout)
iplot(fig)

In [11]:
# Plotly express way
configure_plotly_browser_state()
line_plot = px.line(df,
                    x="world_rank",
                    y="citations",
                    hover_data="university_name",
                    title="Enter your titles in the plotting function itself")

trace2 = go.Scatter(
                    x = df.world_rank,
                    y = df.teaching,
                    mode = "lines+markers",
                    name = "teaching",
                    marker = dict(color = 'rgba(127, 207, 159, 0.8)'),
                    text= df.university_name)
line_plot.add_trace(trace2)

line_plot.show()

In [12]:
# More exclusively the plotly express way

df = uni.iloc[:100,:]
for_color_test = df.melt(id_vars=[col for col in df.columns if col not in ["teaching", "citations"]], var_name="Score Category", value_name="score")
for_color_test
configure_plotly_browser_state()
line_plot = px.line(for_color_test,
                    x="world_rank",
                    y="score",
                    hover_data="university_name",
                    color="Score Category",
                    title="Melt is so useful sometimes",
                    line_dash = "Score Category")

line_plot.show()

# Scatter plot
Let's create a scatter plot for `citation` vs `world rank` of top 100 universities for 2012, 2014 and 2016.

In [13]:
# Run configure_plotly_browser_state() if you are in Google colab
configure_plotly_browser_state()

df2012 = uni[uni.year == 2012].iloc[:100,:]
df2014 = uni[uni.year == 2014].iloc[:100,:]
df2016 = uni[uni.year == 2016].iloc[:100,:]

# creating trace1
trace1 =go.Scatter(
                    x = df2012.world_rank,
                    y = df2012.citations,
                    mode = "markers",
                    name = "2012",
                    marker = dict(color = 'rgba(31, 237, 108, 0.8)'),
                    text= df2012.university_name)
# creating trace2
trace2 =go.Scatter(
                    x = df2014.world_rank,
                    y = df2014.citations,
                    mode = "markers",
                    name = "2014",
                    marker = dict(color = 'rgba(245, 73, 197, 0.8)'),
                    text= df2014.university_name)
# creating trace3
trace3 =go.Scatter(
                    x = df2016.world_rank,
                    y = df2016.citations,
                    mode = "markers",
                    name = "2016",
                    marker = dict(color = 'rgba(94, 29, 211, 0.8)'),
                    text= df2016.university_name)
data = [trace1, trace2, trace3]
layout = dict(title = 'Citation vs world rank of top 100 universities for 2012, 2014 and 2016',
              xaxis= dict(title= 'World Rank',ticklen=10,zeroline= False),
              yaxis= dict(title= 'Citation',ticklen= 5,zeroline= False)
             )
fig = dict(data = data, layout = layout)
iplot(fig)

In [14]:
# plotly express way
configure_plotly_browser_state()
px.scatter(uni, x="world_rank", y="citations", color="year")

## Bar Charts
Let's plot `citations` and `teaching` of top 10 universities in 2016

In [15]:
# Run configure_plotly_browser_state() if you are in Google colab
configure_plotly_browser_state()

df2016 = uni[uni.year == 2016].iloc[:10,:]

# create trace1
trace1 = go.Bar(
                x = df2016.university_name,
                y = df2016.citations,
                name = "citations",
                marker = dict(color = 'rgba(123, 153, 23, 0.5)',
                             line=dict(color='rgb(0,0,0)',width=1.5)),
                text = df2016.country)
# create trace2
trace2 = go.Bar(
                x = df2016.university_name,
                y = df2016.teaching,
                name = "teaching",
                marker = dict(color = 'rgba(94, 29, 211, 0.5)',
                              line=dict(color='rgb(0,0,0)',width=1.5)),
                text = df2016.country)
data = [trace1, trace2]
layout = go.Layout(title = 'Citation and teaching for top 10 universities in 2016',
                   barmode = "group",
                   margin=dict(r=200, l=110, b=150, t=60))
fig = go.Figure(data = data, layout = layout)
iplot(fig)

In [16]:
configure_plotly_browser_state()
# df2016 = df2016.sort_values(by="citations", ascending=False)
px.bar(df2016, x="university_name", y="citations", hover_data="country")

## Stack plot

In [17]:
# Run configure_plotly_browser_state() if you are in Google colab
configure_plotly_browser_state()

# Stack plots
x = df2016.university_name

trace1 = {
  'x': x,
  'y': df2016.citations,
  'name': 'citation',
  'type': 'bar'
};
trace2 = {
  'x': x,
  'y': df2016.teaching,
  'name': 'teaching',
  'type': 'bar'
};
data = [trace1, trace2];
layout = {
  'barmode': 'relative', # makes it a stacked bar plot
  'title': 'citations and teaching of top 10 universities in 2016',
  'margin': dict(r=200, l=110, b=180, t=70)
};

fig = go.Figure(data = data, layout = layout)
iplot(fig)

## Pie Chart
Let's make a pie chart for `Students rate` of top 5 universities in 2016.

In [18]:
df2016 = uni[uni.year == 2016].iloc[:5,:]
df2016

Unnamed: 0,world_rank,university_name,country,teaching,international,research,citations,income,total_score,num_students,student_staff_ratio,international_students,female_male_ratio,year
1803,1.0,California Institute of Technology,United States of America,95.6,64.0,97.6,99.8,97.8,95.2,2243,6.9,27%,33 : 67,2016
1804,2.0,University of Oxford,United Kingdom,86.5,94.4,98.9,98.8,73.1,94.2,19919,11.6,34%,46 : 54,2016
1805,3.0,Stanford University,United States of America,92.5,76.3,96.2,99.9,63.3,93.9,15596,7.8,22%,42 : 58,2016
1806,4.0,University of Cambridge,United Kingdom,88.2,91.5,96.7,97.0,55.0,92.8,18812,11.8,34%,46 : 54,2016
1807,5.0,Massachusetts Institute of Technology,United States of America,89.4,84.0,88.6,99.7,95.4,92.0,11074,9.0,33%,37 : 63,2016


In [19]:
# Run configure_plotly_browser_state() if you are in Google colab
configure_plotly_browser_state()

df2016 = uni[uni.year == 2016].iloc[:5,:]
pie1 = df2016.num_students
pie1_list = [float(each.replace(',', '')) for each in df2016.num_students]  # str(2,4) => str(2.4) = > float(2.4) = 2.4
labels = df2016.university_name
# figure
fig = {
  "data": [
    {
      "values": pie1_list,
      "labels": labels,
      "domain": {"x": [0, .5]},
      "name": "Student info",
      "text": pie1_list,
       "hovertemplate":"%{label}: <br>Percent: %{percent} </br> # of Students: %{text}",
      "hole": .3, # adding a hole in the middle
      "type": "pie"
    },],
  "layout": {
        "title":"Number of students from the top 5 universities",
        "annotations": [
            { "font": { "size": 12},
              "showarrow": False,
              "text": " ",
                "x": 0.5,
                "y": 0.8
            },
        ]
    }
}
iplot(fig)

## Bubble plots
Let's create a bubble plot for the `University world rank` (first 20) vs `teaching score` with `number of students`(size) and `international score`(color) in 2016

In [20]:
# Run configure_plotly_browser_state() if you are in Google colab
configure_plotly_browser_state()

df2016 = uni[uni.year == 2016].iloc[:20,:]
num_students_size  = [float(each.replace(',', '.')) for each in df2016.num_students]
international_color = [float(each) for each in df2016.international]

data = [
   {
       'y': df2016.teaching,
       'x': df2016.world_rank,
       'mode': 'markers',
       'marker': {
            'color': international_color,
            'size': num_students_size,
            'showscale': True
       },
       "text" :  df2016.university_name # additional info that you see when hovering over
    }
]
layout = {
  'title': 'Top 20 Universities statistics - # of Students(size) and International score(color) in 2016',
  'xaxis' : dict(title = "University World Rank"),
  'yaxis' : dict(title = "Teaching scores")
};
fig = go.Figure(data = data, layout = layout)
iplot(fig)

## Choropleth Maps with plotly.express

Plotly Express is the easy-to-use, high-level interface to Plotly, which operates on a variety of types of data and produces easy-to-style figures.

#### GeoJSON with feature.id
Here we load a GeoJSON file containing the geometry information for US counties, where feature.id is a [FIPS code](https://en.wikipedia.org/wiki/FIPS_county_code).



In [21]:
from urllib.request import urlopen
import json
with urlopen('https://raw.githubusercontent.com/plotly/datasets/master/geojson-counties-fips.json') as response:
    counties = json.load(response)

counties["features"][0]

{'type': 'Feature',
 'properties': {'GEO_ID': '0500000US01001',
  'STATE': '01',
  'COUNTY': '001',
  'NAME': 'Autauga',
  'LSAD': 'County',
  'CENSUSAREA': 594.436},
 'geometry': {'type': 'Polygon',
  'coordinates': [[[-86.496774, 32.344437],
    [-86.717897, 32.402814],
    [-86.814912, 32.340803],
    [-86.890581, 32.502974],
    [-86.917595, 32.664169],
    [-86.71339, 32.661732],
    [-86.714219, 32.705694],
    [-86.413116, 32.707386],
    [-86.411172, 32.409937],
    [-86.496774, 32.344437]]]},
 'id': '01001'}

### Data indexed by id
Here we load unemployment data by county, also indexed by FIPS code.

In [22]:
import pandas as pd
df = pd.read_csv("https://raw.githubusercontent.com/plotly/datasets/master/fips-unemp-16.csv",
                   dtype={"fips": str})
df.head()


Unnamed: 0,fips,unemp
0,1001,5.3
1,1003,5.4
2,1005,8.6
3,1007,6.6
4,1009,5.5


In [23]:
df.unemp.describe()

count    3219.000000
mean        5.465642
std         2.344429
min         1.700000
25%         4.000000
50%         5.000000
75%         6.300000
max        23.500000
Name: unemp, dtype: float64

###Choropleth map using GeoJSON
Note In this example we set `layout.geo.scope` to `usa` to automatically configure the map to display USA-centric data in an appropriate projection. See the [Geo map configuration documentation](https://plotly.com/python/map-configuration/) for more information on scopes.



In [24]:
# import plotly.express as px
configure_plotly_browser_state()


fig = px.choropleth(df, geojson=counties, locations='fips', color='unemp',
                           color_continuous_scale="Viridis",
                           range_color=(0, 12), # essentially (vmax, vmin)
                           scope="usa",
                           labels={'unemp':'unemployment rate'}
                          )
fig.update_layout(margin={"r":0,"t":0,"l":0,"b":0})
fig.show() # if the plot doesn't show, try Restarting Runtime of the Colab Notebook

Output hidden; open in https://colab.research.google.com to view.

#**Combining different plots in one plot using different y-axises**




In [25]:
import plotly.graph_objects as go
import pandas as pd
from datetime import datetime
from plotly.subplots import make_subplots
import numpy as np

configure_plotly_browser_state()


df1=pd.read_csv('https://raw.githubusercontent.com/csbfx/cs133/main/dem_rep_18-29_supporters.csv')
df2=pd.read_csv('https://raw.githubusercontent.com/csbfx/cs133/main/dem_rep_30-44_supporters.csv')
df3=pd.read_csv('https://raw.githubusercontent.com/csbfx/cs133/main/dem_rep_45-59_supporters.csv')
df4=pd.read_csv('https://raw.githubusercontent.com/csbfx/cs133/main/dem_rep_60plus_supporters.csv')

#add color column if you want to change negative values to red
#change the first color parameter to red or whatever color for negative values
#i chose not to use this for now
df1["Color"] = np.where(df1["dem percent gap"]<0, 'darkcyan', 'darkcyan')
df2["Color"] = np.where(df2["dem percent gap"]<0, 'yellow', 'yellow')
df3["Color"] = np.where(df3["dem percent gap"]<0, 'black', 'black')
df4["Color"] = np.where(df4["dem percent gap"]<0, 'mediumpurple', 'mediumpurple')

# Create figure with secondary y-axis
fig = make_subplots(specs=[[{"secondary_y": True}]])

#create lines/traces
fig.add_trace(go.Scatter(x=df1['year'], y=df1['dem'],
                    mode='lines',
                    name='dem 18-29',
                    line=dict(color="Blue", width=2),))
fig.add_trace(go.Scatter(x=df2['year'], y=df2['dem'],
                    mode='lines',
                    name='dem 30-44',
                    opacity=.2,
                    line=dict(color="slateblue", width=2),))
fig.add_trace(go.Scatter(x=df3['year'], y=df3['dem'],
                    mode='lines',
                    name='dem 45-59',
                    opacity=.2,
                    line=dict(color="skyblue", width=2),))
fig.add_trace(go.Scatter(x=df4['year'], y=df4['dem'],
                    mode='lines',
                    name='dem 60+',
                    opacity=.2,
                    line=dict(color="darkblue", width=2),))

fig.add_trace(go.Scatter(x=df1['year'], y=df1['rep'],
                    mode='lines',
                    name='rep 18-29',
                    line=dict(color="Red", width=2),))
fig.add_trace(go.Scatter(x=df2['year'], y=df2['rep'],
                    mode='lines',
                    name='rep 30-44',
                    opacity=.2,
                    line=dict(color="mediumvioletred", width=2),))
fig.add_trace(go.Scatter(x=df3['year'], y=df3['rep'],
                    mode='lines',
                    name='rep 45-59',
                    opacity=.2,
                    line=dict(color="maroon", width=2),))
fig.add_trace(go.Scatter(x=df4['year'], y=df4['rep'],
                    mode='lines',
                    name='rep 60+',
                    opacity=.2,
                    line=dict(color="darkRed", width=2),))
# Create the percentile bars
fig.add_trace(go.Bar(x=df1['year'], y=df1['dem percent gap']
                    , opacity=.7
                    , marker_color=df1['Color']
                    , name='dem percent gap 18-29',),secondary_y=True)
#                     bar=dict(color="green", width=1),),secondary_y=True)
fig.add_trace(go.Bar(x=df2['year'], y=df2['dem percent gap']
                    , opacity=.7
                    , marker_color=df2['Color']
                    , name='dem percent gap 30-44',),secondary_y=True)
#                     bar=dict(color="teal", width=1),),secondary_y=True)
fig.add_trace(go.Bar(x=df3['year'], y=df3['dem percent gap']
                    , opacity=.7
                    , marker_color=df3['Color']
                    , name='dem percent gap 45-59',),secondary_y=True)
#                     bar=dict(color="limegreen", width=1),),secondary_y=True)
fig.add_trace(go.Bar(x=df4['year'], y=df4['dem percent gap']
                    , opacity=.7
                    , marker_color=df4['Color']
                    , name='dem percent gap 60+',),secondary_y=True)
#                     bar=dict(color="darkgreen", width=1),),secondary_y=True)

#update layout
fig.update_layout(title="<b>Political Views by Age Group</b>"
                 , height = 900
                 , width = 1650
                 , xaxis_title='<b>Date</b>'
                 , yaxis_title='Percentage'
                 , template = "plotly" # other templates ['plotly_white','plotly_dark','simple_white','seaborn']
                 )

#fig.update_layout(barmode='stack')

fig.update_yaxes(title_text="<b>Dem Percent Difference</b>"
                , range=[-.15,2]
#                 , tickformat = "%"
                , secondary_y=True # this is what adds this to the right side label
                )

fig.update_yaxes(title_text="<b>Percent of Party</b>"
                , range=[0.05,.7]
                , tickformat = "%"
                , secondary_y=False
                )

#update legend
fig.update_layout(
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=1,
        xanchor="right",
        x=.9)
    , title={
        'y':.865,
        'x':0.455,
        'xanchor': 'center',
        'yanchor': 'top'}
)

fig.show()


#**Principal Component Analysis**

The dimensionality reduction technique we will be using is called the Principal Component Analysis (PCA). It is a powerful technique that arises from linear algebra and probability theory. In essence, it computes a matrix that represents the variation of your data (covariance matrix/eigenvectors), and rank them by their relevance ([explained variance/eigenvalues](https://stats.stackexchange.com/questions/2691/making-sense-of-principal-component-analysis-eigenvectors-eigenvalues#:~:text=As%20it%20is%20a%20square%20symmetric%20matrix%2C%20it%20can%20be%20diagonalized%20by%20choosing%20a%20new%20orthogonal%20coordinate%20system%2C%20given%20by%20its%20eigenvectors%20(incidentally%2C%20this%20is%20called%20spectral%20theorem)%3B%20corresponding%20eigenvalues%20will%20then%20be%20located%20on%20the%20diagonal.%20In%20this%20new%20coordinate%20system%2C%20the%20covariance%20matrix%20is%20diagonal%20and%20looks%20like%20that%3A)). For a video tutorial, see this segment on PCA from the [Coursera ML course](https://youtu.be/rng04VJxUt4?t=98).




####**2D PCA Scatter Plot**
Here are we simply visualize the first two principal components of a PCA, by reducing a dataset from 4-dimension to 2-dimension.

In [26]:
import plotly.express as px
from sklearn.decomposition import PCA

df = px.data.iris()
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]

X

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width
0,5.1,3.5,1.4,0.2
1,4.9,3.0,1.4,0.2
2,4.7,3.2,1.3,0.2
3,4.6,3.1,1.5,0.2
4,5.0,3.6,1.4,0.2
...,...,...,...,...
145,6.7,3.0,5.2,2.3
146,6.3,2.5,5.0,1.9
147,6.5,3.0,5.2,2.0
148,6.2,3.4,5.4,2.3


In [27]:
# Perform PCA
configure_plotly_browser_state()

pca = PCA(n_components=2)
components = pca.fit_transform(X)

fig = px.scatter(components, x=0, y=1, color=df['species'])
fig.show()

#### **Visualize all the principal components**
Now, we apply PCA the same dataset, and retrieve all the components. We use the same px.scatter_matrix trace to display our results, but this time our features are the resulting principal components, ordered by how much variance they are able to explain.

The importance of explained variance is demonstrated in the example below. The subplot between PC3 and PC4 is clearly unable to separate each class, whereas the subplot between PC1 and PC2 shows a clear separation between each species.

In this example, we will use Plotly Express, Plotly's high-level API for building figures.

In [28]:
from sklearn.decomposition import PCA
import plotly.express as px

configure_plotly_browser_state()


df = px.data.iris()
features = ["sepal_width", "sepal_length", "petal_width", "petal_length"]

pca = PCA()
components = pca.fit_transform(df[features])
labels = {
    str(i): f"PC {i+1} ({var:.1f}%)"
    for i, var in enumerate(pca.explained_variance_ratio_ * 100)
}

fig = px.scatter_matrix(
    components,
    labels=labels,
    dimensions=range(4),
    color=df["species"]
)
fig.update_traces(diagonal_visible=False)
fig.show()

With `px.scatter_3d`, you can visualize an additional dimension, which let you capture even more variance.

In [29]:
import plotly.express as px
from sklearn.decomposition import PCA

configure_plotly_browser_state()

df = px.data.iris()
X = df[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']]

pca = PCA(n_components=3)
components = pca.fit_transform(X)

total_var = pca.explained_variance_ratio_.sum() * 100

fig = px.scatter_3d(
    components, x=0, y=1, z=2, color=df['species'],
    title=f'Total Explained Variance: {total_var:.2f}%',
    labels={'0': 'PC 1', '1': 'PC 2', '2': 'PC 3'}
)
fig.show()
