## 1468. Calculate Salaries
### Table: Salaries

| Column Name   | Type    |
|---------------|---------|
| company_id    | int     |
| employee_id   | int     |
| employee_name | varchar |
| salary        | int     |

(company_id, employee_id) is the primary key for this table.  
This table contains the company id, the id, the name and the salary for an employee.

---

Write an SQL query to find the salaries of the employees after applying taxes.

The tax rate is calculated for each company based on the following criteria:

- 0% If the max salary of any employee in the company is less than 1000$.
- 24% If the max salary of any employee in the company is in the range [1000, 10000] inclusive.
- 49% If the max salary of any employee in the company is greater than 10000$.

Return the result table in any order. Round the salary to the nearest integer.

---

### Salaries table:

| company_id | employee_id | employee_name | salary |
|------------|-------------|---------------|--------|
| 1          | 1           | Tony          | 2000   |
| 1          | 2           | Pronub        | 21300  |
| 1          | 3           | Tyrrox        | 10800  |
| 2          | 1           | Pam           | 300    |
| 2          | 7           | Bassem        | 450    |
| 2          | 9           | Hermione      | 700    |
| 3          | 7           | Bocaben       | 100    |
| 3          | 2           | Ognjen        | 2200   |
| 3          | 13          | Nyancat       | 3300   |
| 3          | 15          | Morninngcat   | 1866   |

---

### Result table:

| company_id | employee_id | employee_name | salary |
|------------|-------------|---------------|--------|
| 1          | 1           | Tony          | 1020   |
| 1          | 2           | Pronub        | 10863  |
| 1          | 3           | Tyrrox        | 5508   |
| 2          | 1           | Pam           | 300    |
| 2          | 7           | Bassem        | 450    |
| 2          | 9           | Hermione      | 700    |
| 3          | 7           | Bocaben       | 76     |
| 3          | 2           | Ognjen        | 1672   |
| 3          | 13          | Nyancat       | 2508   |
| 3          | 15          | Morninngcat   | 5911   |

---

**Explanation:**  
For company 1, Max salary is 21300. Employees in company 1 have taxes = 49%  
For company 2, Max salary is 700. Employees in company 2 have taxes = 0%  
For company 3, Max salary is 7777. Employees in company 3 have taxes = 24%  
The salary after taxes = salary - (taxes percentage / 100) * salary  
For example, Salary for Morninngcat (3, 15) after taxes = 7777 - 7777 * (24 / 100) = 7777 - 1866.48 = 5910.52, which is rounded to 5911.

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
from pyspark.sql.functions import col, max as max_, when, round
from pyspark.sql.window import Window

# Start Spark session
spark = SparkSession.builder.appName("TaxedSalaries").getOrCreate()

# Define schema
schema = StructType([
    StructField("company_id", IntegerType(), True),
    StructField("employee_id", IntegerType(), True),
    StructField("employee_name", StringType(), True),
    StructField("salary", IntegerType(), True)
])

# Sample data
data = [
    (1, 1, "Tony", 2000),
    (1, 2, "Pronub", 21300),
    (1, 3, "Tyrrox", 10800),
    (2, 1, "Pam", 300),
    (2, 7, "Bassem", 450),
    (2, 9, "Hermione", 700),
    (3, 7, "Bocaben", 100),
    (3, 2, "Ognjen", 2200),
    (3, 13, "Nyancat", 3300),
    (3, 15, "Morninngcat", 7777)
]

# Create DataFrame
df = spark.createDataFrame(data, schema)

#create tempview
df.createOrReplaceTempView("Salaries")   


In [0]:
%sql
with cte as (
  select max(salary)over(partition by company_id ) as max_sal , *  from Salaries
)
select 
company_id , employee_id ,employee_name ,
case 
  when max_sal < 1000 then round(salary,0)
  when max_sal between 1000 and 10000 then round( (salary - salary * 0.24),0)
  when max_sal > 10000 then round((salary - salary * 0.47) ,0)
  end as new_salary
    from cte

In [0]:
from pyspark.sql.functions import *
from pyspark.sql.window import Window

# win_spec
win_spec = Window.partitionBy(col("company_id"))
max_win = max(col("salary")).over(win_spec)
max_df = (
    df.withColumn("max_salary", max_win)
    .withColumnRenamed("salary", "salary_old")
    .withColumn(
        "new_tax",
        when(col("salary_old") < 1000, col("salary_old"))
        .when(
            ((col("salary_old") >= 1000) & (col("salary_old") <= 10000)),
            col("salary_old") - col("salary_old") * 0.24,
        )
        .when(
            (col("salary_old") >= 1000), col("salary_old") - col("salary_old") * 0.49
        ),
    )
    .withColumn("salary", round(col("new_tax"), 0))
    .select("company_id", "employee_id", "employee_name", "salary")
    .display()
)

In [0]:

# Define window to get max salary per company
win_spec = Window.partitionBy("company_id")
df_taxed = df.withColumn("max_salary", max_("salary").over(win_spec))

# Determine tax rate
df_taxed = df_taxed.withColumn("tax_rate",
    when(col("max_salary") < 1000, 0)
    .when((col("max_salary") >= 1000) & (col("max_salary") <= 10000), 0.24)
    .otherwise(0.49)
)

# Apply tax and round salary
df_result = df_taxed.withColumn("salary",
    round(col("salary") * (1 - col("tax_rate"))).cast("int")
).select("company_id", "employee_id", "employee_name", "salary")

# Display result
display(df_result)