## First we initialize our globals

In [1]:
# load table into spark dataframe
cows_bff = spark.read.table("db.cows_bff")

## Some graphs and charts

In [2]:
display(cows_bff.limit(10))

Unnamed: 0,cow_name,meal_start,meal_end,meal_start_time,meal_end_time,duration,date,day
0,Butterscotch,33575,33936,9:19:35,9:25:36,361,2023-05-20,Saturday
1,Hershey,32998,34215,9:09:58,9:30:15,1217,2023-05-20,Saturday
2,Mocha,33671,34788,9:21:11,9:39:48,1117,2023-05-20,Saturday
3,Cocoa,33671,34788,9:21:11,9:39:48,1117,2023-05-20,Saturday
4,Nutella,31935,32628,8:52:15,9:03:48,693,2023-05-20,Saturday
5,Brandy,33595,33997,9:19:55,9:26:37,402,2023-05-20,Saturday
6,Peaches,33004,33113,9:10:04,9:11:53,109,2023-05-20,Saturday
7,Marshmallow,32478,34034,9:01:18,9:27:14,1556,2023-05-20,Saturday
8,Popcorn,33156,33564,9:12:36,9:19:24,408,2023-05-20,Saturday
9,Muffin,34536,34580,9:35:36,9:36:20,44,2023-05-20,Saturday


## Lets compute meal overlap

In [3]:
from pyspark.sql.functions import udf, countDistinct, sum, col
from pyspark.sql.types import IntegerType

calculate_time_overlap = lambda start_interval_1, end_interval_1, start_interval_2, end_interval_2: \
    0 if end_interval_1 <= start_interval_2 or end_interval_2 <= start_interval_1 else \
    max(end_interval_1, end_interval_2) - min(start_interval_1, start_interval_2) 

calculate_time_overlap_udf = udf(calculate_time_overlap, IntegerType())

cow1 = cows_bff\
    .withColumnRenamed('cow_name','cow1')\
    .withColumnRenamed('meal_start','meal_start1')\
    .withColumnRenamed('meal_end','meal_end1')\
    .withColumnRenamed('date','date1')\
    .select('cow1','meal_start1','meal_end1','date1')


cow2 = cows_bff\
    .withColumnRenamed('cow_name','cow2')\
    .withColumnRenamed('meal_start','meal_start2')\
    .withColumnRenamed('meal_end','meal_end2')\
    .withColumnRenamed('date','date2')\
    .select('cow2','meal_start2','meal_end2','date2')

df = cow1.crossJoin(cow2)\
  .where((cow1.cow1 != cow2.cow2) & (cow1.date1 == cow2.date2))

df = df\
  .withColumn('overlap', calculate_time_overlap_udf("meal_start1", "meal_end1", "meal_start2", "meal_end2"))\
  .select('cow1','cow2','date1','overlap')

df = df.groupBy('cow1', 'cow2').agg(sum('overlap').alias('total_overlap'), countDistinct('date1').alias('distinct_days'))
df = df\
  .withColumn('avg_overlap', df.total_overlap / df.distinct_days)\
  .select('cow1','cow2','avg_overlap')\
  .withColumnRenamed('avg_overlap','distance')\
  .sort(col('cow1').asc(), col('cow2').asc())

display(df.limit(10))

Unnamed: 0,cow1,cow2,distance
0,Brandy,Buttercup,1415.461538
1,Brandy,Butterscotch,1300.076923
2,Brandy,Cocoa,1425.538462
3,Brandy,Daisy,1166.846154
4,Brandy,Dottie,1337.769231
5,Brandy,Hershey,1737.076923
6,Brandy,Magic,1258.0
7,Brandy,Marshmallow,1290.230769
8,Brandy,Mocha,1425.538462
9,Brandy,Muffin,1372.384615


# Display Heatmap

In [4]:
pdf = df.toPandas()
pdf = pdf.pivot(index='cow1', columns='cow2', values='distance').fillna(0)

import plotly.express as px
fig = px.imshow(pdf, x=pdf.columns, y=pdf.index, labels=dict(x="Cow 2", y="Cow 1", color="Distance"), title="Cow BFFs", color_continuous_scale='redor')
px.imshow(pdf, x=pdf.columns, y=pdf.index, labels=dict(x="Cow 2", y="Cow 1", color="Distance"),)

fig.update_layout(width=800,height=500)
fig.show()
