In [1]:
# reloads modules when running again
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns

from model_functions import *

from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics.pairwise import pairwise_distances

import warnings
warnings.filterwarnings('ignore')

## Data loading

In [3]:
data, metadata = load_data()

In [4]:
new_data, new_metadata = load_data(data_name="new_data", metadata_name="new_metadata")

In [5]:
# merge data
data = pd.concat([data, new_data], ignore_index=True)

In [6]:
new_metadata

Unnamed: 0,sample,baboon_id,collection_date,sex,age,social_group,group_size,rain_month_mm,season,hydro_year,...,diet_PC6,diet_PC7,diet_PC8,diet_PC9,diet_PC10,diet_PC11,diet_PC12,diet_PC13,month_sin,month_cos
0,sample_11412-TTCTGGTCTTGT-397,Baboon_103,2000-07-10,F,5.426420,g_1.21,21,0.6,dry,2000,...,-11.751633,-1.567519,-1.735464,-1.885510,0.427538,1.209694,-6.315116,1.225756,-0.500000,-8.660254e-01
1,sample_11412-ATCTTGGAGTCG-397,Baboon_103,2000-09-15,F,5.609856,g_1.21,21,0.0,dry,2000,...,-4.245491,-1.719074,5.872374,5.606421,0.519868,-1.399635,0.901593,0.363685,-1.000000,-1.836970e-16
2,sample_12053-GTTGATACGATG-409,Baboon_103,2001-02-07,F,6.006845,g_1.21,20,157.2,wet,2001,...,0.216568,-5.456956,-9.991175,4.687571,2.444637,-1.311342,-0.004980,0.053056,0.866025,5.000000e-01
3,sample_12053-AATGCGCGTATA-409,Baboon_103,2001-03-03,F,6.072553,g_1.21,19,3.2,wet,2001,...,-2.396195,-4.890560,-8.652358,7.925085,1.443745,-0.555187,0.117349,0.051628,1.000000,6.123234e-17
4,sample_11408-CAGTGATACTGC-395,Baboon_103,2001-03-27,F,6.138261,g_1.21,19,36.9,wet,2001,...,0.514945,0.642578,0.343181,1.095459,0.039154,0.402993,0.397106,-0.121340,1.000000,6.123234e-17
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1890,sample_11407-TCCGCTGCTGAC-394,Baboon_599,2012-09-22,M,19.750856,g_1.11,74,0.0,dry,2012,...,-3.677994,-1.167808,0.141555,-0.892291,-0.630581,-0.396363,0.475944,0.163199,-1.000000,-1.836970e-16
1891,sample_11406-GTAATAATGCCG-394,Baboon_599,2012-09-14,M,19.728953,g_1.11,74,0.6,dry,2012,...,-1.311975,0.817112,1.239203,0.111570,-0.723128,-0.540500,0.548988,0.182761,-1.000000,-1.836970e-16
1892,sample_11412-AGGCCTTGGGCG-397,Baboon_599,2012-09-06,M,19.707050,g_1.11,74,0.6,dry,2012,...,-3.251906,0.020369,2.242375,0.915096,-1.009625,-0.199283,0.653121,0.242996,-1.000000,-1.836970e-16
1893,sample_11412-GGTGGCATGGAA-397,Baboon_599,2012-05-08,M,19.375770,g_1.11,73,106.6,wet,2012,...,-6.442805,-0.758193,-2.569987,-2.594860,0.262834,-5.560058,-0.902044,0.543864,0.500000,-8.660254e-01


## Data preprocessing

### Split to train-test

In [7]:
# sort data by collection date
data = data.sort_values(by="collection_date")
train_df = data.copy()
test_df = new_metadata[~new_metadata["sample"].isin(data["sample"])]

test_df = test_df.reset_index(drop=True)
train_df = train_df.reset_index(drop=True)

In [8]:
meta_features_union = meta_features + ["month_sin", "month_cos"]

### Aggregation

In [9]:
# aggregate to one sample per week
train_df = aggregate_samples(train_df.copy())

In [10]:
train_df["collection_date"] = pd.to_datetime(train_df["collection_date"])
taxa_columns = [col for col in train_df.columns if col not in meta_features_union]

## Train

In [None]:
pred_df = predict(train_df, test_df)
pred_df.to_csv("pred_df.csv")

start trend_pred
finished trend_pred
start seasonal_pred
finished 100 out of 1695
finished 200 out of 1695
finished 300 out of 1695
finished 400 out of 1695
finished 500 out of 1695
finished 600 out of 1695
finished 700 out of 1695
finished 800 out of 1695
finished 900 out of 1695
finished 1000 out of 1695
finished 1100 out of 1695
finished 1200 out of 1695
finished 1300 out of 1695
finished 1400 out of 1695
finished 1500 out of 1695


In [None]:
dfskjlhlksdhsdfl

## Performance analysis

In [None]:
# can be uncomment for loading the data from a csv file instead of running the predict function

# pred_df = pd.read_csv("pred_df.csv")

In [None]:
# calculating Bray-Curtis score
d_matrix = pairwise_distances(pred_df[y_test.columns], y_test, metric="braycurtis").diagonal()
pd.DataFrame(d_matrix).to_csv("braycurtis_performance.csv")
braycurtis_score = d_matrix.mean()

### Performance Analysis

In [None]:
print(f"Bray-Curtis score: {round(braycurtis_score, 5)}")

In [None]:
analysis_df = test_df.copy()
analysis_df["braycurtis_score"] = d_matrix
analysis_df.drop(columns=["interpolated", "month_sin", "month_cos"], inplace=True)

