In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import datetime

import pandas as pd
import tentaclio
import sklearn

from phoenix.common import artifacts, run_datetime
from phoenix.common import utils
from phoenix.custom_models.tension_classifier import process_annotations
from phoenix.custom_models.tension_classifier.tension_classifier import VectorizerTopicsTensionClassifier

In [None]:
# Parameters
# See phoenix/common/run_datetime.py expected format of parameter
RUN_DATETIME = None

# See phoenix/common/artifacts/registry_environment.py expected format of parameter
ARTIFACTS_ENVIRONMENT_KEY = "local"

# Input
FOLDER_ANNOTATIONS = f"{artifacts.urls.get_local()}input_csvs/annotated_data/"

## Only try to classify a tension if there are at least this many objects
MIN_NUM_OBJECTS_PER_TENSION = 60

In [None]:
if RUN_DATETIME:
    run_dt = run_datetime.from_file_safe_str(RUN_DATETIME)
else:
    run_dt = run_datetime.create_run_datetime_now()
    
url_config = {}
art_url_reg = artifacts.registry.ArtifactURLRegistry(run_dt, ARTIFACTS_ENVIRONMENT_KEY)
STATIC_URL_CUSTOM_MODELS_TENSION_CLASSIFIER_BASE = art_url_reg.get_url("static-custom_models_tension_classifier_base", url_config)

In [None]:
utils.setup_notebook_output()
utils.setup_notebook_logging()

In [None]:
# Display params.
print(
STATIC_URL_CUSTOM_MODELS_TENSION_CLASSIFIER_BASE,
FOLDER_ANNOTATIONS,
run_dt.dt,
sep='\n',
)

In [None]:
# This csv already only has data for the tesnions that we want to classify (and have enough information for.)
df = pd.read_csv(f"{FOLDER_ANNOTATIONS}phoenix_tensions.csv")

In [None]:
test_df = pd.read_csv(f"{FOLDER_ANNOTATIONS}phoenix_tensions_holdout.csv")

## Train stemmed_count vectorizer on 'full' dataset to get a complete vectorizer.
 This shouldn't be model leakage, however we'll need to find a way to mitigate having new words that the count_vectorizer hasn't seen yet.  


### Due to small sizes of data with certain labels, we're only taking those with more than say 60 examples

In [None]:
class_labels = df.filter(like="is_").columns.tolist()

In [None]:
vt_tension_classifier = VectorizerTopicsTensionClassifier(class_labels)

In [None]:
vt_tension_classifier.train(df, random_state_int=2021)

In [None]:
vt_tension_classifier

In [None]:
vt_tension_classifier.analyse(test_df)

In [None]:
vt_tension_classifier.persist_model(STATIC_URL_CUSTOM_MODELS_TENSION_CLASSIFIER_BASE)

In [None]:
eval_df = test_df.copy()

In [None]:
eval_df

In [None]:
vt_tension_classifier.predict(eval_df)