In [0]:
# load table into spark dataframe
cows_bff = spark.read.table("db.cows_bff")

In [0]:
display(cows_bff.limit(10)) 

cow_name,meal_start,meal_end,meal_start_time,meal_end_time,duration,date,day
Butterscotch,33575,33936,9:19:35,9:25:36,361,2023-05-20T00:00:00.000+0000,Saturday
Hershey,32998,34215,9:09:58,9:30:15,1217,2023-05-20T00:00:00.000+0000,Saturday
Mocha,33671,34788,9:21:11,9:39:48,1117,2023-05-20T00:00:00.000+0000,Saturday
Cocoa,33671,34788,9:21:11,9:39:48,1117,2023-05-20T00:00:00.000+0000,Saturday
Nutella,31935,32628,8:52:15,9:03:48,693,2023-05-20T00:00:00.000+0000,Saturday
Brandy,33595,33997,9:19:55,9:26:37,402,2023-05-20T00:00:00.000+0000,Saturday
Peaches,33004,33113,9:10:04,9:11:53,109,2023-05-20T00:00:00.000+0000,Saturday
Marshmallow,32478,34034,9:01:18,9:27:14,1556,2023-05-20T00:00:00.000+0000,Saturday
Popcorn,33156,33564,9:12:36,9:19:24,408,2023-05-20T00:00:00.000+0000,Saturday
Muffin,34536,34580,9:35:36,9:36:20,44,2023-05-20T00:00:00.000+0000,Saturday


In [0]:
from pyspark.sql.functions import udf, countDistinct, sum, col
from pyspark.sql.types import IntegerType

@udf('integer')
def calculate_time_overlap(start_interval_1, end_interval_1, start_interval_2, end_interval_2):
  # Check if there is no overlap
  if end_interval_1 <= start_interval_2 or end_interval_2 <= start_interval_1:
    return 0  # No overlap
    
  # Calculate the overlap duration
  overlap_start = min(start_interval_1, start_interval_2)
  overlap_end = max(end_interval_1, end_interval_2)
  overlap_duration = overlap_end - overlap_start
    
  return overlap_duration

cow1 = cows_bff\
    .withColumnRenamed('cow_name','cow1')\
    .withColumnRenamed('meal_start','meal_start1')\
    .withColumnRenamed('meal_end','meal_end1')\
    .withColumnRenamed('date','date1')\
    .select('cow1','meal_start1','meal_end1','date1')


cow2 = cows_bff\
    .withColumnRenamed('cow_name','cow2')\
    .withColumnRenamed('meal_start','meal_start2')\
    .withColumnRenamed('meal_end','meal_end2')\
    .withColumnRenamed('date','date2')\
    .select('cow2','meal_start2','meal_end2','date2')

df = cow1.crossJoin(cow2)\
  .where((cow1.cow1 != cow2.cow2) & (cow1.date1 == cow2.date2))

df = df\
  .withColumn('overlap', calculate_time_overlap("meal_start1", "meal_end1", "meal_start2", "meal_end2"))\
  .select('cow1','cow2','date1','overlap')

df = df.groupBy('cow1', 'cow2').agg(sum('overlap').alias('total_overlap'), countDistinct('date1').alias('distinct_days'))
df = df\
  .withColumn('avg_overlap', df.total_overlap / df.distinct_days)\
  .select('cow1','cow2','avg_overlap')\
  .withColumnRenamed('avg_overlap','distance')\
  .sort(col('cow1').asc(), col('cow2').asc())

display(df.toPandas())

cow1,cow2,distance
Brandy,Buttercup,1415.4615384615386
Brandy,Butterscotch,1300.076923076923
Brandy,Cocoa,1425.5384615384614
Brandy,Daisy,1166.8461538461538
Brandy,Dottie,1337.7692307692307
Brandy,Hershey,1737.076923076923
Brandy,Magic,1258.0
Brandy,Marshmallow,1290.2307692307693
Brandy,Mocha,1425.5384615384614
Brandy,Muffin,1372.3846153846157


In [0]:
pdf = df.toPandas()
pdf = pdf.pivot(index='cow1', columns='cow2', values='distance').fillna(0)

import plotly.express as px
fig = px.imshow(pdf, x=pdf.columns, y=pdf.index, labels=dict(x="Cow 2", y="Cow 1", color="Distance"),color_continuous_scale='redor')
fig.update_layout(width=800,height=500)
fig.show()
