In [None]:
import datetime
import calendar
import requests
import plotly.graph_objects as go
import plotly.express as px

def get_city_coordinates(city):
    url = f'https://restapi.amap.com/v3/geocode/geo?key=fe50b339bcfac1c3f9daf919d031a26e&address={city}&output=JSON'

    response = requests.get(url)
    data = response.json()

    if data['status'] == '1' and int(data['count']) >= 1:
        location = data['geocodes'][0]['location']
        # 解析经纬度数据
        lon, lat = location.split(',')
        return lat, lon
    else:
        return None





def get_historical_weather(latitude, longitude, start_date, end_date, timezone):
    url = "https://archive-api.open-meteo.com/v1/archive?latitude=" + latitude + \
          "&longitude=" + longitude + "&start_date=" + start_date + "&end_date=" + \
          end_date + "&daily=temperature_2m_max,temperature_2m_min,rain_sum,precipitation_sum,precipitation_hours,windspeed_10m_max&timezone=" + \
          timezone + "&temperature_unit=celsius"

    r = requests.get(url)
    response_dict = r.json() if r.status_code == 200 else None

    if not response_dict:
        print("No results returned from API")
        return

    return response_dict


def get_all_historical_data(selected_month, latitude, longitude, timezone):
    historical_data = []

    for minus_years in range(50):
        year_graph = datetime.datetime.now().year - 1 - minus_years
        start_date = str(year_graph) + "-" + str(selected_month).rjust(2, '0') + "-01"
        end_date = str(year_graph) + "-" + str(selected_month).rjust(2, '0') + "-" + \
                   str(calendar.monthrange(year_graph, selected_month)[1])

        response_dict = get_historical_weather(latitude, longitude, start_date, end_date,
                                               timezone)

        if not response_dict:
            continue

        historical_data.append(dict(
            year=year_graph,
            daily=response_dict["daily"],
        ))

    return historical_data


def init_graph_data():
    graph_data = {
        "days_month": [],
        "years": [],
        "high_temps": [],
        "low_temps": [],
        "total_precipitation": [],
        "hover_text": []
    }
    return graph_data


def prepare_data_bubble_chart(historical_data):
    graph_data = init_graph_data()

    for i in range(len(historical_data)):
        year_graph = historical_data[i]["year"]

        for j in range(len(historical_data[i]["daily"]["time"])):
            graph_data["days_month"].append(historical_data[i]["daily"]["time"][j][-2:])
            graph_data["high_temps"].append(historical_data[i]["daily"]["temperature_2m_max"][j])
            graph_data["low_temps"].append(historical_data[i]["daily"]["temperature_2m_min"][j])
            graph_data["total_precipitation"].append(historical_data[i]["daily"]["precipitation_sum"][j])
            graph_data["years"].append(year_graph)
            graph_data["hover_text"].append(str(historical_data[i]["daily"]["time"][j]) + "<br>" +
                                            str(historical_data[i]["daily"]["temperature_2m_max"][j]) + " C")

    return graph_data



def draw_bubble_chart(historical_data, selected_month, latitude, longitude, var = 'high_temps'):
    graph_data = prepare_data_bubble_chart(historical_data)

    marker_size_adjust = [value - min(graph_data[var]) + 4 for value in graph_data[var]]

    fig = go.Figure(data=go.Scatter(
        x=graph_data["days_month"],
        y=graph_data["years"],
        mode="markers",
        marker=dict(size=marker_size_adjust,
                    sizemode="area",
                    sizeref=2. * (max(graph_data[var]) - min(graph_data[var])) / (20. ** 2),
                    sizemin=4,
                    color=graph_data[var],
                    colorscale="Rainbow",
                    showscale=True,
                    cmax=40 if max(graph_data[var]) < 40 else max(graph_data[var]),
                    cmin=-30 if min(graph_data[var]) > -30 else min(graph_data[var])),
        hovertext=graph_data["hover_text"],
        hoverinfo="text"
    ))

    txt = {
        'high_temps':'Highest Temperatures',
        'low_temps':'Lowest Temperatures',
        'total_precipitation':'Total Precipitation',
    }

    fig.update_layout(
        xaxis=dict(
            title='Day of the month'
        ),
        yaxis=dict(
            title='Years'
        ),
        title=dict(
            text="<b>Daily " + txt[var] + " for " + calendar.month_name[selected_month] +
                 " over the last 50 years<br>" + "Location latitude: " + latitude + ", longitude: " + longitude +
                 "</b><br>",
            font_size=18
        )
    )

    fig.show()

    return


def prepare_data_heatmap_chart(historical_data, var = 'high_temps'):
    graph_data = init_graph_data()

    
    mapping = {
        'high_temps':'temperature_2m_max',
        'low_temps':'temperature_2m_min',
        'total_precipitation':'precipitation_sum',
    }

    for i in range(len(historical_data[0]["daily"]["time"])):
        graph_data["days_month"].append(historical_data[0]["daily"]["time"][i][-2:])

    for i in range(len(historical_data)):
        year_graph = historical_data[i]["year"]
        graph_data["years"].append(year_graph)

        year = list()
        for j in range(len(historical_data[i]["daily"]["time"])):
            year.append(historical_data[i]["daily"][mapping[var]][j])

        graph_data[var].append(year)

    return graph_data


