In [60]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go

In [70]:
#read the data
data = pd.read_csv("FastFoodNutritionMenuV2.csv")#use your own path
#print the first few rows of the dataframe
print(data.head(10))

      Company                                 Item Calories  \
0  McDonald’s                            Hamburger      250   
1  McDonald’s                         Cheeseburger      300   
2  McDonald’s                  Double Cheeseburger      440   
3  McDonald’s                             McDouble      390   
4  McDonald’s         Quarter Pounder® with Cheese      510   
5  McDonald’s  Double Quarter Pounder® with Cheese      740   
6  McDonald’s                             Big Mac®      540   
7  McDonald’s                        Big N’ Tasty®      460   
8  McDonald’s            Big N’ Tasty® with Cheese      510   
9  McDonald’s                 Angus Bacon & Cheese      790   

  Calories from\nFat Total Fat\n(g) Saturated Fat\n(g) Trans Fat\n(g)  \
0                 80              9                3.5            0.5   
1                110             12                  6            0.5   
2                210             23                 11            1.5   
3             

In [62]:
#data cleaning
#check for missing values
print(data.isnull().sum())
#注意到如果全部删除缺失值会丢失大量数据，所以只在进行相关运算的时候对缺失值进行处理
# For example, when calculating average calories, we can drop rows with missing Calories values
# Similarly, when analyzing by Company, we can drop rows with missing Company values

Company                    0
Item                       0
Calories                   1
Calories from\nFat       506
Total Fat\n(g)            57
Saturated Fat\n(g)        57
Trans Fat\n(g)            57
Cholesterol\n(mg)          1
Sodium \n(mg)              1
Carbs\n(g)                57
Fiber\n(g)                57
Sugars\n(g)                1
Protein\n(g)              57
Weight Watchers\nPnts    261
dtype: int64


In [63]:
#check the data types of each column
print(data.dtypes)
#all dtype are object, so we need to convert them to numeric to perform calculations

Company                  object
Item                     object
Calories                 object
Calories from\nFat       object
Total Fat\n(g)           object
Saturated Fat\n(g)       object
Trans Fat\n(g)           object
Cholesterol\n(mg)        object
Sodium \n(mg)            object
Carbs\n(g)               object
Fiber\n(g)               object
Sugars\n(g)              object
Protein\n(g)             object
Weight Watchers\nPnts    object
dtype: object


这里为什么要检查数据类型？因为在后续的求平均，求分布的时候需要数据是数值类型的(numeric类型)，比如int，float这类。而在读取csv文件时python可能会将其中的数据识别成object类型（这是dataframe对非数值类的统称，比如string，list,这些）。如果不做检查我们后续做数值计算的时候会报错

In [64]:
#calculate the average calories for each restaurant
#get a slice of data that only includes company and calories, and drop rows with missing values in these columns
data_calories = data[["Company", "Calories"]].dropna()

#change the Calories column to numeric, coercing errors to NaN
data_calories["Calories"] = pd.to_numeric(data_calories["Calories"], errors='coerce')

print(data_calories)


         Company  Calories
0     McDonald’s     250.0
1     McDonald’s     300.0
2     McDonald’s     440.0
3     McDonald’s     390.0
4     McDonald’s     510.0
...          ...       ...
1143   Pizza Hut     230.0
1144   Pizza Hut     310.0
1145   Pizza Hut     120.0
1146   Pizza Hut     200.0
1147   Pizza Hut     260.0

[1147 rows x 2 columns]


下面这一步就是先对data_calories按照公司名进行分类（逻辑和SQL的groupby是一致的），然后再在各小的group内对Calories列求平均。最后.reset_index()是将结果转为一个新的dataframe方便后续可视化（当然，这不是必要的，只是加上后会更加方便）

In [65]:
#group by company and caculate the mean calories
avg_calories = data_calories.groupby("Company")["Calories"].mean().reset_index()

#display the average calories for each restaurant
print(avg_calories)

       Company    Calories
0  Burger King  359.189944
1          KFC  215.229358
2   McDonald’s  284.618902
3    Pizza Hut  253.378378
4    Taco Bell  292.166667
5      Wendy’s  322.500000


In [66]:
#plot the average calories for each restaurant
fig = px.bar(avg_calories, x = "Company", y = "Calories", title = "Average Calories by Restaurant")

fig.show()

后续这里是准备对各个公司的产品的营养元素的分布进行可视化，思路为：先对dataframe按照公司进行分类，然后对各个营养元素列进行求和，最后按照公司情况绘制饼状图即可。注意，需要手动提取需要的列，而且列名存在换行符需要额外注意。可以通过下面这个单元格的方法获取所有列名，直观的看到是否有换行符(\n)等格式问题。

In [67]:
columns = data.columns.tolist()
print(columns)

['Company', 'Item', 'Calories', 'Calories from\nFat', 'Total Fat\n(g)', 'Saturated Fat\n(g)', 'Trans Fat\n(g)', 'Cholesterol\n(mg)', 'Sodium \n(mg)', 'Carbs\n(g)', 'Fiber\n(g)', 'Sugars\n(g)', 'Protein\n(g)', 'Weight Watchers\nPnts']


In [68]:
# get those nutrition columns
nutrient_cols = ["Total Fat\n(g)", 'Cholesterol\n(mg)', 'Sodium \n(mg)', 'Carbs\n(g)', 'Fiber\n(g)', 'Sugars\n(g)', 'Protein\n(g)']

for company in data["Company"].unique():
    row = data[data["Company"] == company]
    # Convert nutrient columns to numeric and sum up all the elements
    values = [pd.to_numeric(row[col], errors='coerce').sum() for col in nutrient_cols]
    #label them
    labels = ["Fat", "Cholesterol", "Sodium", "Carbs", "Fiber", "Sugars", "Protein"]
    #plot the pie chart
    fig = go.Figure(data=[go.Pie(labels=labels, values=values, hole=0.3)])
    fig.update_layout(title=f"{company} distribution of nutrition elements")
    fig.show()


In [69]:
#notice that the sodium intake is very high, which may be a health concern
#get rid of the sodium, and replot the pie charts
for company in data["Company"].unique():
    #catch the data for each company
    row = data[data["Company"] == company]
    # Convert nutrient columns to numeric and sum up all the elements, excluding Sodium
    values = [pd.to_numeric(row[col], errors='coerce').sum() for col in nutrient_cols if col != 'Sodium \n(mg)']
    #label them, excluding Sodium
    labels = ["Fat", "Cholesterol", "Carbs", "Fiber", "Sugars", "Protein"]
    #plot the pie chart
    fig = go.Figure(data=[go.Pie(labels=labels, values=values, hole=0.3)])
    fig.update_layout(title=f"{company} distribution of nutrition elements (excluding Sodium)")
    fig.show()