## Preprocessing Clue dataset

In [2]:
import math
import numpy as np
import pandas as pd
import seaborn as sns
import feather
import pickle
import itertools
from tqdm import tqdm
import matplotlib.pyplot as plt
from datetime import datetime, timedelta, date
from dateutil import relativedelta
from collections import defaultdict
import more_itertools as mit
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.preprocessing import OneHotEncoder
from dateutil.relativedelta import relativedelta
from itertools import chain
from collections import Counter 

# Import data

In [None]:
data_folder = "final_datasets/"

In [None]:
users_path = base_path + "users.feather"
cycles_path = base_path + "cycles.feather"
bc_path = base_path + "birth_control.feather"
tracking_path = base_path + "tracking.feather"

In [None]:
users = feather.read_dataframe(users_path)
users.shape

In [None]:
birth_control = feather.read_dataframe(bc_path)
birth_control.shape

In [None]:
cycles = feather.read_dataframe(cycles_path)
cycles.shape

In [None]:
tracking = feather.read_dataframe(tracking_path)
tracking.shape

# Filtering

### helper functions

In [None]:
def daterange(start_date, end_date):
    for n in range(int ((end_date - start_date).days) + 1):
        yield start_date + timedelta(n)

In [None]:
def make_unique_cycle_ids(row):
    return row["anon_id"] + "_" + str(row["cycle_id"])

## 1. Keep all dates > January 1st 2016

In [None]:
def filter_df_dates(df, col_name):
  # filter dataframe to only contain dates after Jan 1st, 2016
    df[col_name] = pd.to_datetime(df[col_name], format='%Y-%m-%d')
    df[col_name] = df[col_name].dt.date
    df = df[df[col_name] > datetime(2016, 1, 1).date()]
    return df

In [None]:
cycles = filter_df_dates(cycles, "cycle_start")
cycles = filter_df_dates(cycles, "cycle_end")
cycles.shape

In [None]:
tracking = filter_df_dates(tracking, "date")
tracking.shape

In [None]:
birth_control = filter_df_dates(birth_control, "date")
birth_control.shape

### add cycle id to tracking dataset

In [None]:
cycles["unique_cycle_id"] = cycles.apply(lambda x: make_unique_cycle_ids(x), axis=1)

In [None]:
def create_cycle_id_mapper(cycles):
    cycle_id_mapper = defaultdict(dict)
    for i, values in enumerate(tqdm(cycles.groupby("anon_id"))):
        user = values[0]
        cycle_starts = values[1]["cycle_start"]
        cycle_ids = values[1]["unique_cycle_id"]
        cycle_end_dates = values[1]["cycle_end"]
        for j, (start, ID, end) in enumerate(zip(cycle_starts, cycle_ids, cycle_end_dates)):
            date_range = [x for x in daterange(start, end)]
            cycle_id_mapper[user][tuple(date_range)] = ID
    return cycle_id_mapper

In [None]:
cycle_id_mapper = create_cycle_id_mapper(cycles)

In [None]:
tracking = tracking[tracking["anon_id"].isin(cycles["anon_id"].unique())]
tracking.shape

In [None]:
def map_date_to_ID(user, date):
    ID = 0
    for key in cycle_id_mapper[user].keys():
        if date in key:
            ID = cycle_id_mapper[user][key]  
            return ID
        else:
            continue
                
    if ID == 0:
        return "no valid cycle_id"

In [None]:
users = users[users["anon_id"].isin(cycles["anon_id"].unique())]
users.shape

In [None]:
tracking["cycle_id"] = tracking.apply(lambda x: map_date_to_ID(x["anon_id"], x["date"]), axis=1)

In [None]:
tracking = tracking[tracking["cycle_id"] != "no valid cycle_id"]
tracking = tracking[tracking["cycle_id"].isin(cycles["unique_cycle_id"].tolist())]
tracking.shape

In [None]:
birth_control = birth_control[birth_control["anon_id"].isin(cycles["anon_id"].unique())]
birth_control.shape

## 2. Users that have tracked at least 6 months and 6 cycles (both)

In [None]:
tracking = tracking.sort_values(["anon_id", "date"], ascending=True).reset_index()

In [None]:
users_to_keep = []
for user, values in tqdm(tracking.groupby(["anon_id"])):
    dates = values["date"].tolist()
    # time between first and last date tracked
    r = relativedelta.relativedelta(dates[-1], dates[0])
    months = r.years * 12 + r.months
    # number of cycles tracked
    cycles_list = values["cycle_id"].tolist()
    if months > 6 and len(cycles_list) > 6:
        users_to_keep.append(user)

In [None]:
tracking = tracking[tracking["anon_id"].isin(users_to_keep)]
tracking.shape

In [None]:
users = users[users["anon_id"].isin(users_to_keep)]
users.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(users_to_keep)]
cycles.shape

In [None]:
birth_control = birth_control[birth_control["anon_id"].isin(users_to_keep)]
birth_control.shape

## 3. Filter out users who only track their period

In [None]:
no_period_only_users = tracking[tracking["category"] != "period"]["anon_id"].unique()

In [None]:
tracking = tracking[tracking["anon_id"].isin(no_period_only_users)]
tracking.shape

In [None]:
users = users[users["anon_id"].isin(no_period_only_users)]
users.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(no_period_only_users)]
cycles.shape

In [None]:
birth_control = birth_control[birth_control["anon_id"].isin(no_period_only_users)]
birth_control.shape

## 4. Birth control simplification + re-assignment

We simplify birth control methods and group them to one of the 4 classes: **ON**, **OFF**, **OTHER-hormonal**, and **OTHER-non-hormonal** according to:


| Birth Control Method | Class Label | 
| --- | --- | 
| Pill combined alternating | **ON** |
| None | **OFF** |
| IUD, injection, implant, <br> vaginal ring, patch, <br> pill combined continuous, <br> pill minipill alternating | **OTHER-Hormonal** |
| condoms, <br> fertility awareness method (FAM) | **OTHER-Non-Hormonal** |


Here, our main birth control method of interest is *Pill combined alternating*, specified as **ON**

In [None]:
def get_bc_labels(bc_type, pill_type, intake_regimen):
    # get output birth control labels: ON, OFF, OTHER-hormonal, and OTHER-non-hormonal
    if bc_type == "pill" and pill_type == "combined" and intake_regimen == "alternating":
        new_label = "ON"
    elif bc_type == "pill" and pill_type == "" and intake_regimen == "alternating":
        new_label = "ON"
    elif bc_type == "pill" and pill_type == "" and intake_regimen == "":
        new_label = "ON"
    elif bc_type == "none":
        new_label = "OFF"
    elif bc_type in ["condoms", "fertility_awareness_method"]:
        new_label = "OTHER-NH"
    elif bc_type in ["IUD", "injection", "implant", "vaginal_ring", "patch"] or \
        (bc_type == "pill" and pill_type == "combined" and intake_regimen == "continuous") or \
        (bc_type == "pill" and pill_type == "minipill" and intake_regimen == "continuous"):
        new_label = "OTHER-H"
    else:
        new_label = "NA"
        
    return new_label

In [None]:
birth_control["bc_label"] = birth_control.apply(lambda x: get_bc_labels(x["type"], x["pill_type"], x["intake_regimen"]), axis=1)

In [None]:
# remove users that only have "NA" BC label
birth_control = birth_control[birth_control["bc_label"] != "NA"]

In [None]:
birth_control["bc_label"].value_counts(normalize=True) * 100

In [None]:
users_to_keep = birth_control["anon_id"].unique()

In [None]:
users = users[users["anon_id"].isin(users_to_keep)]
users.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(users_to_keep)]
cycles.shape

In [None]:
tracking = tracking[tracking["anon_id"].isin(users_to_keep)]
tracking.shape

## 5. Users that have at least 3 cycles that are on the birth control method of interest (ON)

In [None]:
# users that were ever on the birth control "ON"
users_bc_ON = birth_control[birth_control["bc_label"] == "ON"]["anon_id"].tolist()

In [None]:
birth_control = birth_control[birth_control["anon_id"].isin(users_bc_ON)]
birth_control.shape

In [None]:
users = users[users["anon_id"].isin(users_bc_ON)]
users.shape

In [None]:
tracking = tracking[tracking["anon_id"].isin(users_bc_ON)]
tracking.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(users_bc_ON)]
cycles.shape

### Label the time series as ON/OFF/OTHER-NH/OTHER-H

Let's map the birth controls to the tracking table

In [None]:
def add_BC_end_dates(birth_control):
    birth_control = birth_control.sort_values(["anon_id", "date"], ascending=True)
    BC_end_dates = list()
    for i, values in enumerate(tqdm(birth_control.groupby("anon_id")["date"])):
        BC_end_date_list = values[1].tolist()[1:]
        # arbitrary end date for last BC, or can add last date of the tracking dataset as the end of the BC 
        BC_end_date_list.append(date.today())
        BC_end_dates += BC_end_date_list
    BC_end_dates = [date - timedelta(days=1) for date in BC_end_dates]
    birth_control["BC_end_date"] = BC_end_dates
    return birth_control

In [None]:
birth_control = add_BC_end_dates(birth_control)

Fix birth control end dates

In [None]:
def zero_runs(a):
    iszero = np.concatenate(([0], np.equal(a, 0).view(np.int8), [0]))
    absdiff = np.abs(np.diff(iszero))
    ranges = np.where(absdiff == 1)[0].reshape(-1, 2)
    return ranges

In [None]:
map_bc_labels_to_int = {"OFF": 0, "ON": 1, "OTHER-H": 2, "OTHER-NH": 3}

In [None]:
BC_end_dates_list = []
for i, values in enumerate(tqdm(birth_control.sort_values(["anon_id", "date"]).groupby("anon_id"))):
    BC_labels = values[1]["bc_label"].tolist()
    BC_end_dates = values[1]["BC_end_date"].tolist()
    BC_labels_int = [map_bc_labels_to_int[x] for x in BC_labels]
    runs = zero_runs(np.diff(BC_labels_int))
    if runs.shape[0] != 0:
        for consec_group in runs:
            begin_idx = consec_group[0]
            end_idx = consec_group[-1]
            correct_end_date = BC_end_dates[end_idx]
            for j in range(begin_idx, end_idx + 1):
                BC_end_dates[j] = correct_end_date
        BC_end_dates_list.extend(BC_end_dates)
    else:
        BC_end_dates_list.extend(BC_end_dates)

In [None]:
birth_control["BC_end_date_new"] = BC_end_dates_list

In [None]:
filtered_birth_control = birth_control.sort_values(["anon_id", "date"]).drop_duplicates(["anon_id", "bc_label", "BC_end_date_new"])

In [None]:
filtered_birth_control["BC_end_date_new"] = pd.to_datetime(filtered_birth_control["BC_end_date_new"])
filtered_birth_control["BC_end_date_new"] = filtered_birth_control["BC_end_date_new"].dt.date