def draw_heatmap_chart(historical_data, selected_month, latitude, longitude, var = 'high_temps'):
    graph_data = prepare_data_heatmap_chart(historical_data, var = var)


    mapping = {
        'high_temps':'temperature_2m_max',
        'low_temps':'temperature_2m_min',
        'total_precipitation':'precipitation_sum',
    }

    txt = {
        'high_temps':'Highest Temperatures',
        'low_temps':'Lowest Temperatures',
        'total_precipitation':'Total Precipitation',
    }

    fig = px.imshow(img=graph_data[var],
                    labels=dict(x="Day of the month", y="Years", color="Temperature, C"),
                    x=graph_data["days_month"],
                    y=graph_data["years"],
                    aspect="auto",
                    color_continuous_scale="Rainbow",
                    title="<b>Daily " + txt[var] + " for " +
                          calendar.month_name[selected_month] +
                          " over the last 50 years<br>" + "Location latitude: " +
                          latitude + ", longitude: " +
                          longitude + "</b><br></b><br></b><br>",
                    zmax=40,
                    zmin=-30,
                    origin="lower"
                    )

    fig.show()

    return



def prepare_data_line_plot(historical_data):
    graph_data = {
        "years": [],
        "high_temps": [],
        "average_temps": [],
        "low_temps": []
    }

    for i in range(len(historical_data)):
        year_graph = historical_data[i]["year"]
        graph_data["years"].append(year_graph)

        high_temps_year = list()
        for j in range(len(historical_data[i]["daily"]["time"])):
            high_temps_year.append(historical_data[i]["daily"]["temperature_2m_max"][j])

        low_temps_year = list()
        for j in range(len(historical_data[i]["daily"]["time"])):
            low_temps_year.append(historical_data[i]["daily"]["temperature_2m_min"][j])

        graph_data["high_temps"].append(max(high_temps_year))
        graph_data["low_temps"].append(max(low_temps_year))
        graph_data["average_temps"].append(round(sum(high_temps_year) / len(high_temps_year),2))

    return graph_data

def draw_line_plot(historical_data, selected_month, latitude, longitude):
    graph_data = prepare_data_line_plot(historical_data)

    fig = go.Figure()
    # Create and style traces
    color_name = "firebrick" if max(list(graph_data["high_temps"])) > 0 else "royalblue"
    fig.add_trace(go.Scatter(x=graph_data["years"], y=graph_data["high_temps"],
                             name="Highest Temp, C", mode='lines+markers',
                             line=dict(color=color_name, width=2)))
    color_name = "#db82b4" if max(list(graph_data["average_temps"])) > 0 else "#3779ff"
    fig.add_trace(go.Scatter(x=graph_data["years"], y=graph_data["average_temps"],
                             mode='lines+markers', name="Average of the Highest Day Temp, C",
                             line=dict(color=color_name, width=2, dash="dash")))
    color_name = "#f9ab00" if max(list(graph_data["average_temps"])) > 0 else "#75a99c"
    fig.add_trace(go.Scatter(x=graph_data["years"], y=graph_data["low_temps"],
                             mode='lines+markers', name="Average of the Lowest Day Temp, C",
                             line=dict(color=color_name, width=2, dash="dot")))

    # Edit the layout
    fig.update_layout(title='<b>Highest Day Temperature and Average of the Highest(lowest) Day '
                            'Temperatures for ' +
                            calendar.month_name[selected_month] + " over the last 50 years<br>"
                            + "Location latitude: "
                            + latitude + ", longitude: " + longitude + "</b><br>",
                      xaxis_title="Year",
                      yaxis_title="Temperature (degrees C)")

    fig.update_layout(hovermode='x unified')

    fig.show()

    return


def main():

    # 示例城市列表
    city = input("Please enter any location in China:")

    latitude = get_city_coordinates(city)[0]
    longitude = get_city_coordinates(city)[1]

    timezone = "Asia%2FShanghai"

    selected_month = int(input("Please enter the selected month(number):"))
    # selected_month = 3   # input the selected month
    # ANSI转义序列
    bold_red = "\033[1m\033[91m"
    bold_blue = "\033[1m\033[94m"
    bold_green = "\033[1m\033[92m"
    reset = "\033[0m"

    # 打印加粗并红色的文字
    print(bold_red + f"Obtaining the weather data of 【{city}】 for {calendar.month_name[selected_month]} in the past 50 years." + reset)
    # print(f"Obtaining {city}'s weather data for {calendar.month_name[selected_month]} in the past 50 years.")
    historical_data = get_all_historical_data(selected_month, latitude, longitude, timezone)
    print(bold_green + 'Finished!' + reset)
    print(bold_blue + 'Draw Bubble Chart' + reset)
    draw_bubble_chart(historical_data, selected_month, latitude, longitude, var = 'high_temps')
    print(bold_blue + 'Draw Heatmap Chart' + reset)
    draw_heatmap_chart(historical_data, selected_month, latitude, longitude, var = 'high_temps')
    print(bold_blue + 'Draw Line Chart' + reset)
    draw_line_plot(historical_data, selected_month, latitude, longitude) 

    return
if __name__ == "__main__":
  main()

# 第一个可以输入 北京市朝阳区阜通东大街6号
# 第二个可以输入 1-12任何数字