In [None]:
train_data_path = '/content/drive/Shareddrives/UnlimitedSharedDrive/KSE-Task1-Word-embed/data_features/train_data_vnlaw_kse_jac_tfidf.npy'
test_data_path = '/content/drive/Shareddrives/UnlimitedSharedDrive/KSE-Task1-Word-embed/data_features/test_data_vnlaw_kse_jac_tfidf.npy'

# Load Feature-Extracted Data

In [None]:
import numpy as np
train_data = np.load(train_data_path)
test_data = np.load(test_data_path)

In [None]:
train_data.shape

(16457, 203)

In [None]:
from sklearn.model_selection import train_test_split
X = train_data[:, :-1]
y = train_data[:, -1]
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Model Training

In [None]:
from sklearn.model_selection import GroupKFold
import xgboost as xgb

skf = GroupKFold(n_splits=6)
clf = xgb.XGBClassifier(
    n_estimators=15000,
    max_depth=128,
    learning_rate=0.02,
    subsample=1,
    colsample_bytree=1,
    missing=-1,
    eval_metric='auc',
    # USE CPU
    nthread=4,
    tree_method='hist'
    # USE GPU
#   tree_method='gpu_hist' 
)

h = clf.fit(X_train, y_train, 
        eval_set=[(X_val,y_val)],
        verbose=100, early_stopping_rounds=300)

[0]	validation_0-auc:0.872019
Will train until validation_0-auc hasn't improved in 300 rounds.
[100]	validation_0-auc:0.943189
[200]	validation_0-auc:0.955178
[300]	validation_0-auc:0.954772
[400]	validation_0-auc:0.956533
[500]	validation_0-auc:0.957213
[600]	validation_0-auc:0.956943
Stopping. Best iteration:
[333]	validation_0-auc:0.958058



# Test on test set

In [None]:
X_test = test_data[:, :-1]
y_test = test_data[:, -1]

In [None]:
y_prob = h.predict_proba(X_test)

In [None]:
np.save('result_test_set.npy', y_prob)

In [None]:
y_prob.shape

(4150, 2)

# Infer private test

In [None]:
# Load saved model
import pickle
pickle.dump(h, open('xgb_model.pkl', 'wb'))

In [None]:
!git clone https://github.com/legal-qa-research/legal-qa-retrieval.git

Cloning into 'legal-qa-retrieval'...
remote: Enumerating objects: 591, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (48/48), done.[K
remote: Total 591 (delta 21), reused 45 (delta 16), pack-reused 526[K
Receiving objects: 100% (591/591), 92.66 KiB | 797.00 KiB/s, done.
Resolving deltas: 100% (372/372), done.


In [None]:
%cd /content/legal-qa-retrieval/
!git checkout kse-support

/content/legal-qa-retrieval
Branch 'kse-support' set up to track remote branch 'kse-support' from 'origin'.
Switched to a new branch 'kse-support'


In [None]:
%cd /content/legal-qa-retrieval/
!git add .
!git reset --hard
!git pull origin kse-support

/content/legal-qa-retrieval
HEAD is now at e21a733 not require args
remote: Enumerating objects: 7, done.[K
remote: Counting objects: 100% (7/7), done.[K
remote: Compressing objects: 100% (1/1), done.[K
remote: Total 4 (delta 3), reused 4 (delta 3), pack-reused 0[K
Unpacking objects: 100% (4/4), done.
From https://github.com/legal-qa-research/legal-qa-retrieval
 * branch            kse-support -> FETCH_HEAD
   e21a733..4f49db8  kse-support -> origin/kse-support
Updating e21a733..4f49db8
Fast-forward
 traditional_ml/constant.py | 4 [32m++[m[31m--[m
 1 file changed, 2 insertions(+), 2 deletions(-)


In [None]:
%%capture
!pip3 install fasttext
!pip3 install vncorenlp
!pip3 install sentence_transformers
!pip3 install scipy

In [None]:
import os
os.environ['PYTHONPATH'] += ":/content/legal-qa-retrieval"

In [None]:
! cp /content/drive/Shareddrives/UnlimitedSharedDrive/legal-qa-retrieval/pkl_file/* /content/legal-qa-retrieval/pkl_file/
! cp /content/drive/Shareddrives/UnlimitedSharedDrive/legal-qa-retrieval/data/* /content/legal-qa-retrieval/data
! cp /content/drive/Shareddrives/UnlimitedSharedDrive/pretrained_fastext/vnlaw_ft.bin /content/legal-qa-retrieval/traditional_ml/pretrained_fasttext

In [None]:
# Run Feature-builder and infer test set
%cd /content/legal-qa-retrieval/
!python /content/legal-qa-retrieval/traditional_ml/infer_private_test.py --infer_threshold 0.06 --infer_top_k=1

/content/legal-qa-retrieval
Building Features: 100% 8800/8800 [11:09<00:00, 13.15it/s]
Predicting ... 
Predict done, start record result