In [None]:
def create_birth_control_mapper(birth_controls):
    BC_mapper = defaultdict(dict)
    for i, values in enumerate(tqdm(birth_controls.sort_values(["anon_id", "date"]).groupby("anon_id"))):
        user = values[0]
        BC_starts = values[1]["date"]
        BC_label = values[1]["bc_label"]
        BC_end_dates = values[1]["BC_end_date_new"]
        for j, (start, BC, end) in enumerate(zip(BC_starts, BC_label, BC_end_dates)):
            date_range = [x for x in daterange(start, end)]
            BC_mapper[user][tuple(date_range)] = BC
    return BC_mapper

In [None]:
BC_mapper = create_birth_control_mapper(birth_control)

In [None]:
tracking = tracking[tracking["anon_id"].isin(birth_control["anon_id"].unique())]

In [None]:
def map_dates_to_BC(user, date):
    # add dates from start to end of being on a BC method
    BC = 0
    for key in BC_mapper[user].keys():
        if date in key:
            BC = BC_mapper[user][key]  
            return BC
        else:
            continue

    if BC == 0:
        return "no valid BC"

In [None]:
tracking["birth_control"] = tracking.apply(lambda x: map_dates_to_BC(x["anon_id"], x["date"]), axis=1)
tracking = tracking[tracking["birth_control"] != "no valid BC"]

In [None]:
users_to_keep = tracking["anon_id"].unique().tolist()

In [None]:
users = users[users["anon_id"].isin(users_to_keep)]
users.shape

In [None]:
birth_control = birth_control[birth_control["anon_id"].isin(users_to_keep)]
birth_control.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(users_to_keep)]
cycles.shape

In [None]:
cycle_BC_set_dict = dict()
for user in tqdm(tracking["anon_id"].unique()):
    cycle_BC_set_dict[user] = dict()
    
for cycle, bc_values in tracking.groupby("cycle_id")["birth_control"]:
    bc_values = bc_values.values.tolist()
    bc_values = list(set(bc_values))
    user_cycle = cycle.split("_")
    user = user_cycle[0]
    cycle = int(user_cycle[1])
    cycle_BC_set_dict[user][cycle] = bc_values

In [None]:
users_to_keep = []
user_cycles_ON_dict = dict()
for user, values in cycle_BC_set_dict.items():
    cycles_list = [cycle for cycle, v in values.items() if v == ['ON']]
    cycles_list.sort()
    consec_cycles_list = [list(group) for group in mit.consecutive_groups(cycles_list)]
    consec_cycles_list_3_or_more = [x for x in consec_cycles_list if len(x) > 3]
    if len(consec_cycles_list_3_or_more) > 0:
        users_to_keep.append(user) 

In [None]:
birth_control = birth_control[birth_control["anon_id"].isin(users_to_keep)]
birth_control.shape

In [None]:
users = users[users["anon_id"].isin(users_to_keep)]
users.shape

In [None]:
tracking = tracking[tracking["anon_id"].isin(users_to_keep)]
tracking = tracking[tracking["birth_control"] != 'no valid BC']
tracking.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(users_to_keep)]
cycles.shape

## 6. Detect stable transitions

Determine whether a user has switched to another birth control method, and has stayed on this new birth control method for at least 90 days. This is considered to be a stable transition. If the three cycles before this stable transition were completely on the birth control method (ON), we keep this time series.

In [None]:
birth_control["BC_end_date_new"] = pd.to_datetime(birth_control["BC_end_date_new"])
birth_control["BC_end_date_new"] = birth_control["BC_end_date_new"].dt.date

In [None]:
birth_control = birth_control.sort_values(["anon_id", "date"]).drop_duplicates(["anon_id", "bc_label", "BC_end_date_new"])

In [None]:
user_BC_dict = dict()
for user, bc_values in birth_control.groupby("anon_id")["bc_label"]:
    bc_values = bc_values.values.tolist()
    bc_values = list(set(bc_values))
    user_BC_dict[user] = bc_values 

In [None]:
users_BC_change = [user for user, bc in user_BC_dict.items() if len(bc) > 1]

In [None]:
users_without_BC_change = birth_control[~birth_control["anon_id"].isin(users_BC_change)]
users_with_BC_change = birth_control[birth_control["anon_id"].isin(users_BC_change)]

In [None]:
users_with_BC_change["time_on_BC"] = users_with_BC_change.apply(lambda x: (x["BC_end_date_new"] - x["date"]).days, axis=1)

In [None]:
#drop first BC
users_with_stable_transitions = users_with_BC_change.groupby("anon_id", as_index=False).apply(lambda x: x.iloc[1:]).reset_index(drop=True)

In [None]:
# users that switch to a birth control method for at least 90 days that is not ON
users_with_stable_transitions = users_with_BC_change[users_with_BC_change["time_on_BC"] >= 90]
users_with_stable_transitions = users_with_stable_transitions[users_with_stable_transitions["bc_label"] != "ON"]

In [None]:
users_with_BC_change = users_with_BC_change[users_with_BC_change["anon_id"].isin(users_with_stable_transitions["anon_id"].tolist())]
users_with_BC_change = users_with_BC_change.sort_values(by=["anon_id", "date"])

In [None]:
stable_transitions_bools = []
for user, v in users_with_BC_change.groupby("anon_id"):
    bc_values = v["bc_label"].values.tolist()
    time_values = v["time_on_BC"].values.tolist()
    stable_transition_bool = ["bc_transition" if (i > 0 and x >= 90 and bc_values[i - 1] == "ON") else "bc_ON" if (i != len(time_values) - 1 and time_values[i + 1] >= 90 and bc_values[i] == "ON" and bc_values[i + 1] != "ON") else False for i, x in enumerate(time_values)]
    stable_transitions_bools += stable_transition_bool

In [None]:
users_with_BC_change["stable_transition_bool"] = stable_transitions_bools
users_with_stable_transitions = users_with_BC_change[users_with_BC_change["stable_transition_bool"].isin(["bc_transition", "bc_ON"])]

In [None]:
transition_ids = []
for anon, values in users_with_stable_transitions.groupby("anon_id"):
    bc_values = values["bc_label"].tolist()
    BC_end_date_list = values["BC_end_date_new"].tolist()
    BC_start_date_list = values["date"].tolist()
    i = 0
    t_group = 0
    for x, y in zip(*[iter(bc_values)]*2):
        if x == "ON" and y != "ON" and (BC_end_date_list[i] + timedelta(1) == BC_start_date_list[i + 1]):
            t_id = anon + str(t_group)
            transition_ids += [t_id, t_id] 
        else:
            to_keep += [False, False]
            transition_ids += [0, 0] 
        t_group += 1
        i += 2

In [None]:
users_with_stable_transitions["transition_id"] = transition_ids

In [None]:
users_with_stable_transitions_ON = users_with_stable_transitions[users_with_stable_transitions["stable_transition_bool"] == "bc_ON"]
users_with_stable_transitions_ON = users_with_stable_transitions_ON[users_with_stable_transitions_ON["time_on_BC"] > 30]

In [None]:
def get_cycles_ON(user, start_date, end_date):
    cycles_ON = cycles[(cycles["anon_id"] == user) & (cycles["cycle_start"] >= start_date) & (cycles["cycle_end"] <= end_date)]["unique_cycle_id"].tolist()
    return cycles_ON

In [None]:
users_with_stable_transitions_ON_dict = users_with_stable_transitions_ON[["transition_id", "anon_id", "date", "BC_end_date_new"]].to_dict('index')

In [None]:
cycles_ON_dict = {v["transition_id"]: get_cycles_ON(v["anon_id"], v["date"], v["BC_end_date_new"])for k, v in tqdm(users_with_stable_transitions_ON_dict.items())}

In [None]:
users_with_stable_transitions_ON["len_cycles_ON"] = users_with_stable_transitions_ON["transition_id"].apply(lambda x: len(cycles_ON_dict[x]))
users_with_stable_transitions_ON["cycles_ON"] = users_with_stable_transitions_ON["transition_id"].apply(lambda x: cycles_ON_dict[x])
users_with_stable_transitions_ON["cycles_ON"] = users_with_stable_transitions_ON["cycles_ON"].apply(lambda x: [int(c.split("_")[1]) for c in x])
users_with_stable_transitions_ON["cycles_ON"].apply(lambda x: x.sort())
users_with_stable_transitions_ON["last_3_cycles_ON"] = users_with_stable_transitions_ON["cycles_ON"].apply(lambda x: x[-3:])

In [None]:
def get_cycle_ids(anon_id, cycle_num_list):
    return [anon_id + "_" + str(c) for c in cycle_num_list]

In [None]:
users_with_stable_transitions_ON["last_3_cycles_ON"] = users_with_stable_transitions_ON.apply(lambda x: get_cycle_ids(x["anon_id"], x["last_3_cycles_ON"]), axis=1)

In [None]:
def bool_3_cycles_ON(user, three_cycles):
    final_3_cycles = []
    for c in three_cycles:
        if cycle_BC_set_dict[user][c] == {"ON"}:
            final_3_cycles.append(c)
    return final_3_cycles

In [None]:
users_with_stable_transitions_ON["last_cycles_ON"] = users_with_stable_transitions_ON.apply(lambda x: bool_3_cycles_ON(x["anon_id"], x["last_3_cycles_ON"]), axis=1)
users_with_stable_transitions_ON["len_last_cycles_ON"] = users_with_stable_transitions_ON["last_cycles_ON"].apply(lambda x: len(x))

In [None]:
final_users_with_stable_transitions_ON = users_with_stable_transitions_ON[users_with_stable_transitions_ON["len_last_cycles_ON"] == 3]
users_with_stable_transitions = users_with_stable_transitions[users_with_stable_transitions["transition_id"].isin(final_users_with_stable_transitions_ON["transition_id"].unique())]
users_with_stable_transitions["date_range"] = users_with_stable_transitions.apply(lambda x: [d for d in daterange(x["date"], x["BC_end_date_new"])], axis=1)

In [None]:
def map_dates_to_BC(user, date):
    BC = 0
    for key in BC_mapper[user].keys():
        if date in key:
            BC = BC_mapper[user][key]  
            return BC
        else:
            continue

    if BC == 0:
        return "no valid BC"

In [None]:
def cycle_BC_set(cycles):
    cycle_id_mapper = defaultdict(dict)
    for i, values in enumerate(tqdm(cycles.groupby("anon_id"))):
        user = values[0]
        cycle_starts = values[1]["cycle_start"]
        cycle_ids = values[1]["unique_cycle_id"]
        cycle_end_dates = values[1]["cycle_end"]
        for j, (start, ID, end) in enumerate(zip(cycle_starts, cycle_ids, cycle_end_dates)):
            date_range = [x for x in daterange(start, end)]
            BC_list = [map_dates_to_BC(user, x) for x in date_range]
            cycle_id_mapper[user][ID] = set(BC_list)
    return cycle_id_mapper

In [None]:
cycle_BC_set_dict = cycle_BC_set(cycles)

