In [2]:
import torch
import numpy as np
import pandas as pd
from nlp_models.multi_task_model.mtl import MTLInference

HF_MODEL_CARD = 'sentence-transformers/multi-qa-mpnet-base-dot-v1'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
df_train = pd.read_csv('../data/0_external/google-quest-challenge/train.csv')
df_test = pd.read_csv('../data/0_external/google-quest-challenge/test.csv')

In [4]:
df_train.category.value_counts()

TECHNOLOGY       2441
STACKOVERFLOW    1253
CULTURE           963
SCIENCE           713
LIFE_ARTS         709
Name: category, dtype: int64

In [5]:
label_dict = dict([(k,v) for k, v in enumerate(df_train.category.unique())])
label_list = list(label_dict.values())

In [6]:
mtl_model = MTLInference(HF_MODEL_CARD, HF_MODEL_CARD, num_labels=len(label_dict), device=DEVICE)
labels_embedding = mtl_model.predict(label_list)
scores = [
    torch.mm(mtl_model.predict(input)[1], labels_embedding[1].transpose(0, 1)).cpu().tolist() \
        for input in df_train['question_title'].to_list()
    ]
scores = np.squeeze(scores)
max_scores = np.argmax(scores, 1)

summary = []
for i, s in enumerate(max_scores):
    row = df_train.iloc[i]
    pred = label_dict[s] == row.category
    summary.append(pred)
print(f'Prediction Accuracy: {sum(summary) / len(summary):.2%}')

Prediction Accuracy: 29.08%


In [7]:
df_test.category.value_counts()

TECHNOLOGY       204
STACKOVERFLOW    103
CULTURE           64
SCIENCE           58
LIFE_ARTS         47
Name: category, dtype: int64