Notes,

* [Hugging Face model card](https://huggingface.co/intfloat/e5-mistral-7b-instruct) recommends last token pooling and then normalize the embedding. But average pooling without embedding normalization chosen following [model's source code](https://github.com/microsoft/unilm/tree/master/e5) for running MTEB.
* NF4 and double quantization were used following explanation and result on paper [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314).

In [1]:
# disable persistant warning shown by tokenizers
%set_env TOKENIZERS_PARALLELISM=false

env: TOKENIZERS_PARALLELISM=false


In [2]:
import gc
import torch
import numpy as np
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
)
from sklearn.multiclass import OneVsRestClassifier
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
import torch.nn.functional as F
from torch import Tensor
from tqdm import tqdm

In [3]:
SEED = 42
MODEL_NAME = 'intfloat/e5-mistral-7b-instruct'
BATCH_SIZE = 8
INSTRUCTION = 'Classify the aspect mentioned in the given Steam Review into up to of the eight aspects: recommended, story, gameplay, visual, audio, technical, price, and suggestion.'  # This mimic paper's string instruction

In [4]:
df_train = pd.read_csv('../../dataset/v1/train.csv')
df_test = pd.read_csv('../../dataset/v1/test.csv')

labels = df_train.columns[3:].to_list()
y_train = df_train[labels].to_numpy()
y_test = df_test[labels].to_numpy()

In [5]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
)
max_length = 4096

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(
    MODEL_NAME,
    quantization_config=quantization_config,
)
model.eval()

# # off-load to CPU without quantization
# model = AutoModel.from_pretrained(
#     MODEL_NAME,
#     device_map='auto',
#     offload_folder='offload',
#     torch_dtype=torch.bfloat16
# )

`low_cpu_mem_usage` was None, now set to True since model is quantized.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralModel(
  (embed_tokens): Embedding(32000, 4096, padding_idx=2)
  (layers): ModuleList(
    (0-31): 32 x MistralDecoderLayer(
      (self_attn): MistralSdpaAttention(
        (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): MistralRotaryEmbedding()
      )
      (mlp): MistralMLP(
        (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
        (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
        (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): MistralRMSNorm()
      (post_attention_layernorm): MistralRMSNorm()
    )
  )
  (norm): MistralRMSNorm()
)

In [6]:
def evaluate(X_train, y_train, X_test, y_test, labels):
    clf = LogisticRegression(
        random_state=SEED,
        max_iter=100
    )
    ovr = OneVsRestClassifier(clf, n_jobs=-1)

    ovr.fit(X_train, y_train)
    y_pred = ovr.predict(X_test)

    accuracy = accuracy_score(y_test, y_pred)
    print(f'Overall accuracy: {accuracy}')
    for idx, label in enumerate(labels):
        label_accuracy = accuracy_score(y_test[:, idx], y_pred[:, idx])
        print(f'Accuracy {label}: {label_accuracy}')

    f1 = f1_score(y_test, y_pred, average='macro')
    print(f'F1 macro: {f1}')
    print(
        classification_report(y_test, y_pred, target_names=labels, digits=4, zero_division=0)
    )


def last_token_pool(
        last_hidden_states: Tensor,
        attention_mask: Tensor
    ) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device),
            sequence_lengths
        ]


def avg_pool(
        last_hidden_states: Tensor,
        attention_mask: Tensor,
    ) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    # handle few torch.inf and -torch.inf occurance
    mask_inf = emb == torch.inf
    mask_min_inf = emb == -torch.inf
    emb[mask_inf] = 0.0
    emb[mask_min_inf] = 0.0

    return emb

def get_text_embedding(df):
    X = np.zeros(shape=(df.shape[0], 4096), dtype=np.float32)

    for i in tqdm(range(0, df.shape[0], BATCH_SIZE)):
        batch_dict = tokenizer(
            [s + INSTRUCTION for s in df.iloc[i:i+BATCH_SIZE, 2]],
            max_length=max_length, padding=True, truncation=True, return_tensors='pt'
        )
        with torch.no_grad() and torch.inference_mode():
            outputs = model(**batch_dict)
            # embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
            embeddings = avg_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

        # embeddings = F.normalize(embeddings, p=2, dim=1)
        X[i:i+BATCH_SIZE, :] = embeddings.detach().cpu().float().numpy()
        
        # maybe clear stuff here
        del batch_dict, outputs, embeddings
        gc.collect()
        torch.cuda.empty_cache()
    return X

In [7]:
X_train = get_text_embedding(df_train)
X_test = get_text_embedding(df_test)

  0%|          | 0/113 [00:00<?, ?it/s]2024-06-04 21:53:34.665451: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
100%|██████████| 113/113 [12:35<00:00,  6.69s/it]
100%|██████████| 25/25 [02:38<00:00,  6.33s/it]


In [8]:
evaluate(X_train, y_train, X_test, y_test, labels)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver opt

Overall accuracy: 0.295
Accuracy label_recommended: 0.91
Accuracy label_story: 0.805
Accuracy label_gameplay: 0.905
Accuracy label_visual: 0.745
Accuracy label_audio: 0.84
Accuracy label_technical: 0.865
Accuracy label_price: 0.83
Accuracy label_suggestion: 0.875
F1 macro: 0.7136813015004153
                   precision    recall  f1-score   support

label_recommended     0.9392    0.9392    0.9392       148
      label_story     0.7604    0.8202    0.7892        89
   label_gameplay     0.9299    0.9481    0.9389       154
     label_visual     0.6957    0.7356    0.7151        87
      label_audio     0.7317    0.5882    0.6522        51
  label_technical     0.7586    0.7719    0.7652        57
      label_price     0.6857    0.5106    0.5854        47
 label_suggestion     0.3750    0.2857    0.3243        21

        micro avg     0.8180    0.8043    0.8111       654
        macro avg     0.7345    0.7000    0.7137       654
     weighted avg     0.8120    0.8043    0.8062       6