In [None]:
%env CUDA_VISIBLE_DEVICES=1

from typing import Union, Sequence, Tuple, Dict, Literal
import json
import os

from tqdm.auto import tqdm
from termcolor import colored
import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F

import transformers

from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA

import pickle

class args:
    random_seed = 52
    train_size = 0.9
    data_path = 'PATH_TO_YOUR_MINED_DATA'

In [None]:
data = torch.load(args.data_path).float()

In [None]:
np.random.seed(args.random_seed)

In [None]:
from IPython.display import display, HTML
display(HTML("<style>.container { width:75% !important; }</style>"))

In [None]:
def make_X_y(data: list[Dict], hiddens: Literal['current', 'prev', 'both'] = 'current') -> Tuple[np.array, np.array]: 
    X = []
    X_prev = []
    y = []

    n_skip = 0
    skip_indexes = set()
    for sample_idx, sample_dict in enumerate(data):
        if 'hiddens' in sample_dict:
            assert len(sample_dict['changed_token_indices']) == len(sample_dict['hiddens'])
            x_ = [i.numpy() for i in sample_dict['hiddens']]
            x_prev_ = [i.numpy() for i in sample_dict['prev_hiddens']]
            y_ = [int(i[1]) for i in sample_dict['changed_token_indices']]
            X.extend(x_)
            X_prev.extend(x_prev_)
            y.extend(y_)
        else:
            n_skip += 1
            skip_indexes.add(sample_idx)


    print(skip_indexes)
        
    X = np.array(X)
    X_prev = np.array(X_prev)
    y = np.array(y)
    if hiddens == 'current':
        return X, y
    elif hiddens == 'prev':
        return X_prev, y
    elif hiddens == 'both':
        X = np.concatenate((X, X_prev), axis=1)
        return X, y
    else:
        raise ValueError(f'{hiddens=} is not supported')

## Load and create data

In [None]:
np.random.shuffle(data)
first_test_sample_idx = int(len(data) * args.train_size) + 1

train_data = data[:first_test_sample_idx]
test_data = data[first_test_sample_idx:]

In [None]:
X_train, y_train = make_X_y(train_data)
X_val, y_val = make_X_y(test_data)

In [None]:
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)

print(X_train.shape, X_val.shape)
print(y_train.mean(), y_val.mean())

## train

In [None]:
def train_head_and_search_best_hparam(hidden_dim_slice: slice, C: float | None = None):
    C_grid = [C] if C is not None else (10**0.5) ** -np.arange(10)[::-1]

    dataframes = []

    best_C = None
    best_metric_roc_auc = -1
    best_dataframe = None

    best_model = None
    for c_idx, C in enumerate(tqdm(C_grid)):

        print('#####' * 6, end='')
        print(f' {C=:.5f}, {c_idx=} ', end='')
        print('#####' * 6)    

        model_ = LogisticRegression(C=C)

        model_.fit(X_train[:, hidden_dim_slice], y_train)
        train_probs = model_.predict_proba(X_train[:, hidden_dim_slice])[:, 1]
        val_probs = model_.predict_proba(X_val[:, hidden_dim_slice])[:, 1]
        quantiles =  np.linspace(0, 1, 20)
        # quantiles =  np.linspace(0, 1, 1000)
        roc_auc_score_train = 0.0 if len(set(y_val)) == 1 else roc_auc_score(y_train, train_probs)
        roc_auc_score_val = 0.0 if len(set(y_val)) == 1 else roc_auc_score(y_val, val_probs)
        metrics = {
            'roc_auc_train': [],
            'roc_auc_val': [],

            'acc_train': [],
            'recall_train': [],
            'precision_train': [],

            'acc_val': [],
            'recall_val': [],
            'precision_val': [],

            'neg_rate_val': [],

            'q': [],
            'thr': []
        }
        quantiles = list(quantiles)
        quantiles = sorted(quantiles)
        for quantile in quantiles:
            thr = np.quantile(val_probs, quantile)
            
            train_pred = train_probs > thr
            val_pred = val_probs > thr

            metrics['acc_train'].append(accuracy_score(y_train, train_pred))
            metrics['recall_train'].append(recall_score(y_train, train_pred))
            metrics['precision_train'].append(precision_score(y_train, train_pred))

            metrics['acc_val'].append(accuracy_score(y_val, val_pred))
            metrics['recall_val'].append(recall_score(y_val, val_pred))
            metrics['precision_val'].append(precision_score(y_val, val_pred))

            metrics['roc_auc_train'].append(roc_auc_score_train)
            metrics['roc_auc_val'].append(roc_auc_score_val)

            metrics['q'].append(quantile)
            metrics['thr'].append(thr)

            metrics['neg_rate_val'].append(1 - (val_probs > thr).mean())

        metrics_df = pd.DataFrame(metrics).set_index('q')
        metrics_df['C'] = C
        metrics_df['c_idx'] = c_idx
    
        print(metrics_df[['roc_auc_train', 'roc_auc_val', 'C']].iloc[0])
        print()
        if roc_auc_score_val > best_metric_roc_auc:
            best_metric_roc_auc = roc_auc_score_val
            best_C = C
            best_dataframe = metrics_df
            best_model = model_

        dataframes.append(metrics_df)

    dataframes = pd.concat(dataframes)

    print(f"{best_C=} {best_metric_roc_auc=}")

    return dataframes, best_dataframe, best_model


In [None]:
all_dataframes_1_token, best_datarame_1_token, best_head_1_token = train_head_and_search_best_hparam(hidden_dim_slice=slice(2048*3))

In [None]:
all_dataframes_2_tokens, best_datarame_2_tokens, best_head_2_tokens = train_head_and_search_best_hparam(slice(2048 * 3 * 2))

### Save checkpoint

In [None]:
checkpoint_path_1_token = os.path.join('checkpoints', 'checkpoint_head_1_token.pkl')
checkpoint_path_2_tokens = os.path.join('checkpoints', 'checkpoint_head_2_tokens.pkl')
checkpoint_path_draft_token_both_from_both_models_with_target_model_draft_token = os.path.join('checkpoints', 'checkpoint_head_draft_token_both_from_both_models_with_target_model_draft_token.pkl')
# save
with open(checkpoint_path_1_token, 'wb') as f:
    dump_dict = dict(model=best_head_1_token, scaler=scaler)
    pickle.dump(dump_dict, f)


with open(checkpoint_path_2_tokens, 'wb') as f:
    dump_dict = dict(model=best_head_2_tokens, scaler=scaler)
    pickle.dump(dump_dict, f)
    
with open(checkpoint_path_draft_token_both_from_both_models_with_target_model_draft_token, 'wb') as f:
    dump_dict = dict(model=best_head_draft_token_both_from_both_models_with_target_model_draft_token, scaler=scaler)
    pickle.dump(dump_dict, f)
