## Importing Libraries

In [0]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window

**1501. Countries You Can Safely Invest In (Medium)**

**Table Person:**

| Column Name    | Type    |
|----------------|---------|
| id             | int     |
| name           | varchar |
| phone_number   | varchar |

id is the column of unique values for this table.
Each row of this table contains the name of a person and their phone number.
Phone number will be in the form 'xxx-yyyyyyy' where xxx is the country code (3 characters) and yyyyyyy is the phone number (7 characters) where x and y are digits. Both can contain leading zeros.
 
**Table Country:**

| Column Name    | Type    |
|----------------|---------|
| name           | varchar |
| country_code   | varchar |

country_code is the column of unique values for this table.
Each row of this table contains the country name and its code. country_code will be in the form 'xxx' where x is digits.
 
**Table Calls:**

| Column Name | Type |
|-------------|------|
| caller_id   | int  |
| callee_id   | int  |
| duration    | int  |

This table may contain duplicate rows.
Each row of this table contains the caller id, callee id and the duration of the call in minutes. caller_id != callee_id
 
A telecommunications company wants to invest in new countries. The company intends to invest in the countries where the average call duration of the calls in this country is strictly greater than the global average call duration.

**Write a solution to find the countries where this company can invest.**

Return the result table in any order.

The result format is in the following example.

**Example 1:**

**Input:** 

**Person table:**
| id | name     | phone_number |
|----|----------|--------------|
| 3  | Jonathan | 051-1234567  |
| 12 | Elvis    | 051-7654321  |
| 1  | Moncef   | 212-1234567  |
| 2  | Maroua   | 212-6523651  |
| 7  | Meir     | 972-1234567  |
| 9  | Rachel   | 972-0011100  |

**Country table:**
| name     | country_code |
|----------|--------------|
| Peru     | 051          |
| Israel   | 972          |
| Morocco  | 212          |
| Germany  | 049          |
| Ethiopia | 251          |

**Calls table:**
| caller_id | callee_id | duration |
|-----------|-----------|----------|
| 1         | 9         | 33       |
| 2         | 9         | 4        |
| 1         | 2         | 59       |
| 3         | 12        | 102      |
| 3         | 12        | 330      |
| 12        | 3         | 5        |
| 7         | 9         | 13       |
| 7         | 1         | 3        |
| 9         | 7         | 1        |
| 1         | 7         | 7        |

**Output:** 
| country  |
|----------|
| Peru     |

**Explanation:** 
- The average call duration for Peru is (102 + 102 + 330 + 330 + 5 + 5) / 6 = 145.666667
- The average call duration for Israel is (33 + 4 + 13 + 13 + 3 + 1 + 1 + 7) / 8 = 9.37500
- The average call duration for Morocco is (33 + 4 + 59 + 59 + 3 + 7) / 6 = 27.5000 
- Global call duration average = (2 * (33 + 4 + 59 + 102 + 330 + 5 + 13 + 3 + 1 + 7)) / 20 = 55.70000
- Since Peru is the only country where the average call duration is greater than the global average, it is the only recommended country.

In [0]:
person_data_1501 = [
    (3, "Jonathan", "051-1234567"),
    (12, "Elvis", "051-7654321"),
    (1, "Moncef", "212-1234567"),
    (2, "Maroua", "212-6523651"),
    (7, "Meir", "972-1234567"),
    (9, "Rachel", "972-0011100"),
]

person_columns_1501 = ["id", "name", "phone_number"]
person_df_1501 = spark.createDataFrame(person_data_1501, person_columns_1501)
person_df_1501.show()

country_data_1501 = [
    ("Peru", "051"),
    ("Israel", "972"),
    ("Morocco", "212"),
    ("Germany", "049"),
    ("Ethiopia", "251"),
]

country_columns_1501 = ["country", "country_code"]
country_df_1501 = spark.createDataFrame(country_data_1501, country_columns_1501)
country_df_1501.show()

calls_data_1501 = [
    (1, 9, 33),
    (2, 9, 4),
    (1, 2, 59),
    (3, 12, 102),
    (3, 12, 330),
    (12, 3, 5),
    (7, 9, 13),
    (7, 1, 3),
    (9, 7, 1),
    (1, 7, 7),
]

calls_columns_1501 = ["caller_id", "callee_id", "duration"]
calls_df_1501 = spark.createDataFrame(calls_data_1501, calls_columns_1501)
calls_df_1501.show()



+---+--------+------------+
| id|    name|phone_number|
+---+--------+------------+
|  3|Jonathan| 051-1234567|
| 12|   Elvis| 051-7654321|
|  1|  Moncef| 212-1234567|
|  2|  Maroua| 212-6523651|
|  7|    Meir| 972-1234567|
|  9|  Rachel| 972-0011100|
+---+--------+------------+

+--------+------------+
| country|country_code|
+--------+------------+
|    Peru|         051|
|  Israel|         972|
| Morocco|         212|
| Germany|         049|
|Ethiopia|         251|
+--------+------------+

+---------+---------+--------+
|caller_id|callee_id|duration|
+---------+---------+--------+
|        1|        9|      33|
|        2|        9|       4|
|        1|        2|      59|
|        3|       12|     102|
|        3|       12|     330|
|       12|        3|       5|
|        7|        9|      13|
|        7|        1|       3|
|        9|        7|       1|
|        1|        7|       7|
+---------+---------+--------+



In [0]:
person_df_1501 = person_df_1501\
                    .withColumn("country_code", split(col("phone_number"), "-")[0])

In [0]:
df_person_country_1501 = person_df_1501\
                            .join(country_df_1501, on="country_code", how="left")

In [0]:
df_calls_country_1501 = (
    calls_df_1501
    .join(
        df_person_country_1501.select(col("id").alias("caller_id"), col("country").alias("caller_country")),
        on="caller_id", how="left"
    )
    .join(
        df_person_country_1501.select(col("id").alias("callee_id"), col("country").alias("callee_country")),
        on="callee_id", how="left"
    )
)

In [0]:
df_country_duration_1501 = df_calls_country_1501 \
                                .select(col("caller_country").alias("country"), col("duration")) \
                            .union(
                                    df_calls_country_1501\
                                .select(col("callee_country").alias("country"), col("duration"))
                                )

In [0]:
df_country_avg_1501 = df_country_duration_1501\
                            .groupBy("country").agg(avg("duration").alias("avg_duration"))

In [0]:
global_avg_1501 = df_country_duration_1501\
                            .agg(avg("duration").alias("global_avg")).collect()[0]["global_avg"]

In [0]:
df_country_avg_1501\
    .filter(col("avg_duration") > global_avg_1501).select("country").show()

+-------+
|country|
+-------+
|   Peru|
+-------+