In [None]:
def map_dates_to_cycles(user, dates):
    ID_list = []
    for date in dates:
        ID = 0
        for key in cycle_id_mapper[user].keys():
            if date in key:
                ID = cycle_id_mapper[user][key]  
                ID_list.append(ID)
            else:
                continue

        if ID == 0:
            ID = "no valid cycle_id"
            ID_list.append(ID)
    return list(set(ID_list))

In [None]:
users_with_stable_transitions_to_keep = users_with_stable_transitions["anon_id"].unique()

In [None]:
output_df_transition = users_with_stable_transitions[users_with_stable_transitions["bc_label"] != "ON"]
output_df_transition["cycles"] = output_df_transition.apply(lambda x: map_dates_to_cycles(x["anon_id"], x["date_range"]), axis=1)

In [None]:
output_df_transition["cycle_count"] = output_df_transition["cycles"].apply(lambda x: len(x))
output_df_transition = output_df_transition.rename(columns={"date": "start_date_transition", "BC_end_date_new": "end_date_transition", "time_on_BC": "time_on_output_BC", "cycles": "output_cycles", "cycle_count": "cycle_count_output", "type": "type_output", "pill_type": "pill_type_output", "intake_regimen": "intake_regimen_output", "bc_label": "bc_label_out"})
output_df_transition = output_df_transition[["anon_id", "start_date_transition", "type_output", "pill_type_output", "intake_regimen_output", "bc_label_out", "end_date_transition", "time_on_output_BC", "output_cycles", "transition_id"]]

In [None]:
new_df = final_users_with_stable_transitions_ON[["anon_id", "date", "type", "pill_type", "intake_regimen", "bc_label", "BC_end_date_new", "last_cycles_ON", "transition_id"]]
new_df = new_df.rename(columns={"date": "input_start_date", "bc_label": "bc_label_in", "BC_end_date_new": "end_date_ON", "last_cycles_ON": "input_cycles", "type": "type_input", "pill_type": "pill_type_input", "intake_regimen": "intake_regimen_input"})

In [None]:
new_df = pd.merge(new_df, output_df_transition, how="left", left_on="transition_id", right_on="transition_id")
new_df = new_df.sort_values(by = ["anon_id", "input_start_date"])

In [None]:
# keep first transition
new_df = new_df.groupby('anon_id').head(1)

In [None]:
new_df["third_input_cycle"] = new_df["input_cycles"].apply(lambda x: x[-1])
new_df = pd.merge(new_df, cycles[["cycle_end", "unique_cycle_id"]], how="left", left_on="third_input_cycle", right_on="unique_cycle_id")
new_df = new_df.rename(columns={"cycle_end": "third_cycle_end"})

In [None]:
users_without_BC_change["date_range"] = users_without_BC_change.apply(lambda x: [d for d in daterange(x["date"], x["BC_end_date_new"])], axis=1)
users_without_BC_change["cycles"] = users_without_BC_change.apply(lambda x: map_dates_to_cycles(x["anon_id"], x["date_range"]), axis=1)
users_without_BC_change["cycles"] = users_without_BC_change["cycles"].apply(lambda x: [c.split("_")[1] for c in x])
users_without_BC_change["cycles"] = users_without_BC_change["cycles"].apply(lambda x: [int(c) for c in x if c != "id"])
users_without_BC_change["cycles"].apply(lambda x: x.sort())

In [None]:
users_without_BC_change["input_cycles"] = users_without_BC_change["cycles"].apply(lambda x: x[1:4])
users_without_BC_change["input_cycles"] = users_without_BC_change.apply(lambda x: get_cycle_ids(x["anon_id"], x["input_cycles"]), axis=1)

In [None]:
users_without_BC_change["output_cycles"] = users_without_BC_change["cycles"].apply(lambda x: x[4:])
users_without_BC_change["output_cycles"] = users_without_BC_change.apply(lambda x: get_cycle_ids(x["anon_id"], x["output_cycles"]), axis=1)
users_without_BC_change = users_without_BC_change.rename(columns={"date": "input_start_date", "type": "type_input", "pill_type": "pill_type_input", "intake_regimen":"intake_regimen_input", "bc_label": "bc_label_in", "BC_end_date_new": "end_date_ON", "cycle_end": "start_date_transition"})
users_without_BC_change["third_input_cycle"] = users_without_BC_change["input_cycles"].apply(lambda x: x[-1])
users_without_BC_change = pd.merge(users_without_BC_change, cycles[["cycle_end", "unique_cycle_id"]], how="left", left_on="third_input_cycle", right_on="unique_cycle_id")
users_without_BC_change = users_without_BC_change[["anon_id", "input_start_date", "type_input", "pill_type_input", "intake_regimen_input", "bc_label_in", "end_date_ON", "start_date_transition", "input_cycles", "output_cycles"]]

In [None]:
new_df["output_BC"] = new_df.apply(lambda x: (x["type_output"], x["pill_type_output"], x["intake_regimen_output"]), axis=1)
new_df.drop(["transition_id", "type_output", "pill_type_output", "intake_regimen_output", "unique_cycle_id"], axis=1, inplace=True)

In [None]:
users_without_BC_change["bc_label_out"] = "ON"
users_without_BC_change["output_BC"] = "ON"

In [None]:
users_without_BC_change["end_date_transition"] = pd.to_datetime('now')
users_without_BC_change["end_date_transition"] = users_without_BC_change["end_date_transition"].dt.date
users_without_BC_change["time_on_output_BC"] = -1
users_without_BC_change["third_input_cycle"] = users_without_BC_change["input_cycles"].apply(lambda x: x[-1])

In [None]:
# create final dataset
dataset = pd.concat([new_df, users_without_BC_change])

In [None]:
all_input_cycles = [item for sublist in dataset["input_cycles"].tolist() for item in sublist]
all_output_cycles = [item for sublist in dataset["output_cycles"].tolist() for item in sublist]
all_final_cycles = all_input_cycles + all_output_cycles

In [None]:
final_users = dataset["anon_id"].unique()

In [None]:
tracking = tracking[tracking["anon_id"].isin(final_users)]
tracking = tracking[tracking["cycle_id"].isin(all_final_cycles)]
tracking.shape

In [None]:
users = users[users["anon_id"].isin(final_users)]
users.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(final_users)]
cycles.shape

In [None]:
cycles = cycles[cycles["unique_cycle_id"].isin(all_final_cycles)]
cycles.shape

## 7. Keep users that keep tracking ≥ 90 days after their birth control transition

In [None]:
#check that filtered_tracking is sorted first
tracking = tracking.sort_values(["anon_id", "date"], ascending=True)
g = tracking.groupby("anon_id")
users_last_tracking = g.tail(1)

In [None]:
dataset = pd.merge(dataset, users_last_tracking[["anon_id", "date"]], how="left", left_on="anon_id", right_on="anon_id")
dataset = dataset.rename(columns={"date": "final_tracked_date"})

In [None]:
def get_days_tracked_after_transition(final_tracked_date, start_date_transition):
    return (final_tracked_date - start_date_transition).days

In [None]:
dataset["days_tracked_after_transition"] = dataset.apply(lambda x: get_days_tracked_after_transition(x["final_tracked_date"], x["start_date_transition"]), axis=1)

In [None]:
dataset["final_tracked_date"] = pd.to_datetime(dataset["final_tracked_date"])
dataset["final_tracked_date"] = dataset["final_tracked_date"].dt.date

In [None]:
dataset = dataset[dataset["days_tracked_after_transition"] >= 90]

In [None]:
dataset["bc_label_out"].value_counts(normalize=True)*100

In [None]:
final_users = dataset["anon_id"].unique().tolist()

In [None]:
users = users[users["anon_id"].isin(final_users)]
users.shape

In [None]:
cycles = cycles[cycles["anon_id"].isin(final_users)]
cycles.shape

In [None]:
tracking = tracking[tracking["anon_id"].isin(final_users)]
tracking.shape

## add user-specific data

In [None]:
dataset = pd.merge(dataset, users[["anon_id", "height", "weight", "last_country", "birth_year"]], how="left", left_on="anon_id", right_on="anon_id")

### BMI

In [None]:
def get_average_height_weight(h_w):
    if isinstance(h_w, str):
        if "-" in h_w:
            lim = h_w.split("-")
            min_ = int(lim[0])
            max_ = int(lim[1])
            avg = int((min_ + max_) / 2)
        elif "<" in h_w:
            avg = int(h_w.split("<")[1]) 
        elif ">=" in h_w:
            avg = int(h_w.split(">=")[1])
    elif np.isnan(h_w):
        avg = np.nan
    return avg

In [None]:
dataset["height"] = dataset["height"].apply(lambda x: get_average_height_weight(x))
dataset["height"] = dataset["height"] / 100

In [None]:
dataset["weight"] = dataset["weight"].apply(lambda x: get_average_height_weight(x))

In [None]:
def get_BMI(height, weight):
    if not (np.isnan(height) and np.isnan(weight)):
        return weight / np.square(height)
    else:
        return np.nan 

In [None]:
dataset["BMI"] = dataset.apply(lambda x: get_BMI(x["height"], x["weight"]), axis=1)

### Country

In [None]:
dataset = dataset.rename(columns={"last_country": "country"})

### Age

In [None]:
def get_current_age(start_date_transition, birth_year):
    if isinstance(birth_year, str):
        birth_year_lim = birth_year.split("-")
        min_by = birth_year_lim[0]
        birth_date_min = datetime(year=int(min_by), month=1, day=1)
        max_age = relativedelta.relativedelta(start_date_transition, birth_date_min).years
        if max_age > 0:
            return max_age
        else:
            return np.nan
    else:
        return np.nan

