In [1]:
from functools import lru_cache
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_all_providers
import numpy as np
from transformers import AutoTokenizer

In [2]:
def create_model_for_provider(model_path: str, provider: str= 'CPUExecutionProvider') -> InferenceSession:
    assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"
    # Few properties that might have an impact on performances (provided by MS)
    options = SessionOptions()
    options.intra_op_num_threads = 4
    options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
    # Load the model as a graph and prepare the CPU backend
    session = InferenceSession(model_path, options, providers=[provider])
    session.disable_fallback()
    return session

@lru_cache(maxsize=2000)
def encode(x):
    encoded_input = tokenizer([x], padding=True, truncation=True)
    sess_out = sess.run(['output'], {
        "input_ids": np.array(encoded_input['input_ids'], dtype=np.int64),
        "attention_mask": np.array(encoded_input['attention_mask'], dtype=np.int64)
    })
    out = np.sum(sess_out[0], 1) / sess_out[0].shape[1]
    return out.tolist()[0]

@lru_cache(maxsize=2000)
def sim(a, b):
    va = encode(a)
    vb = encode(b)
    va, vb = np.array(va), np.array(vb)
    s = np.dot(va, vb) / (np.sqrt(np.sum(va ** 2)) * np.sqrt(np.sum(vb ** 2)))
    s = float(s)
    return s

In [3]:
sess = create_model_for_provider('./stsq.onnx')
tokenizer = AutoTokenizer.from_pretrained('./paraphrase-multilingual-mpnet-base-v2')

In [4]:
import pickle

from tqdm import tqdm

In [5]:
df = pickle.load(open('sim.pkl', 'rb'))

In [6]:
sims = []
for item in tqdm(df.itertuples(), total=len(df)):
    sims.append(sim(item.a, item.b))

100%|██████████| 8050/8050 [06:01<00:00, 22.25it/s]


In [7]:
from scipy.optimize import minimize

In [8]:
df['s'] = [float(x) for x in sims]
gdf = df.copy()

In [9]:
def loss(w):
    df = gdf.copy()
    df['s'] += w[0]
    df['s'] = df['s'].clip(0, 10)
    df['s'] *= w[1]
    min_s = df['s'].min()
    max_s = df['s'].max()
    return -df[['t', 's']].corr()['s']['t'] + (min_s) ** 2 + (max_s - 1) ** 2

In [10]:
ret = minimize(loss, [0.0, 1.0])

In [11]:
ret.x

array([-0.31587389,  1.46171649])

In [12]:
df['s'] = (df['s'] + ret.x[0]).clip(0, 10) * ret.x[1]

In [13]:
df[['t', 's']].corr()

Unnamed: 0,t,s
t,1.0,0.818625
s,0.818625,1.0


In [14]:
import json

In [15]:
with open('adjust.json', 'w') as fp:
    json.dump(ret.x.tolist(), fp)

In [16]:
df[df['t'] == 5]['s'].median()

0.9418801355474369

In [17]:
df[df['t'] == 4]['s'].median()

0.8539239518031861

In [18]:
df[df['t'] == 3]['s'].median()

0.7340277823073851

In [19]:
df[df['t'] == 2]['s'].median()

0.5605673928889083

In [20]:
df[df['t'] == 1]['s'].median()

0.34744278892796626

In [21]:
df[df['t'] == 0]['s'].median()

0.0