In [14]:
import os

os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import warnings

warnings.filterwarnings("ignore")

SEED = 42

In [15]:
import torch
import pickle
import pandas as pd
from tqdm import tqdm
from sklearn.neural_network import MLPClassifier
from sklearn.svm import LinearSVC, SVC
from sklearn.model_selection import KFold, GridSearchCV
from sklearn.metrics import f1_score
from sklearn.ensemble import RandomForestClassifier
from sentence_transformers import SentenceTransformer
from sklearn.base import clone

print(f"pytorch version: {torch.__version__}")
print(f"cuda available: {torch.cuda.is_available()}")
print(f"devices count: {torch.cuda.device_count()}")

pytorch version: 2.2.2
cuda available: False
devices count: 0


In [16]:
annotators = ["A001", "A002", "A003", "A004", "A005", "A007", "A008", "A009", "A010", "A012"]
embedder = 't-gbert-lpc'
classifier = 'svc'

In [17]:
def output_st2(predictions):

    QUANT_TO_QUAL={0: 0, 1: 1, 2: 1, 3: 1, 4: 1}
    
    predictions_qual: pd.DataFrame = predictions.applymap(lambda x: QUANT_TO_QUAL[x] if not pd.isna(x) else x)
    
    output_st2 = pd.DataFrame(index=["id"])
    output_st2.index = predictions.index
    
    output_st2["dist_bin_0"] = predictions_qual.apply(lambda x: (x == 0).sum() / x.count(), axis='columns')
    output_st2["dist_bin_1"] = predictions_qual.apply(lambda x: (x == 1).sum() / x.count(), axis='columns')
    output_st2["dist_multi_0"] = predictions.apply(lambda x: (x == 0).sum() / x.count(), axis='columns')
    output_st2["dist_multi_1"] = predictions.apply(lambda x: (x == 1).sum() / x.count(), axis='columns')
    output_st2["dist_multi_2"] = predictions.apply(lambda x: (x == 2).sum() / x.count(), axis='columns')
    output_st2["dist_multi_3"] = predictions.apply(lambda x: (x == 3).sum() / x.count(), axis='columns')
    output_st2["dist_multi_4"] = predictions.apply(lambda x: (x == 4).sum() / x.count(), axis='columns')
    
    return output_st2

In [18]:
models = {}
embeddings = pd.read_pickle(f"created_data/embeddings/{embedder}.pkl")

for annotator in annotators:
    y_train_split = pd.read_json(f"created_data/training_data/y_train_{annotator}.jsonl", lines=True).set_index('id')
    y_val_split = pd.read_json(f"created_data/training_data/y_val_{annotator}.jsonl", lines=True).set_index('id')
    y_train_all = pd.concat([y_train_split, y_val_split])
    X_train_all = embeddings.loc[y_train_all.index]
    
    with open(f"models/{embedder}_{classifier}_{annotator}.pkl", 'rb') as f:
        model = pickle.load(f)
    model.fit(X_train_all['Embedding'].tolist(), y_train_all[annotator].tolist())
    models[annotator] = model

In [19]:
embeddings_test = pd.read_pickle(f"created_data/embeddings/{embedder}_test.pkl")

annotator_rows = pd.read_json(f"created_data/training_data/X_test.jsonl", lines=True).set_index('id')
predictions = pd.DataFrame(index=annotator_rows.index, columns=annotators)

for idx, annotator_row in annotator_rows.iterrows():
    correct_annos = annotator_row['annotators']
    embedding = embeddings_test.loc[idx]['Embedding']
    for anno in correct_annos:
        predictions.loc[idx][anno] = models[anno].predict([embedding])[0]

output = output_st2(predictions)

In [20]:
output.to_csv(f'created_data/results/st2_{embedder}_{classifier}.tsv', sep="\t")

output

Unnamed: 0_level_0,dist_bin_0,dist_bin_1,dist_multi_0,dist_multi_1,dist_multi_2,dist_multi_3,dist_multi_4
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
f3b81af2f6852bf1b9896629525d2f41,0.6,0.4,0.6,0.0,0.2,0.2,0.0
cf8b8bac7165144bb62b399a98843366,1.0,0.0,1.0,0.0,0.0,0.0,0.0
0c45cdf4cca5eec566d6dd53653b532b,0.1,0.9,0.1,0.0,0.5,0.4,0.0
3a60877d2c04ba65f457f7cc3e003169,0.6,0.4,0.6,0.0,0.3,0.1,0.0
f389b63364d8da93860e3c7e6569bf5b,0.7,0.3,0.7,0.0,0.2,0.1,0.0
...,...,...,...,...,...,...,...
2f7322c62b63ff74ec945bb38ed9f258,1.0,0.0,1.0,0.0,0.0,0.0,0.0
ec5fe35f542aac2f3155177dbf2731c2,1.0,0.0,1.0,0.0,0.0,0.0,0.0
6674986a02bab67b011df90cc7396a96,1.0,0.0,1.0,0.0,0.0,0.0,0.0
2a3774eba33afe18af2f0d312d081bb3,1.0,0.0,1.0,0.0,0.0,0.0,0.0