In [None]:
dataset["current_age"] = dataset.apply(lambda x: get_current_age((x["start_date_transition"], x["birth_year"]), axis=1)

### add cycle-specific data

In [None]:
dataset["first_input_cycle"] = dataset["input_cycles"].apply(lambda x: x[0])
dataset["second_input_cycle"] = dataset["input_cycles"].apply(lambda x: x[1])
dataset["third_input_cycle"] = dataset["input_cycles"].apply(lambda x: x[2])

In [None]:
dataset = pd.merge(dataset, cycles[["unique_cycle_id", "cycle_length"]], how="left", left_on="first_input_cycle", right_on="unique_cycle_id")
dataset = dataset.rename(columns={"cycle_length": "first_input_CL"})

In [None]:
dataset = pd.merge(dataset, cycles[["unique_cycle_id", "cycle_length"]], how="left", left_on="second_input_cycle", right_on="unique_cycle_id")
dataset = dataset.rename(columns={"cycle_length": "second_input_CL"})

In [None]:
dataset = pd.merge(dataset, cycles[["unique_cycle_id", "cycle_length"]], how="left", left_on="third_input_cycle", right_on="unique_cycle_id")
dataset = dataset.rename(columns={"cycle_length": "third_input_CL"})

In [None]:
dataset.pop("unique_cycle_id")

In [None]:
dataset['median_CL_inputs'] = dataset[['first_input_CL', 'second_input_CL', 'third_input_CL']].median(axis=1)
dataset['mean_CL_inputs'] = dataset[['first_input_CL', 'second_input_CL', 'third_input_CL']].mean(axis=1)
dataset['var_CL_inputs'] = dataset[['first_input_CL', 'second_input_CL', 'third_input_CL']].var(axis=1)

In [None]:
def get_days_on_current_BC(user, start_date_transition):
    date = start_date_transition - timedelta(days=1)
    key = [key for key in BC_mapper[user].keys() if date in key]
    key = list(key[0])
    min_date = np.min(key)
    return (date - min_date).days

In [None]:
dataset["days_on_current_BC"] = dataset.apply(lambda x: get_days_on_current_BC(x["anon_id"], x["start_date_transition"]), axis=1)

In [None]:
dataset["bc_label_out"].value_counts(normalize=True)*100

## prepare inputs

In [None]:
dataset = pd.merge(dataset, cycles[["unique_cycle_id", "cycle_start", "cycle_end"]], how="left", left_on="first_input_cycle", right_on="unique_cycle_id")
dataset = dataset.rename(columns={"cycle_start": "start_first_cycle", "cycle_end": "end_first_cycle"})

In [None]:
dataset = pd.merge(dataset, cycles[["unique_cycle_id", "cycle_start", "cycle_end"]], how="left", left_on="second_input_cycle", right_on="unique_cycle_id")
dataset = dataset.rename(columns={"cycle_start": "start_second_cycle", "cycle_end": "end_second_cycle"})

In [None]:
dataset = pd.merge(dataset, cycles[["unique_cycle_id", "cycle_start", "cycle_end"]], how="left", left_on="third_input_cycle", right_on="unique_cycle_id")
dataset = dataset.rename(columns={"cycle_start": "start_third_cycle", "cycle_end": "end_third_cycle"})

In [None]:
def get_input_cycles_25_dates(start_first_cycle, end_first_cycle, start_second_cycle, end_second_cycle, start_third_cycle, end_third_cycle):
    # we only want to take 25 days of the cycles into account
    # counting backward from last cycle: -18 to -1
    # and the first 7 days
    first_cycle_dates_1 = [d for d in daterange(start_first_cycle, start_first_cycle + timedelta(days=6))]
    first_cycle_dates_2 = [d for d in daterange(end_first_cycle - timedelta(days=17), end_first_cycle)]
    
    second_cycle_dates_1 = [d for d in daterange(start_second_cycle, start_second_cycle + timedelta(days=6))]
    second_cycle_dates_2 = [d for d in daterange(end_second_cycle - timedelta(days=17), end_second_cycle)]
    
    third_cycle_dates_1 = [d for d in daterange(start_third_cycle, start_third_cycle + timedelta(days=6))]
    third_cycle_dates_2 = [d for d in daterange(end_third_cycle - timedelta(days=17), end_third_cycle)]
    
    all_dates = first_cycle_dates_1 + first_cycle_dates_2 + second_cycle_dates_1 + second_cycle_dates_2 + third_cycle_dates_1 + third_cycle_dates_2
    return all_dates

In [None]:
dataset["input_cycles_25_dates"] = dataset.apply(lambda x: get_input_cycles_25_dates(x["start_first_cycle"], x["end_first_cycle"], x["start_second_cycle"], x["end_second_cycle"], x["start_third_cycle"], x["end_third_cycle"]), axis=1)

In [None]:
def get_end_120_days_output(third_cycle_end):
    return third_cycle_end + timedelta(days=120)

In [None]:
dataset["end_120_days_output"] = dataset["third_cycle_end"].apply(lambda x: get_end_120_days_output(x))

In [None]:
dataset["third_cycle_end"] = pd.to_datetime(dataset["third_cycle_end"])
dataset["third_cycle_end"] = dataset["third_cycle_end"].dt.date

In [None]:
def map_dates_to_BC(user, dates):
    BC_list = list()
    for date in dates:
        count = 0
        for key in BC_mapper[user].keys():
            if date in key:
                count += 1
                BC = BC_mapper[user][key]  
                BC_list.append(BC)
            else:
                continue
                
        if count == 0:
            BC_list.append("no valid BC")

    return BC_list

In [None]:
def get_BC_timeseries_input(user, input_dates):
    BC_list = list()
    for date in input_dates:
        count = 0
        for key in BC_mapper[user].keys():
            if date in key:
                count += 1
                BC = BC_mapper[user][key]
                BC_list.append(BC)
            else:
                continue
        
        if count == 0:
            BC_list.append("no valid BC")
    
    return BC_list

In [None]:
dataset["BC_input_timeseries"] = dataset.apply(lambda x: get_BC_timeseries_input(x["anon_id"], x["input_cycles_25_dates"]), axis=1)

In [None]:
dataset["n_output_cycles"] = dataset["output_cycles"].apply(lambda x: len(x))

In [None]:
dataset["final_output_cycles"] = dataset["output_cycles"].apply(lambda x: x[:3] if len(x) > 3 else x)
dataset["final_output_cycles"] = dataset.apply(lambda x: get_cycle_ids(x["anon_id"], x["final_output_cycles"]), axis=1)

In [None]:
def get_output_cycles_dates(three_cycles):
    first_cycle = three_cycles[0]
    last_cycle = three_cycles[-1]
    
    start = cycles[cycles["unique_cycle_id"] == first_cycle]["cycle_start"].values[0]
    start = pd.Timestamp(start).to_pydatetime()
    end = cycles[cycles["unique_cycle_id"] == last_cycle]["cycle_end"].values[0]
    end = pd.Timestamp(end).to_pydatetime()
    
    dates = [d.date() for d in daterange(start, end)]
    return dates

In [None]:
def get_output_dates(start_date, end_date):
    dates = [d for d in daterange(start_date + timedelta(days=1), end_date)]
    return dates

In [None]:
dataset["output_dates"] = dataset.apply(lambda x: get_output_dates(x["third_cycle_end"], x["end_120_days_output"]), axis=1)

In [None]:
final_users = dataset["anon_id"].unique().tolist()

In [None]:
def create_inverse_cycle_id_mapper(cycles):
    cycle_id_mapper = defaultdict(dict)
    for i, values in enumerate(tqdm(cycles.groupby("anon_id"))):
        user = values[0]
        cycle_starts = values[1]["cycle_start"]
        cycle_ids = values[1]["unique_cycle_id"]
        cycle_end_dates = values[1]["cycle_end"]
        for j, (start, ID, end) in enumerate(zip(cycle_starts, cycle_ids, cycle_end_dates)):
            cycle_id_mapper[user][ID] = [start, end]
    return cycle_id_mapper

In [None]:
inv_C_ID_mapper = create_inverse_cycle_id_mapper(cycles)

In [None]:
def get_output_cycles_dates_standardized(user, output_cycles):
    final_output_dates = []
    for c in output_cycles:
        [cycle_start, cycle_end] = inv_C_ID_mapper[user][c]
        cycle_dates_1 = [d for d in daterange(cycle_start, cycle_start + timedelta(days=6))]
        final_output_dates += cycle_dates_1
        cycle_dates_2 = [d for d in daterange(cycle_end - timedelta(days=17), cycle_end)]
        final_output_dates += cycle_dates_2
    
    return final_output_dates

In [None]:
dataset["final_output_dates"] = dataset.apply(lambda x: get_output_cycles_dates_standardized(x["anon_id"], x["final_output_cycles"]), axis=1)

In [None]:
dataset["BC_final_output_timeseries"] = dataset.apply(lambda x: get_BC_timeseries_input(x["anon_id"], x["final_output_dates"]), axis=1)

In [None]:
all_output_BC_labels = [item for sublist in dataset["BC_final_output_timeseries"].tolist() for item in sublist]

### one-hot encoding the birth control labels

In [None]:
le = preprocessing.LabelEncoder()
bc_s = dataset["bc_label_out"].unique().tolist()
bc_s.extend(['no valid BC'])
le.fit(bc_s)
le.classes_
# creating instance of one-hot-encoder
enc = OneHotEncoder(handle_unknown='ignore')
enc.fit(le.transform(bc_s).reshape(-1, 1))

In [None]:
le.transform(bc_s)

In [None]:
def transform_BC_timeseries_to_onehot(BC_timeseries):
    if BC_timeseries != []:
        le_BC = le.transform(BC_timeseries)
        return enc.transform(le_BC.reshape(-1, 1)).toarray()
    else:
        return []

In [None]:
dataset["BC_timeseries_onehot"] = dataset["BC_input_timeseries"].apply(lambda x: transform_BC_timeseries_to_onehot(x))

In [None]:
dataset["BC_output_timeseries_onehot"] = dataset["BC_final_output_timeseries"].apply(lambda x: transform_BC_timeseries_to_onehot(x))

In [None]:
def get_new_cat(cat, type_):
    if type_ == "motivated" or type_ == "unmotivated":
        cat = "motivation"
    elif type_ == "unproductive" or type_ == "productive":
        cat = "productivity"
    else:
        cat = cat
    return cat

In [None]:
tracking["category"] = tracking.apply(lambda x: get_new_cat(x["category"], x["type"]), axis=1)

**Ordinal:**
- period (bleeding) 
    - spotting
    - light
    - medium
    - heavy

- energy
    - exhausted
    - low energy
    - energized
    - high energy
    
- motivation 
    - unmotivated
    - motivated

- productivity
    - unproductive
    - productive
    
**Categorical:**
- pain
    - cramps
    - headache
    - tender breasts
    - ovulation Pain
 
- mental
    - calm
    - focused
    - distracted
    - stressed
    
- emotion
    - happy
    - sad
    - sensitive_emotion
    - pms
    
- social
    - conflict_social
    - withdrawn_social
    - sociable
    - supportive_social
    
- pain medication 

In [None]:
ordinal_variable_encoder = defaultdict(dict)
## let's add the ordinal categories

#period
ordinal_variable_encoder['period']['spotting'] = 1.0
ordinal_variable_encoder['period']['light'] = 2.0
ordinal_variable_encoder['period']['medium'] = 3.0
ordinal_variable_encoder['period']['heavy'] = 4.0

#energy
ordinal_variable_encoder['energy']['exhausted'] = 1.0
ordinal_variable_encoder['energy']['low_energy'] = 2.0
ordinal_variable_encoder['energy']['energized'] = 3.0
ordinal_variable_encoder['energy']['high_energy'] = 4.0

#motivation
ordinal_variable_encoder['motivation']['unmotivated'] = 1.0
ordinal_variable_encoder['motivation']['motivated'] = 2.0

#productivity
ordinal_variable_encoder['productivity']['unproductive'] = 1.0
ordinal_variable_encoder['productivity']['productive'] = 2.0

In [None]:
new_cat_list = map(lambda x, y: (x, y), tqdm(tracking["category"].tolist()), tracking["type"].tolist())
tracking["new_cat"] = list(new_cat_list)

In [None]:
tracking["new_cat"].value_counts(normalize=True) * 100

In [None]:
unique_categories = tracking["new_cat"].unique().tolist()
k = [x[0] for x in unique_categories]
unique_global_cats = []
for item in k:
    if item not in unique_global_cats:
        unique_global_cats.append(item)

In [None]:
unique_global_cats = ["pill_hbc", 
                      "period", 
                      "pain", 
                      "productivity", 
                      "social", 
                      "energy", 
                      "medication", 
                      "emotion", 
                      "mental",
                      "motivation"]

In [None]:
subcategories = {
    'pain': ['cramps', 'headache', 'ovulation_pain', 'tender_breasts'],
    'emotion': ['happy', 'sad', 'sensitive_emotion', 'pms'],
    'mental': ['calm', 'focused', 'distracted', 'stressed'],
    'social': ['conflict_social', 'withdrawn_social', 'sociable', 'supportive_social'],
    'pill_hbc': ['taken', 'missed', 'late', 'double']
}

In [None]:
missing_global_cats = {
    'emotion': [0.0, 0.0, 0.0, 0.0, 0.0],
    'energy': [0.0, 1.0],
    'medication': [0.0],
    'mental': [0.0, 0.0, 0.0, 0.0, 0.0],
    'motivation': [0.0, 1.0],
    'pain': [0.0, 0.0, 0.0, 0.0, 0.0],
    'period': [0.0, 1.0],
    'productivity': [0.0, 1.0],
    'social': [0.0, 0.0, 0.0, 0.0, 0.0],
    'pill_hbc': [0.0, 0.0, 0.0, 0.0, 0.0]
}

In [None]:
missing_one_hot_day = [missing_global_cats[x] for x in unique_global_cats]
missing_one_hot_day = [item for sublist in missing_one_hot_day for item in sublist]

In [None]:
all_input_cycles = [item for sublist in dataset["input_cycles"].tolist() for item in sublist]

In [None]:
input_cycles_df = cycles[cycles["unique_cycle_id"].isin(all_input_cycles)]

In [None]:
def get_25_standardized_days(cycle_start, cycle_end):    
    first_cycle_dates_1 = [d for d in daterange(cycle_start, cycle_start + timedelta(days=6))]
    first_cycle_dates_2 = [d for d in daterange(cycle_end - timedelta(days=17), cycle_end)]
    return first_cycle_dates_1 + first_cycle_dates_2

In [None]:
input_cycles_df["standardized_25_cycle_days"] = input_cycles_df.apply(lambda x: get_25_standardized_days(x["cycle_start"], x["cycle_end"]), axis=1)

In [None]:
standardized_cycle_days = dict(zip(input_cycles_df.unique_cycle_id, input_cycles_df.standardized_25_cycle_days))

In [None]:
input_cycles_tracking = tracking[tracking["cycle_id"].isin(all_input_cycles)]

### 8. Filter out input-output pairs where the user only tracks "period'' in the 3 input cycles.

In [None]:
logging_fractions_check = input_cycles_tracking[input_cycles_tracking["category"] != "period"]
input_cycles_logging_fractions_dict = logging_fractions_check.groupby("cycle_id")["date"].count().to_dict()

In [None]:
def get_logging_counts_3_input_cycles(input_cycles):
    three_counts = [input_cycles_logging_fractions_dict[x] if x in input_cycles_logging_fractions_dict.keys() else 0 for x in input_cycles]
    return sum(three_counts)

In [None]:
dataset["logging_counts_input_cycles"] = dataset["input_cycles"].apply(lambda x: get_logging_counts_3_input_cycles(x))

In [None]:
dataset = dataset[dataset["logging_counts_input_cycles"] > 0]

In [None]:
final_users = dataset["anon_id"].unique().tolist()

In [None]:
cycles = cycles[cycles["anon_id"].isin(final_users)]
cycles.shape

In [None]:
tracking = tracking[tracking["anon_id"].isin(final_users)]
tracking.shape

In [None]:
users = users[users["anon_id"].isin(final_users)]
users.shape

In [None]:
birth_control = birth_control[birth_control["anon_id"].isin(final_users)]
birth_control.shape

In [None]:
all_input_cycles = dataset["input_cycles"].to_list()
unique_input_cycles = [x for sublist in all_input_cycles for x in sublist]

In [None]:
all_output_cycles = dataset["final_output_cycles"].to_list()
unique_output_cycles = [x for sublist in all_output_cycles for x in sublist]

In [None]:
final_unique_cycles = unique_input_cycles + unique_output_cycles
final_unique_cycles = np.unique(np.array(unique_output_cycles))

In [None]:
final_cycles_df = cycles[cycles["unique_cycle_id"].isin(final_unique_cycles)]

In [None]:
final_cycles_df["standardized_25_cycle_days"] = final_cycles_df.apply(lambda x: get_25_standardized_days(x["cycle_start"], x["cycle_end"]), axis=1)

In [None]:
standardized_cycle_days = dict(zip(final_cycles_df.unique_cycle_id, final_cycles_df.standardized_25_cycle_days))

In [None]:
final_cycles_tracking = tracking[tracking["cycle_id"].isin(final_unique_cycles)]

In [None]:
BC_timeseries_cycles = defaultdict()

In [None]:
dataset["BC_timeseries_1"] = dataset["BC_timeseries_onehot"].apply(lambda x: x[:25])
dataset["BC_timeseries_1"] = dataset.apply(lambda x: (x["first_input_cycle"], x["BC_timeseries_1"]), axis=1)
list_1 = dataset["BC_timeseries_1"].tolist()
for d in tqdm(list_1):
    BC_timeseries_cycles[d[0]] = d[1]

In [None]:
dataset["BC_timeseries_2"] = dataset["BC_timeseries_onehot"].apply(lambda x: x[25:50])
dataset["BC_timeseries_2"] = dataset.apply(lambda x: (x["second_input_cycle"], x["BC_timeseries_2"]), axis=1)
list_2 = dataset["BC_timeseries_2"].tolist()
for d in tqdm(list_2):
    BC_timeseries_cycles[d[0]] = d[1]

In [None]:
dataset["BC_timeseries_3"] = dataset["BC_timeseries_onehot"].apply(lambda x: x[50:])
dataset["BC_timeseries_3"] = dataset.apply(lambda x: (x["third_input_cycle"], x["BC_timeseries_3"]), axis=1)
list_3 = dataset["BC_timeseries_3"].tolist()
for d in tqdm(list_3):
    BC_timeseries_cycles[d[0]] = d[1]

In [None]:
l = BC_timeseries_cycles

In [None]:
def get_tuples_output_BC_cycles(output_cycles_list, output_BC_timeseries_one_hot):
    return [(c, output_BC_timeseries_one_hot[i*25:(i+1)*25]) for i, c in enumerate(output_cycles_list)]

In [None]:
dataset["tuples_ouput_BC_cycles"] = dataset.apply(lambda x: get_tuples_output_BC_cycles(x["final_output_cycles"], x["BC_output_timeseries_onehot"]), axis=1)

In [None]:
list_output_cycles = dataset["tuples_ouput_BC_cycles"].tolist()
list_output_cycles = [item for sublist in list_output_cycles for item in sublist]
for d in tqdm(list_output_cycles):
    BC_timeseries_cycles[d[0]] = d[1]

In [None]:
l = BC_timeseries_cycles

In [None]:
idx_to_del = []
for c, values in tqdm(final_cycles_tracking.groupby("cycle_id")):
    cycle_tracking_dates = values["date"].tolist()
    final_idx_list = values["index"].tolist()
    wanted_dates = standardized_cycle_days[c]
    idx_list = [i for i, element in enumerate(cycle_tracking_dates) if element not in wanted_dates]
    idx_list = [final_idx_list[i] for i in idx_list]
    idx_to_del += idx_list

In [None]:
final_cycles_tracking = final_cycles_tracking[~final_cycles_tracking["index"].isin(idx_to_del)]

In [None]:
input_cycles_timeseries_onehot = defaultdict()

In [None]:
output_cycles_timeseries_onehot = defaultdict()

In [None]:
def transform_input_per_day(subset_tracking, day, BC_onehot):
    tracking_dates = subset_tracking["date"].tolist()
    if day not in tracking_dates:
        tracking_one_hot = missing_one_hot_day
        final = BC_onehot
        final.extend(tracking_one_hot)
        return final
    elif day in tracking_dates:
        date_tracking = subset_tracking[subset_tracking["date"] == day]
        new_cat_list = date_tracking["new_cat"].tolist()
        cat_list = {x[0]: x[1] for x in new_cat_list}
        day_one_hot = [[ordinal_variable_encoder[t][cat_list[t]], 0.0] \
                       if (t in cat_list.keys() and t in ordinal_variable_encoder.keys()) \
                       else [1.0] + [1.0 if s == cat_list[t] else 0.0 for s in subcategories[t]] \
                       if (t in cat_list.keys() and t not in ordinal_variable_encoder.keys() and t != "medication") \
                       else [1.0] if (t in cat_list.keys() and t not in ordinal_variable_encoder.keys() and t == "medication") \
                       else missing_global_cats[t] for t in unique_global_cats]
        day_one_hot = list(chain.from_iterable(day_one_hot))
        final = BC_onehot
        final.extend(day_one_hot)
        return final

In [None]:
for cycle, values in tqdm(standardized_cycle_days.items()):
    input_cycle_25_days = values
    subset_tracking = final_cycles_tracking[final_cycles_tracking["cycle_id"] == cycle]
    BC_timeseries_onehot = BC_timeseries_cycles[cycle]
    final_list = [transform_input_per_day(subset_tracking, d, BC_timeseries_onehot[i].tolist()) for i, d in enumerate(input_cycle_25_days)]
    output_cycles_timeseries_onehot[cycle] = final_list

In [None]:
# make path to save transformed datasets
input_data_path = "input_data_experiments/clue/"
Path(input_data_path).mkdir(parents=True, exist_ok=True)

In [None]:
final_output_cycles_timeseries_df = pd.DataFrame.from_dict(output_cycles_timeseries_onehot, orient='index')
final_output_cycles_timeseries_df.to_pickle(input_data_path + "final_output_cycles_timeseries_df.pkl")

In [None]:
final_cycles_timeseries_df = pd.DataFrame.from_dict(input_cycles_timeseries_onehot, orient='index')
final_cycles_timeseries_df.to_pickle(input_data_path + "final_cycles_timeseries_df.pkl")

In [None]:
def get_input_cycles_timeseries(three_input_cycles):
    final_list = []
    for x in three_input_cycles:
        final_list += input_cycles_timeseries_onehot[x]
    return final_list

In [None]:
dataset["input_one_hot"] = dataset["input_cycles"].apply(lambda x: get_input_cycles_timeseries(x))
dataset["input_onehot_len_check"] = dataset.apply(lambda x: len(x["input_one_hot"]), axis=1)
dataset["input_onehot_len_check"].value_counts()

In [None]:
dataset["APC_cycles"] = dataset.apply(lambda x: x["input_cycles"] + x["final_output_cycles"], axis=1)

In [None]:
dataset["APC_input_one_hot"] = dataset["APC_cycles"].apply(lambda x: get_input_cycles_timeseries(x))

### add baseline symptoms - when the user is OFF 

- median CL
- var CL
- average occurences of symptoms over all cycles OFF

In [None]:
dataset["starting_date_3_cycles"] = dataset["input_cycles_25_dates"].apply(lambda x: x[0])

In [None]:
tracking = pd.merge(tracking, cycles[["unique_cycle_id", "cycle_length"]], how="left", left_on="cycle_id", right_on="unique_cycle_id")

In [None]:
OFF_tracking = tracking[tracking["birth_control"] == "OFF"]

In [None]:
def get_max_cycle_symptoms_OFF(value_counts_list):
    max_dict = {}
    for i, d in enumerate(value_counts_list):
        for k, v in d.items():
            if k not in max_dict.keys():
                max_dict[k] = v
            elif (k in max_dict.keys() and v > max_dict[k]):
                max_dict[k] = v
            else:
                continue
    return max_dict

In [None]:
def get_min_cycle_symptoms_OFF(value_counts_list):
    min_dict = {}
    for i, d in enumerate(value_counts_list):
        for k, v in d.items():
            if k not in min_dict.keys():
                min_dict[k] = v
            elif (k in min_dict.keys() and v < min_dict[k]):
                min_dict[k] = v
            else:
                continue
    return min_dict

In [None]:
baseline_OFF_per_users = defaultdict()
for u, values in tqdm(OFF_tracking.groupby(["anon_id"])):
    cycles_OFF = values["cycle_id"].tolist()
    OFF_dates = values["date"].tolist()
    baseline_OFF_per_users[u] = defaultdict()
    if len(cycles_OFF) == 1:
        median_CL_OFF = values["cycle_length"].tolist()[0]
        var_CL_OFF = 0
        avg_cycle_symptoms_OFF = values["new_cat"].value_counts().to_dict()
        max_cycle_symptoms_OFF = avg_cycle_symptoms_OFF
        min_cycle_symptoms_OFF = avg_cycle_symptoms_OFF
        global_avg_symptoms_OFF = values["category"].value_counts().to_dict()
        baseline_OFF_per_users[u]["median_CL_OFF"] = median_CL_OFF
        baseline_OFF_per_users[u]["var_CL_OFF"] = var_CL_OFF
        baseline_OFF_per_users[u]["max_cycle_symptoms_OFF"] = max_cycle_symptoms_OFF
        baseline_OFF_per_users[u]["min_cycle_symptoms_OFF"] = min_cycle_symptoms_OFF
        baseline_OFF_per_users[u]["avg_cycle_symptoms_OFF"] = avg_cycle_symptoms_OFF
        baseline_OFF_per_users[u]["global_avg_symptoms_OFF"] = global_avg_symptoms_OFF
        baseline_OFF_per_users[u]["dates_OFF"] = (min(OFF_dates), max(OFF_dates))
    elif len(cycles_OFF) > 1:
        no_duplicates = values.drop_duplicates(subset='cycle_id', keep="first")
        median_CL_OFF = np.median(no_duplicates["cycle_length"])
        var_CL_OFF = np.var(no_duplicates["cycle_length"])
        counter_list = [Counter(values[values["cycle_id"] == x]["new_cat"].value_counts().to_dict()) for x in cycles_OFF]
        z = sum(counter_list, Counter())
        avg_cycle_symptoms_OFF = {k: v/len(cycles_OFF) for k, v in z.items()}
        value_counts_list = [values[values["cycle_id"] == x]["new_cat"].value_counts().to_dict() for x in cycles_OFF]
        max_cycle_symptoms_OFF = get_max_cycle_symptoms_OFF(value_counts_list)
        min_cycle_symptoms_OFF = get_min_cycle_symptoms_OFF(value_counts_list)
        global_counter_list = [Counter(values[values["cycle_id"] == x]["category"].value_counts().to_dict()) for x in cycles_OFF]
        z_global = sum(global_counter_list, Counter())
        global_avg_symptoms_OFF = {k: v/len(cycles_OFF) for k, v in z_global.items()}
        baseline_OFF_per_users[u]["median_CL_OFF"] = median_CL_OFF
        baseline_OFF_per_users[u]["var_CL_OFF"] = var_CL_OFF
        baseline_OFF_per_users[u]["max_cycle_symptoms_OFF"] = max_cycle_symptoms_OFF
        baseline_OFF_per_users[u]["min_cycle_symptoms_OFF"] = min_cycle_symptoms_OFF
        baseline_OFF_per_users[u]["avg_cycle_symptoms_OFF"] = avg_cycle_symptoms_OFF
        baseline_OFF_per_users[u]["global_avg_symptoms_OFF"] = global_avg_symptoms_OFF
        baseline_OFF_per_users[u]["dates_OFF"] = (min(OFF_dates), max(OFF_dates))

In [None]:
OFF_users = list(baseline_OFF_per_users.keys())

In [None]:
not_OFF_users = np.unique(dataset[~dataset["anon_id"].isin(OFF_users)]["anon_id"].tolist())

In [None]:
for u in tqdm(not_OFF_users):
    baseline_OFF_per_users[u] = defaultdict()
    baseline_OFF_per_users[u]["median_CL_OFF"] = np.nan
    baseline_OFF_per_users[u]["var_CL_OFF"] = np.nan
    baseline_OFF_per_users[u]["max_cycle_symptoms_OFF"] = np.nan
    baseline_OFF_per_users[u]["min_cycle_symptoms_OFF"] = np.nan
    baseline_OFF_per_users[u]["avg_cycle_symptoms_OFF"] = np.nan
    baseline_OFF_per_users[u]["global_avg_symptoms_OFF"] = np.nan
    baseline_OFF_per_users[u]["dates_OFF"] = np.nan

In [None]:
OFF_unique_cats = [x for x in unique_categories if x[0] != 'pill_hbc']

In [None]:
def get_vector_avg_sympt_input_cycles(avg_sympt_dict, unique_categories):
    return [avg_sympt_dict[t] if t in avg_sympt_dict.keys() else 0 for t in OFF_unique_cats]

In [None]:
def get_baseline_OFF(user, third_cycle_end):
    dates_OFF = baseline_OFF_per_users[user]["dates_OFF"] 
    if isinstance(dates_OFF, tuple):
        latest_OFF_date = dates_OFF[1]
        if latest_OFF_date < third_cycle_end:
            median_CL_OFF = baseline_OFF_per_users[user]["median_CL_OFF"]
            var_CL_OFF = baseline_OFF_per_users[user]["var_CL_OFF"]
            avg_cycle_symptoms_OFF = baseline_OFF_per_users[user]["avg_cycle_symptoms_OFF"]
            min_cycle_symptoms_OFF = baseline_OFF_per_users[user]["min_cycle_symptoms_OFF"]
            max_cycle_symptoms_OFF = baseline_OFF_per_users[user]["max_cycle_symptoms_OFF"]
            global_avg_symptoms_OFF = baseline_OFF_per_users[user]["global_avg_symptoms_OFF"]
        else:
            median_CL_OFF = np.nan
            var_CL_OFF = np.nan
            avg_cycle_symptoms_OFF = np.nan
            min_cycle_symptoms_OFF = np.nan
            max_cycle_symptoms_OFF = np.nan
            global_avg_symptoms_OFF = np.nan 
    elif np.isnan(dates_OFF):
        median_CL_OFF = np.nan
        var_CL_OFF = np.nan
        avg_cycle_symptoms_OFF = np.nan
        min_cycle_symptoms_OFF = np.nan
        max_cycle_symptoms_OFF = np.nan
        global_avg_symptoms_OFF = np.nan
        
    return median_CL_OFF, var_CL_OFF, avg_cycle_symptoms_OFF, min_cycle_symptoms_OFF, max_cycle_symptoms_OFF, global_avg_symptoms_OFF

In [None]:
dataset['median_CL_base_OFF'],  dataset['var_CL_base_OFF'], dataset['avg_cycle_symptoms_base_OFF'], dataset['min_cycle_symptoms_OFF'], dataset['max_cycle_symptoms_OFF'], dataset['global_avg_symptoms_OFF'] = zip(*dataset.apply(lambda x: get_baseline_OFF(x["anon_id"], x["third_cycle_end"]), axis=1))

In [None]:
missingness_vector = [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]

In [None]:
all_subcats = {
    'period': ['spotting', 'light', 'medium', 'heavy'],
    'pain': ['cramps', 'headache', 'ovulation_pain', 'tender_breasts'],
    'energy': ['exhausted', 'low_energy', 'energized', 'high_energy'],
    'emotion': ['happy', 'sad', 'sensitive_emotion', 'pms'],
    'motivation': ['unmotivated', 'motivated'],
    'mental': ['calm', 'focused', 'distracted', 'stressed'],
    'social': ['conflict_social', 'withdrawn_social', 'sociable', 'supportive_social'],
    'productivity': ['unproductive', 'productive']
}

In [None]:
missing_cats = {
    'emotion': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
    'energy': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
    'medication': [0.0, 0.0, 0.0, 1.0],
    'mental': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
    'motivation': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
    'pain': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
    'period': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
    'productivity': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
    'social': [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
}

In [None]:
OFF_unique_global_cats = [x for x in unique_global_cats if x != "pill_hbc"]

In [None]:
def get_baseline_OFF_vector(row, OFF_unique_global_cats):
    OFF_vector = []
    if not pd.notnull(row["median_CL_base_OFF"]):
        return [0.0, 0.0] + missingness_vector
    elif isinstance(row["global_avg_symptoms_OFF"], dict):
        OFF_vector.append(row["median_CL_base_OFF"])
        OFF_vector.append(row["var_CL_base_OFF"])
        for cat in OFF_unique_global_cats:
            if cat not in row["global_avg_symptoms_OFF"]:
                ext = missing_cats[cat]
                OFF_vector.extend(ext)
            else:
                if cat != "medication":
                    OFF_vector.append(float(row["global_avg_symptoms_OFF"][cat]))
                    for sub_cat in all_subcats[cat]:
                        if (cat, sub_cat) in row["avg_cycle_symptoms_base_OFF"]:
                            avg_subcat = row["avg_cycle_symptoms_base_OFF"][(cat, sub_cat)]
                            OFF_vector.append(float(avg_subcat))
                            min_subcat = row["min_cycle_symptoms_OFF"][(cat, sub_cat)]
                            OFF_vector.append(float(min_subcat))
                            max_subcat = row["max_cycle_symptoms_OFF"][(cat, sub_cat)]
                            OFF_vector.append(float(max_subcat))
                            missingness = 0.0
                            OFF_vector.append(missingness)
                        elif (cat, sub_cat) not in row["avg_cycle_symptoms_base_OFF"]:
                            avg_subcat = 0.0
                            OFF_vector.append(avg_subcat)
                            min_subcat = 0.0
                            OFF_vector.append(min_subcat)
                            max_subcat = 0.0
                            OFF_vector.append(max_subcat)
                            missingness = 1.0
                            OFF_vector.append(missingness)
                elif cat == "medication":
                    sub_cat = "pain_medication"
                    if (cat, sub_cat) in row["avg_cycle_symptoms_base_OFF"]:
                            avg_subcat = row["avg_cycle_symptoms_base_OFF"][(cat, sub_cat)]
                            OFF_vector.append(float(avg_subcat))
                            min_subcat = row["min_cycle_symptoms_OFF"][(cat, sub_cat)]
                            OFF_vector.append(float(min_subcat))
                            max_subcat = row["max_cycle_symptoms_OFF"][(cat, sub_cat)]
                            OFF_vector.append(float(max_subcat))
                            missingness = 0.0
                            OFF_vector.append(missingness)
                    elif (cat, sub_cat) not in row["avg_cycle_symptoms_base_OFF"]:
                        avg_subcat = 0.0
                        OFF_vector.append(avg_subcat)
                        min_subcat = 0.0
                        OFF_vector.append(min_subcat)
                        max_subcat = 0.0
                        OFF_vector.append(max_subcat)
                        missingness = 1.0
                        OFF_vector.append(missingness)
        return OFF_vector

In [None]:
dataset["baseline_OFF_vector"] = dataset.apply(lambda x: get_baseline_OFF_vector(x, OFF_unique_global_cats), axis=1)

In [None]:
dataset["baseline_OFF_vector_len_check"] = dataset.apply(lambda x: len(x["baseline_OFF_vector"]), axis=1)

### Turn baseline OFF vector into static input categories

In [None]:
dataset["median_CL_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[0])
dataset["var_CL_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[1])

# period
dataset["period_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[2])

dataset["period_spotting_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[3])
dataset["period_spotting_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[4])
dataset["period_spotting_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[5])
dataset["period_spotting_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[6])

dataset["period_light_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[7])
dataset["period_light_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[8])
dataset["period_light_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[9])
dataset["period_light_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[10])

dataset["period_medium_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[11])
dataset["period_medium_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[12])
dataset["period_medium_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[13])
dataset["period_medium_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[14])

dataset["period_heavy_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[15])
dataset["period_heavy_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[16])
dataset["period_heavy_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[17])
dataset["period_heavy_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[18])

#pain
dataset["pain_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[19])

dataset["pain_cramps_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[20])
dataset["pain_cramps_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[21])
dataset["pain_cramps_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[22])
dataset["pain_cramps_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[23])

dataset["pain_headache_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[24])
dataset["pain_headache_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[25])
dataset["pain_headache_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[26])
dataset["pain_headache_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[27])

dataset["pain_ovulation_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[28])
dataset["pain_ovulation_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[29])
dataset["pain_ovulation_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[30])
dataset["pain_ovulation_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[31])

dataset["pain_tender_breasts_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[32])
dataset["pain_tender_breasts_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[33])
dataset["pain_tender_breasts_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[34])
dataset["pain_tender_breasts_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[35])


#productivity
dataset["productivity_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[36])

dataset["unproductive_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[37])
dataset["unproductive_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[38])
dataset["unproductive_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[39])
dataset["unproductive_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[40])

dataset["productive_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[41])
dataset["productive_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[42])
dataset["productive_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[43])
dataset["productive_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[44])

#social
dataset["social_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[45])

dataset["conflict_social_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[46])
dataset["conflict_social_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[47])
dataset["conflict_social_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[48])
dataset["conflict_social_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[49])

dataset["withdrawn_social_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[50])
dataset["withdrawn_social_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[51])
dataset["withdrawn_social_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[52])
dataset["withdrawn_social_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[53])

dataset["sociable_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[54])
dataset["sociable_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[55])
dataset["sociable_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[56])
dataset["sociable_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[57])

dataset["supportive_social_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[58])
dataset["supportive_social_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[59])
dataset["supportive_social_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[60])
dataset["supportive_social_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[61])

#energy
dataset["energy_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[62])

dataset["energy_exhausted_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[63])
dataset["energy_exhausted_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[64])
dataset["energy_exhausted_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[65])
dataset["energy_exhausted_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[66])

dataset["low_energy_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[67])
dataset["low_energy_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[68])
dataset["low_energy_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[69])
dataset["low_energy_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[70])

dataset["energized_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[71])
dataset["energized_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[72])
dataset["energized_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[73])
dataset["energized_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[74])

dataset["high_energy_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[75])
dataset["high_energy_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[76])
dataset["high_energy_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[77])
dataset["high_energy_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[78])


#medication
dataset["pain_medication_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[79])
dataset["pain_medication_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[80])
dataset["pain_medication_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[81])
dataset["pain_medication_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[82])


#emotion
dataset["emotion_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[83])

dataset["emotion_happy_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[84])
dataset["emotion_happy_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[85])
dataset["emotion_happy_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[86])
dataset["emotion_happy_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[87])

dataset["emotion_sad_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[88])
dataset["emotion_sad_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[89])
dataset["emotion_sad_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[90])
dataset["emotion_sad_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[91])

dataset["sensitive_emotion_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[92])
dataset["sensitive_emotion_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[93])
dataset["sensitive_emotion_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[94])
dataset["sensitive_emotion_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[95])

dataset["emotion_pms_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[96])
dataset["emotion_pms_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[97])
dataset["emotion_pms_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[98])
dataset["emotion_pms_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[99])


#mental
dataset["mental_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[100])

dataset["mental_calm_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[101])
dataset["mental_calm_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[102])
dataset["mental_calm_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[103])
dataset["mental_calm_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[104])

dataset["mental_focused_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[105])
dataset["mental_focused_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[106])
dataset["mental_focused_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[107])
dataset["mental_focused_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[108])

dataset["mental_distracted_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[109])
dataset["mental_distracted_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[110])
dataset["mental_distracted_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[111])
dataset["mental_distracted_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[112])

dataset["mental_stressed_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[113])
dataset["mental_stressed_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[114])
dataset["mental_stressed_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[115])
dataset["mental_stressed_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[116])


#motivation
dataset["motivation_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[117])

dataset["unmotivated_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[118])
dataset["unmotivated_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[119])
dataset["unmotivated_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[120])
dataset["unmotivated_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[121])

dataset["motivated_avg_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[122])
dataset["motivated_min_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[123])
dataset["motivated_max_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[124])
dataset["motivated_missing_OFF"] = dataset["baseline_OFF_vector"].apply(lambda x: x[125])

In [None]:
# we want to scale on the non-zero values for standard scaling
def scale_column_nonzeros(column):
    arr = np.stack(column)
    idx_nonzeros = arr.nonzero()[0]
    
    if len(arr[arr != 0]) == 0:
        return column
    
    arr_nonzeros = arr[arr != 0].reshape(-1, 1)
    # scale nonzero values
    scaler = StandardScaler().fit(arr_nonzeros)
    scaled_nonzeros = scaler.transform(arr_nonzeros)
    
    for (index, replacement) in zip(idx_nonzeros, scaled_nonzeros):
        arr[index] = replacement
    
    return arr.tolist()

In [None]:
BASE_OFF_COLS = ["median_CL_OFF", "var_CL_OFF", "period_avg_OFF", "period_spotting_avg_OFF", 
                 "period_spotting_min_OFF", "period_spotting_max_OFF", "period_spotting_missing_OFF",
                 "period_light_avg_OFF", "period_light_min_OFF", "period_light_max_OFF", 
                 "period_light_missing_OFF", "period_medium_avg_OFF", "period_medium_min_OFF",
                 "period_medium_max_OFF", "period_medium_missing_OFF", "period_heavy_avg_OFF",
                 "period_heavy_min_OFF", "period_heavy_max_OFF", "period_heavy_missing_OFF",
                 "pain_avg_OFF", "pain_cramps_avg_OFF", "pain_cramps_min_OFF", "pain_cramps_max_OFF",
                 "pain_cramps_missing_OFF", "pain_headache_avg_OFF", "pain_headache_min_OFF",
                 "pain_headache_max_OFF", "pain_headache_missing_OFF", "pain_ovulation_avg_OFF",
                 "pain_ovulation_min_OFF", "pain_ovulation_max_OFF", "pain_ovulation_missing_OFF",
                 "pain_tender_breasts_avg_OFF", "pain_tender_breasts_min_OFF", "pain_tender_breasts_max_OFF",
                 "pain_tender_breasts_missing_OFF", "energy_avg_OFF", "energy_exhausted_avg_OFF", 
                 "energy_exhausted_min_OFF", "energy_exhausted_max_OFF", "energy_exhausted_missing_OFF",
                 "low_energy_avg_OFF", "low_energy_min_OFF", "low_energy_max_OFF", "low_energy_missing_OFF",
                 "energized_avg_OFF", "energized_min_OFF", "energized_max_OFF", "energized_missing_OFF", 
                 "high_energy_avg_OFF", "high_energy_min_OFF", "high_energy_max_OFF", "high_energy_missing_OFF",
                 "emotion_avg_OFF", "emotion_happy_avg_OFF", "emotion_happy_min_OFF", "emotion_happy_max_OFF",
                 "emotion_happy_missing_OFF", "emotion_sad_avg_OFF", "emotion_sad_min_OFF", "emotion_sad_max_OFF",
                 "emotion_sad_missing_OFF", "sensitive_emotion_avg_OFF", "sensitive_emotion_min_OFF", 
                 "sensitive_emotion_max_OFF", "sensitive_emotion_missing_OFF", "emotion_pms_avg_OFF", 
                 "emotion_pms_min_OFF", "emotion_pms_max_OFF", "emotion_pms_missing_OFF", "motivation_avg_OFF",
                 "unmotivated_avg_OFF", "unmotivated_min_OFF", "unmotivated_max_OFF", "unmotivated_missing_OFF",
                 "motivated_avg_OFF", "motivated_min_OFF", "motivated_max_OFF", "motivated_missing_OFF", 
                 "mental_avg_OFF", "mental_calm_avg_OFF", "mental_calm_min_OFF", "mental_calm_max_OFF", 
                 "mental_calm_missing_OFF", "mental_focused_avg_OFF", "mental_focused_min_OFF", 
                 "mental_focused_max_OFF", "mental_focused_missing_OFF", "mental_distracted_avg_OFF", 
                 "mental_distracted_min_OFF", "mental_distracted_max_OFF", "mental_distracted_missing_OFF",
                 "mental_stressed_avg_OFF", "mental_stressed_min_OFF", "mental_stressed_max_OFF", 
                 "mental_stressed_missing_OFF", "social_avg_OFF", "conflict_social_avg_OFF", "conflict_social_min_OFF",
                 "conflict_social_max_OFF", "conflict_social_missing_OFF", "withdrawn_social_avg_OFF", "withdrawn_social_min_OFF",
                 "withdrawn_social_max_OFF", "withdrawn_social_missing_OFF", "sociable_avg_OFF", "sociable_min_OFF",
                 "sociable_max_OFF", "sociable_missing_OFF", "supportive_social_avg_OFF", "supportive_social_min_OFF",
                 "supportive_social_max_OFF", "supportive_social_missing_OFF", "pain_medication_avg_OFF", "pain_medication_min_OFF",
                 "pain_medication_max_OFF", "pain_medication_missing_OFF", "productivity_avg_OFF", "unproductive_avg_OFF",
                 "unproductive_min_OFF", "unproductive_max_OFF", "unproductive_missing_OFF", "productive_avg_OFF",
                 "productive_min_OFF", "productive_max_OFF", "productive_missing_OFF"]

In [None]:
base_OFF_cols_without_missing = [x for x in BASE_OFF_COLS if "missing" not in x]

In [None]:
for col in base_OFF_cols_without_missing:
    dataset[col] = scale_column_nonzeros(dataset[col])

### Encode static input categories + output labels

In [None]:
le = LabelEncoder()
le.fit(dataset["country"])
dataset["country_encoded"] = le.transform(dataset["country"])

In [None]:
dataset["bc_label_out"].value_counts(normalize=True) * 100

In [None]:
dataset = dataset.rename(columns={"bc_label_out": "output"})

In [None]:
#set own labels:
label_encoder = {"ON": 0, "OFF": 1, "OTHER-H": 2, "OTHER-NH": 3}

In [None]:
dataset['output_cat'] = dataset["output"].map(label_encoder)

In [None]:
dataset['output_cat'].value_counts(normalize=True) * 100

#### fill missing static values with mean

In [None]:
aux_df = dataset[["BMI", "current_age", "median_CL_inputs", "var_CL_inputs", "days_on_current_BC"]]
aux_df = aux_df.fillna(aux_df.mean())
dataset[["BMI", "current_age", "median_CL_inputs", "var_CL_inputs", "days_on_current_BC"]] = aux_df

In [None]:
static_input_cols = BASE_OFF_COLS + ["BMI", "country_encoded", "current_age", "median_CL_inputs", "var_CL_inputs", "days_on_current_BC"]

### prepare APC input data

In [None]:
def get_full_25_day_one_hot(row):
    l = []
    for i in range(0, 25):
        l.append(row[str(i)])
    return l

In [None]:
final_output_cycles_timeseries_df["input_one_hot"] = final_output_cycles_timeseries_df.apply(lambda x: get_full_25_day_one_hot(x), axis=1)

In [None]:
def get_output_cycles_timeseries(output_cycles):
    final_list = []
    for x in output_cycles:
        final_list += final_output_cycles_timeseries_df[final_output_cycles_timeseries_df["unique_cycle_id"] == x]["input_one_hot"].values[0]
    return final_list

In [None]:
dataset["output_one_hot"] = dataset["final_output_cycles"].apply(lambda x: get_output_cycles_timeseries(x))

In [None]:
dataset["APC_one_hot"] = dataset["input_one_hot"] + dataset["output_one_hot"]

In [None]:
dataset["len_APC_days"] = dataset["APC_one_hot"].apply(lambda x: len(x))

In [None]:
data = dataset[["BMI", "country_encoded", "current_age", "median_CL_inputs", "var_CL_inputs", "APC_one_hot", "len_APC_days", \
                                                 "days_on_current_BC", "input", "output_cat"]]

In [None]:
data[BASE_OFF_COLS] = dataset[BASE_OFF_COLS].copy()

### missing values

In [None]:
def get_missing_features_per_timestep(timestep):
    n_missing = 0
    #BC method
    if timestep[:5] == [0.0, 0.0, 0.0, 0.0, 1.0]:
        n_missing += 1
    
    #pill_hbc
    if timestep[5:10] == [0.0, 0.0, 0.0, 0.0, 0.0]:
        n_missing += 1
    
    #period
    if timestep[10:12] == [0.0, 1.0]:
        n_missing += 1
    
    #pain
    if timestep[12:17] == [0.0, 0.0, 0.0, 0.0, 0.0]:
        n_missing += 1
    
    #productivity
    if timestep[17:19] == [0.0, 1.0]:
        n_missing += 1
    
    #social
    if timestep[19:24] == [0.0, 0.0, 0.0, 0.0, 0.0]:
        n_missing += 1
    
    #energy
    if timestep[24:26] == [0.0, 1.0]:
        n_missing += 1
    
    #medication
    if timestep[26] == [0.0]:
        n_missing += 1
    
    #emotion
    if timestep[27:32] == [0.0, 0.0, 0.0, 0.0, 0.0]:
        n_missing += 1
    
    #mental
    if timestep[32:37] == [0.0, 0.0, 0.0, 0.0, 0.0]:
        n_missing += 1
    
    #motivation
    if timestep[37:] == [0.0, 1.0]:
        n_missing += 1
    
    return n_missing   

In [None]:
def get_missing_features_per_input_cycles(input_cycles):
    n_missing_features = [get_missing_features_per_timestep(timestep) for timestep in input_cycles]
    return n_missing_features

In [None]:
dataset["missing_features_per_input"] = dataset["input_one_hot"].apply(lambda x: get_missing_features_per_input_cycles(x))

In [None]:
# get percentage of missing values per input - output pair
dataset["total_missingness"] = dataset["missing_features_per_input"].apply(lambda x: sum(x))
n_features = 11
n_time_steps = 75
n_total = n_features * n_time_steps
dataset["total_frac_missingness"] = dataset["total_missingness"] / n_total

In [None]:
# average number of missing values 
dataset["total_frac_missingness"].mean() * 100

In [None]:
# min number of missing values
dataset["total_frac_missingness"].min() * 100

In [None]:
# max number of missing values
dataset["total_frac_missingness"].max() * 100

### prepare dataset for GRU-D

In [None]:
def get_daily_mask_and_inputs_GRUD(input_one_hot_day, output_type):
    new_input_values = []
    daily_mask = np.ones((34,))
    
    #BC method
    new_input_values += list(input_one_hot_day[:4])
    if (input_one_hot_day[:5] == np.array([0.0, 0.0, 0.0, 0.0, 1.0])).all():
        daily_mask[:4] = 0
    else:
        daily_mask[:4] = 1
    
    #pill_hbc
    new_input_values += list(input_one_hot_day[5:10])
    if (input_one_hot_day[5:10] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[4:9] = 0
    else:
        daily_mask[4:9] = 1
    
    #period
    new_input_values += [input_one_hot_day[10]]
    if (input_one_hot_day[10:12] == np.array([0.0, 1.0])).all():
        daily_mask[9] = 0
    else:
        daily_mask[9] = 1
    
    #pain
    new_input_values += list(input_one_hot_day[12:17])
    if (input_one_hot_day[12:17] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[10:15] = 0
    else:
        daily_mask[10:15] = 1
    
    #productivity
    new_input_values += [input_one_hot_day[17]]
    if (input_one_hot_day[17:19] == np.array([0.0, 1.0])).all():
        daily_mask[15] = 0
    else:
        daily_mask[15] = 1
    
    #social
    new_input_values += list(input_one_hot_day[19:24])
    if (input_one_hot_day[19:24] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[16:21] = 0
    else:
        daily_mask[16:21] = 1
    
    #energy
    new_input_values += [input_one_hot_day[24]]
    if (input_one_hot_day[24:26] == np.array([0.0, 1.0])).all():
        daily_mask[21] = 0
    else:
        daily_mask[21] = 1
    
    #medication
    new_input_values += [input_one_hot_day[26]]
    if input_one_hot_day[26] == 0.0:
        daily_mask[22] = 0
    else:
        daily_mask[22] = 1
    
    #emotion
    new_input_values += list(input_one_hot_day[27:32])
    if (input_one_hot_day[27:32] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[23:28] = 0
    else:
        daily_mask[23:28] = 1
    
    #mental
    new_input_values += list(input_one_hot_day[32:37])
    if (input_one_hot_day[32:37] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[28:33] = 0
    else:
        daily_mask[28:33] = 1
    
    #motivation
    new_input_values += [input_one_hot_day[37]]
    if (input_one_hot_day[37:] == np.array([0.0, 1.0])).all():
        daily_mask[33] = 0
    else:
        daily_mask[33] = 1
        
    new_input_values = np.array(new_input_values)
    
    if output_type == "input":
        return new_input_values 
    elif output_type == "mask":
        return daily_mask

In [None]:
def apply_mask_GRUD(input_one_hot):
    masked_input = np.apply_along_axis(get_daily_mask_and_inputs_GRUD, 1, input_one_hot)
    return masked_input

In [None]:
data["X"] = data["input"].apply(lambda x: apply_mask_GRUD(x, "input"))

In [None]:
data["M"] = data["input"].apply(lambda x: apply_mask_GRUD(x, "mask"))

### prepare dataset for APC

In [None]:
def get_daily_mask(input_one_hot_day):    
    daily_mask = input_one_hot_day.copy()
    
    #BC method
    if (input_one_hot_day[:5] == np.array([0.0, 0.0, 0.0, 0.0, 1.0])).all():
        daily_mask[:5] = -1
    
    #pill_hbc
    if (input_one_hot_day[5:10] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[5:10] = -1
    
    #period
    if (input_one_hot_day[10:12] == np.array([0.0, 1.0])).all():
        daily_mask[10:12] = -1
    daily_mask[12] = -1
    
    #pain
    if (input_one_hot_day[12:17] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[12:17] = -1
    
    #productivity
    if (input_one_hot_day[17:19] == np.array([0.0, 1.0])).all():
        daily_mask[17:19] = -1
    daily_mask[19] = -1
    
    #social
    if (input_one_hot_day[19:24] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[19:24] = -1
    
    #energy
    if (input_one_hot_day[24:26] == np.array([0.0, 1.0])).all():
        daily_mask[24:26] = -1
    daily_mask[26] = -1
    
    #medication
    if input_one_hot_day[26] == 0.0:
        daily_mask[26] = -1
    
    #emotion
    if (input_one_hot_day[27:32] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[27:32] = -1
    
    #mental
    if (input_one_hot_day[32:37] == np.array([0.0, 0.0, 0.0, 0.0, 0.0])).all():
        daily_mask[32:37] = -1
    
    #motivation
    if (input_one_hot_day[37:] == np.array([0.0, 1.0])).all():
        daily_mask[37:] = -1
    
    daily_mask[38] = -1
    
    return daily_mask 

In [None]:
def apply_mask(input_one_hot):
    masked_input = np.apply_along_axis(get_daily_mask, 1, input_one_hot)
    return masked_input

In [None]:
dataset["APC_masked_one_hot"] = dataset["APC_one_hot"].apply(lambda x: apply_mask(x))

### save final dataset

In [None]:
with open('clue.pickle', 'wb') as handle:
    pickle.dump(dataset, handle, protocol=pickle.HIGHEST_PROTOCOL)