In [None]:
# Bray-Curtis Distribution
plt.grid(False)
sns.histplot(analysis_df["braycurtis_score"], color="#46c1db", bins=25, kde=True)
plt.title("Bray-Curtis Distribution")
plt.xlabel("Bray-Curtis Score")
plt.xlim(0)
plt.show()

In [None]:
# Bray-Curtis Distribution
plt.grid(False)
sns.histplot(analysis_df.groupby("baboon_id")[["braycurtis_score"]].agg("mean")["braycurtis_score"], color="#46c1db", bins=25, kde=True)
plt.title("Average Bray-Curtis per Baboon")
plt.xlabel("Average per Baboon")
plt.show()

In [None]:
# Bray-Curtis by sex
plt.grid(False)
sns.violinplot(data=analysis_df, hue="sex", y="braycurtis_score", palette="muted")
plt.title("Bray-Curtis per Sex")
plt.xlabel("Sex")
plt.ylabel("Bray-Curtis Score")
plt.show()

In [None]:
# Bray-Curtis per Baboon
plt.figure(figsize=(30, 6))
plt.grid(True)
analysis_df["baboon_mean"] = analysis_df.groupby("baboon_id")[["braycurtis_score"]].transform("median")["braycurtis_score"]
sns.violinplot(data=analysis_df.sort_values("baboon_mean"), hue="baboon_id", y="braycurtis_score", palette="muted", legend=False)
plt.title("Bray-Curtis per Baboon")
plt.xlabel("Baboon")
plt.xlim((-0.41,0.41))
plt.ylabel("Bray-Curtis Score")
plt.show()

In [None]:
# creating a dictionary of babbon_id:last known sample
d = dict()
for baboon_id in train_df["baboon_id"].unique():
    d[baboon_id] = [sorted(train_df[train_df["baboon_id"] == baboon_id]["collection_date"].to_list())[-1]]
df1 = pd.DataFrame(d).T
merged_df = analysis_df.merge(df1, left_on="baboon_id", right_index=True)

In [None]:
# computing the distance between each sample to the last one in the train dataset
merged_df.rename(columns={0:"last_timepoint"}, inplace=True)
merged_df["dist_in_days"] = 0
merged_df["dist_in_days"] = merged_df.apply(lambda row: np.abs(row["collection_date"] - row["last_timepoint"]).total_seconds() / (60 * 60 * 24), axis=1)

In [None]:
# plot the Bray-Curtis score for the distances in days
mean_by_dist = merged_df.groupby("dist_in_days")[["braycurtis_score"]].mean()
plt.figure(figsize=(30, 6))
sns.regplot(x=mean_by_dist.index, y=mean_by_dist["braycurtis_score"], color="#46c1db")
plt.xlabel("Distance From the Last Known Sample (Days)")
plt.title("Bray-Curtis Score by Distance in Days From the Last Known Sample")
#plt.xscale("log")
plt.xlim(0)
plt.show()

In [None]:
# plot the Bray-Curtis score for the distances in days
mean_by_dist = merged_df.groupby("dist_in_days")[["braycurtis_score"]].mean()
plt.figure(figsize=(30, 6))
sns.regplot(x=mean_by_dist.index, y=mean_by_dist["braycurtis_score"], color="#46c1db")
plt.xlabel("Distance From the Last Known Sample (Days), log scale")
plt.title("Bray-Curtis Score by Distance in Days From the Last Known Sample")
plt.xscale("log")
plt.xlim(0)
plt.show()

### Milestone 3 Plots - Parameter Tuning

In [None]:
def plot_fig(df, line, log=False, x="Weight"):
    plt.figure(figsize=(8, 6))
    plt.plot(df["seasonal weight"], df["avg_bray_curtis"])
    if log:
        plt.xscale("log")
    plt.axvline(x = line, color = 'black', linestyle = '--') 
    
    #ticks = df[~df["weight"].isin([0.15, 0.25, 0.3, 0.4])]["weight"].tolist()
    #plt.xticks(ticks, ticks, rotation=30)
    plt.ylabel("Bray-Curtis Score")
    plt.xlabel(x)
    plt.show()

In [None]:
season_param_df = pd.read_csv("./hyperparam tuning\hyperparam tuning - seasonal\seasonal_hyperparam_season.csv").sort_values(by="weight")
identity_param_df = pd.read_csv("./hyperparam tuning\hyperparam tuning - seasonal\seasonal_hyperparam_identity.csv").sort_values(by="weight")
diet_param_df = pd.read_csv("./hyperparam tuning\hyperparam tuning - seasonal\seasonal_hyperparam_diet.csv").sort_values(by="weight")

In [None]:
plot_fig(season_param_df, line =  0.35, log=True)
plot_fig(identity_param_df, line = 0.1, log=True)
plot_fig(diet_param_df, line = 0.025, log=True)

In [None]:
trend_param_df = pd.read_csv("./hyperparam tuning/trend_hyperparam.csv").sort_values(by="weight")

In [None]:
plot_fig(trend_param_df, line = 80, log=False, x="Number of Samples")

In [None]:
seasonal_vs_trend_df = pd.read_csv("./hyperparam tuning/trend_seasonal_default.csv").sort_values(by="seasonal weight")
seasonal_vs_trend_extrapolation_df = pd.read_csv("./hyperparam tuning/trend_seasonal_extrapolation.csv").sort_values(by="seasonal weight")

In [None]:
plot_fig(seasonal_vs_trend_df, line = 0.5, log=False, x="Seasonal Weight")
plot_fig(seasonal_vs_trend_extrapolation_df, line = 0.8, log=False, x="Seasonal Weight")