In [9]:
%load_ext autoreload
%autoreload 2

from dotenv import load_dotenv

load_dotenv("../.env")

True

In [22]:
import mlflow
import onnx
import onnxruntime as ort
import sentence_transformers
import torch
import torch.nn.functional as F
from sentence_transformers import SentenceTransformer
from transformers import AutoModel, AutoTokenizer

from such_toxic.text_classifier import TextClassifier
from such_toxic.util import expand, mat_mul, mat_sum, shape, unsqueeze

In [11]:
model_uri = f"runs:/5261022518c9417692ab0d3315ffb9e0/such-toxic"
sentence_transformer_model = "sentence-transformers/all-MiniLM-L6-v2"
output_sentence_transformer_model = "../target/st-all-MiniLM-L6-v2.onnx"
output_such_toxic_model = "../target/such-toxic.onnx"

tokenizer = AutoTokenizer.from_pretrained(sentence_transformer_model)
st_model = AutoModel.from_pretrained(sentence_transformer_model)

In [28]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[
        0
    ]  # First element of model_output contains all token embeddings
    print("token_embeddings: ", token_embeddings.shape)
    print("attention_mask: ", attention_mask.shape)

    unsqueezed_attention_mask = attention_mask.unsqueeze(-1)
    print("unsqueezed_attention_mask: ", unsqueezed_attention_mask.shape)

    input_mask_expanded = unsqueezed_attention_mask.expand(
        token_embeddings.size()
    ).float()

    print("input_mask_expanded: ", input_mask_expanded.shape)
    print("token_x_input_mask: ", (token_embeddings * input_mask_expanded).shape)

    s = torch.sum(token_embeddings * input_mask_expanded, 1)
    print("s: ", s.shape)
    return s / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


st_input = tokenizer(
    ["This is a sample"],
    padding=True,
    truncation=True,
    return_tensors="pt",
)

with torch.no_grad():
    st_output = st_model(**st_input)

st_embedding = mean_pooling(st_output, st_input["attention_mask"])
st_embedding = F.normalize(st_embedding, p=2, dim=1)
st_embedding.shape

token_embeddings:  torch.Size([1, 6, 384])
attention_mask:  torch.Size([1, 6])
unsqueezed_attention_mask:  torch.Size([1, 6, 1])
input_mask_expanded:  torch.Size([1, 6, 384])
token_x_input_mask:  torch.Size([1, 6, 384])
s:  torch.Size([1, 384])
s:  torch.Size([1, 384])


torch.Size([1, 384])

In [13]:
torch.onnx.export(
    st_model,
    (st_input["input_ids"], st_input["attention_mask"]),
    output_sentence_transformer_model,
    input_names=["input_ids", "attention_mask"],
    output_names=["output"],
    dynamic_axes={
        "input_ids": {0: "batch_size", 1: "sequence"},
        "attention_mask": {0: "batch_size", 1: "sequence"},
    },
    do_constant_folding=True,
    opset_version=13,
)

In [14]:
such_toxic = mlflow.pytorch.load_model(model_uri, map_location="cpu")
such_toxic(st_embedding)



tensor([[2.1717e-03, 3.3784e-04, 5.4082e-04, 5.4006e-05, 5.1951e-04, 1.3210e-04]],
       grad_fn=<SigmoidBackward0>)

In [15]:
torch.onnx.export(
    such_toxic,
    st_embedding,
    output_such_toxic_model,
    input_names=["embeddings"],
    output_names=["output"],
    dynamic_axes={"embeddings": {0: "batch_size"}},
    do_constant_folding=True,
    opset_version=13,
)

In [27]:
def s_mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0].tolist()
    print("token_embeddings: ", shape(token_embeddings))
    attention_mask = attention_mask.tolist()
    # print("attention_mask: ", shape(attention_mask))

    unsqueezed_attention_mask = unsqueeze(
        attention_mask,
        axis=len(shape(attention_mask)),
    )
    print("unsqueezed_attention_mask: ", shape(unsqueezed_attention_mask))

    input_mask_expanded = expand(unsqueezed_attention_mask, shape(token_embeddings))
    print("input_mask_expanded: ", shape(input_mask_expanded))

    token_embeddings_x_input_mask = mat_mul(token_embeddings, input_mask_expanded)
    print("token_embeddings_x_input_mask: ", shape(token_embeddings_x_input_mask))
    masked_sum = mat_sum(token_embeddings_x_input_mask, dim=1)
    print("masked_sum: ", shape(masked_sum))

    # s =  torch.sum(token_embeddings * input_mask_expanded, 1)
    # print("s: ", s.shape)
    # return s / torch.clamp(
    #     input_mask_expanded.sum(1), min=1e-9
    # )

    # # # Masked mean with division by non-zero count
    # masked_mean = [
    #     sum / (sum(mask) + 1e-9) for sum, mask in zip(masked_sum, attention_mask)
    # ]
    # print(masked_mean)
    # return masked_mean


st_onnx_model = onnx.load(output_sentence_transformer_model)
onnx.checker.check_model(st_onnx_model)

stoxic_onnx_model = onnx.load(output_such_toxic_model)
onnx.checker.check_model(stoxic_onnx_model)

st_session = ort.InferenceSession(output_sentence_transformer_model)
st_output = st_session.run(
    None,
    {
        "input_ids": st_input["input_ids"].numpy(),
        "attention_mask": st_input["attention_mask"].numpy(),
    },
)

st_embedding = s_mean_pooling(st_output, st_input["attention_mask"])
# print(st_embedding)

# st_embedding = F.normalize(st_embedding, p=2, dim=1)


# stoxic_session = ort.InferenceSession(output_such_toxic_model)
# stoxic_output = stoxic_session.run(None, {"embeddings": st_embedding.numpy()})

# print("Toxic: ", stoxic_output[0][0][0])
# print("Severe Toxic: ", stoxic_output[0])

token_embeddings:  [1, 6, 384]
unsqueezed_attention_mask:  [1, 6, 1]
input_mask_expanded:  [1, 6, 384]
token_embeddings_x_input_mask:  [1, 6, 384]
masked_sum:  [1]


In [6]:
import torch

list(torch.tensor([1, 2, 3]).shape)

[3]