In [None]:
from collections.abc import Iterable, Sequence
from datetime import date
from functools import partial
from itertools import islice
from typing import Any, Literal, Optional, TypeVar

import numpy as np
import torch
from torch import Tensor
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from transformers.modeling_outputs import ModelOutput

import seb
from seb.interfaces.model import Encoder, LazyLoadEncoder, ModelMeta, SebModel
from seb.interfaces.task import Task
from seb.registries import models
import numpy as np

import pickle

#import logging
#logger = logging.getLogger(__name__)


T = TypeVar("T")
EncodeTypes = Literal["query", "passage"]


def batched(iterable: Iterable[T], n: int) -> Iterable[tuple[T, ...]]:
    # batched('ABCDEFG', 3) --> ABC DEF G
    if n < 1:
        raise ValueError("n must be at least one")
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch


def batch_to_device(batch_data: dict[str, torch.Tensor], device: str = "cuda") -> dict[str, torch.Tensor]:
    return {key: data.to(device) for key, data in batch_data.items()}


def task_to_instruction(task: Task) -> str:
    if task.task_type in ["STS"]:
        return "Retrieve semantically similar text"
    if task.task_type in ["Summarization"]:
        return "Given a news summary, retrieve other semantically similar summaries"
    if task.task_type in ["BitextMining"]:
        task_name_to_instruct: dict[str, str] = {
            "Bornholm Parallel": "Retrieve parallel sentences in Danish and Bornholmsk",
            "Norwegian courts": "Retrieve parallel sentences in Norwegian Bokmål and Nynorsk",
        }
        default_instruction = "Retrieve parallel sentences."
        return task_name_to_instruct.get(task.name, default_instruction)
    if task.task_type in ["Classification"]:
        task_name_to_instruct: dict[str, str] = {
            "Angry Tweets": "Classify Danish tweets by sentiment. (positive, negative, neutral)",
            "DKHate": "Classify Danish tweets based on offensiveness (offensive, not offensive)",
            "Da Political Comments": "Classify Danish political comments for sentiment",
            "DaLAJ": "Classify texts based on linguistic acceptability in Swedish",
            "LCC": "Classify texts based on sentiment",
            "Language Identification": "Classify texts based on language",
            "Massive Intent": "Given a user utterance as query, find the user intents",
            "Massive Scenario": "Given a user utterance as query, find the user scenarios",
            "NoReC": "Classify Norwegian reviews by sentiment",
            "SweReC": "Classify Swedish reviews by sentiment",
            "Norwegian parliament": "Classify parliament speeches in Norwegian based on political affiliation",
            "ScaLA": "Classify passages in Scandinavian Languages based on linguistic acceptability",
        }
        default_instruction = "Classify user passages"
        return task_name_to_instruct.get(task.name, default_instruction)
    if task.task_type in ["Clustering"]:
        task_name_to_instruct: dict[str, str] = {
            "ArxivClusteringP2P": "Identify the main and secondary category of Arxiv papers based on the titles and abstracts",
            "VG Clustering": "Identify the categories (e.g. sports) of given articles in Norwegian",
            "SNL Clustering": "Identify categories in a Norwegian lexicon",
            "SwednClustering": "Identify news categories in Swedish passages",
        }
        default_instruction = "Identify categories in user passages"
        return task_name_to_instruct.get(task.name, default_instruction)
    if task.task_type in ["Reranking"]:
        return "Retrieve semantically similar passages."
    if task.task_type in ["Retrieval"]:
        task_name_to_instruct: dict[str, str] = {
            "Twitterhjerne": "Retrieve answers to questions asked in Danish tweets",
            "SwednRetrieval": "Given a Swedish news headline retrieve summaries or news articles",
            "TV2Nord Retrieval": "Given a summary of a Danish news article retrieve the corresponding news article",
            "DanFEVER": "Given a claim in Danish, retrieve documents that support the claim",
            "SNL Retrieval": "Given a lexicon headline in Norwegian, retrieve its article",
            "NorQuad": "Given a question in Norwegian, retrieve the answer from Wikipedia articles",
            "SweFAQ": "Retrieve answers given questions in Swedish",
            "ArguAna": "Given a claim, find documents that refute the claim",
            "ClimateFEVER": "Given a claim about climate change, retrieve documents that support or refute the claim",
        }
        default_instruction = "Retrieve text based on user query."
        return task_name_to_instruct.get(task.name, default_instruction)
    return ""


class E5Instruct(Encoder):
    def __init__(self, model_name: str, max_length: int, max_batch_size: Optional[int] = None, **kwargs: Any):
        #logger.info("Started loading e5 instruct model")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModel.from_pretrained(model_name, **kwargs)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(device)
        self.max_length = max_length
        self.max_batch_size = max_batch_size
        

    def preprocess(self, sentences: Sequence[str], instruction: str, encode_type: EncodeTypes) -> BatchEncoding:
        if encode_type == "query":
            sentences = [f"Instruction: {instruction}\nQuery: {sentence}" for sentence in sentences]

        batch_dict = self.tokenizer(
            sentences,  # type: ignore
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )

        return batch_dict.to(self.model.device)

    def get_embedding_from_output(self, output: ModelOutput, batch_dict: BatchEncoding) -> torch.Tensor:
        return self.average_pool(output.last_hidden_state, batch_dict["attention_mask"])  # type: ignore

    @staticmethod
    def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> Tensor:
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

    def encode(
        self,
        sentences: list[str],
        *,
        task: Optional[Task] = None,
        batch_size: int = 128,
        encode_type: EncodeTypes = "query",
        **kwargs: Any,  # noqa
    ) -> np.ndarray:
        if self.max_batch_size and batch_size > self.max_batch_size:
            batch_size = self.max_batch_size
        batched_embeddings = []
        if task is not None:  # noqa
            instruction = task_to_instruction(task)
        else:
            instruction = ""
            
        for batch in tqdm(batched(sentences, batch_size)):
            with torch.inference_mode():
                batch_dict = self.preprocess(batch, instruction=instruction, encode_type=encode_type)
                outputs = self.model(**batch_dict)
                embeddings = self.get_embedding_from_output(outputs, batch_dict)
            batched_embeddings.append(embeddings.detach().cpu())

        return torch.cat(batched_embeddings).to("cpu").detach().numpy()

    def encode_corpus(self, corpus: list[dict[str, str]], **kwargs: Any) -> np.ndarray:
        sep = " "
        if isinstance(corpus, dict):
            sentences = [
                (corpus["title"][i] + sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip()  # type: ignore
                for i in range(len(corpus["text"]))  # type: ignore
            ]
        else:
            sentences = [(doc["title"] + sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
        return self.encode(sentences, encode_type="passage", **kwargs)

    def encode_queries(self, queries: list[str], **kwargs: Any) -> np.ndarray:
        return self.encode(queries, encode_type="query", **kwargs)


class Llama8BInstruct(E5Instruct):
    def __init__(self):
        super().__init__("meta-llama/Meta-Llama-3-8B-Instruct", max_length=512, max_batch_size=32, torch_dtype=torch.float16)

    @staticmethod
    def last_token_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0]
        if left_padding:
            return last_hidden_states[:, -1]
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[
            torch.arange(batch_size, device=last_hidden_states.device),
            sequence_lengths,
        ]

    def get_embedding_from_output(self, output: ModelOutput, batch_dict: BatchEncoding) -> torch.Tensor:
        return self.last_token_pool(output.last_hidden_state, batch_dict["attention_mask"])  # type: ignore

    def preprocess(self, sentences: Sequence[str], instruction: str, encode_type: EncodeTypes) -> BatchEncoding:
        if encode_type == "query":
            sentences = [f"Instruction: {instruction}\nQuery: {sentence}" for sentence in sentences]
        
        batch_dict: BatchEncoding = self.tokenizer(
            sentences,  # type: ignore
            max_length=self.max_length - 1,
            return_attention_mask=False,
            padding=False,
            truncation=True,
        )

        # append eos_token_id to every input_ids
        batch_dict["input_ids"] = [
            [*input_ids, self.tokenizer.eos_token_id]
            for input_ids in batch_dict["input_ids"]  # type: ignore
        ]
        batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors="pt")

        return batch_dict.to(self.model.device)


@models.register("llama-8b-instruct")
def create_llama_8b_instruct() -> SebModel:
    hf_name = "meta-llama/Meta-Llama-3-8B-Instruct"
    meta = ModelMeta(
        name=hf_name.split("/")[-1],
        huggingface_name=hf_name,
        reference=f"https://huggingface.co/{hf_name}",
        languages=[],
        open_source=True,
        embedding_size=4096,
        architecture="Llama",
        release_date=date(2023, 12, 20),
    )
    return SebModel(
        encoder=LazyLoadEncoder(Llama8BInstruct),
        meta=meta,
    )

model_name = "llama-8b-instruct"
#def run_benchmark():
models = [seb.get_model(model_name)]
benchmark = seb.Benchmark(languages=['da'])
results = benchmark.evaluate_models(models=models)
#Save pickle
with open(f"SEB_eval_{model_name}.pkl", 'wb') as f:
    pickle.dump(results, f)
avg_score = np.mean([res.get_main_score() for res in results])
print(f'\nAverage results: {avg_score}')
print(f'\nFull Results:\n{results}')

#if __name__ == "__main__":
    #run_benchmark()

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.49%       1.894ms        48.21%     184.645ms     721.268us       0.000us         0.00%     100.307ms     391.823us           0 b           0 b       4.34 Gb           0 



-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.07%     977.095us        10.66%     143.343ms     559.935us       0.000us         0.00%     988.227ms       3.860ms           0 b           0 b      42.00 Gb           0 



-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.09%       1.122ms        24.96%     324.276ms       1.267ms       0.000us         0.00%        1.000s       3.906ms           0 b           0 b      42.00 Gb           0 



-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.28%       1.475ms        19.26%      99.952ms     390.438us       0.000us         0.00%     403.140ms       1.575ms           0 b           0 b      16.50 Gb           0 

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         0.85%     940.865us         6.74%       7.467ms      29.166us       0.000us         0.00%      88.664ms     346.345us           0 b           0 b       3.61 Gb           0 

Running Meta-Llama-3-8B-Instruct:   0%|          | 0/1 [00:36<?, ?it/s]

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                           aten::matmul         1.28%       1.085ms        15.85%      13.485ms      52.677us       0.000us         0.00%      65.930ms     257.537us           0 b           0 b       2.63 Gb           0 




ValueError: Found input variables with inconsistent numbers of samples: [32, 476]

In [2]:
from typing import List, Optional, Dict, Any, Union
import numpy as np
import torch
from torch import Tensor, nn
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from tqdm import tqdm

class OptimizedLlama8BInstruct(Encoder):
    def __init__(self, model_name: str, pooling_mode: str = "mean", max_length: int = 512, skip_instruction: bool = True):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.max_length = max_length
        self.pooling_mode = pooling_mode
        self.skip_instruction = skip_instruction

        # Handle PAD token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

    def preprocess(
        self, 
        sentences: List[str], 
        instruction: str = "", 
        encode_type: str = "query"
    ) -> BatchEncoding:
        # Attach instruction dynamically
        if encode_type == "query" and instruction:
            sentences = [f"Instruction: {instruction}\nQuery: {sentence}" for sentence in sentences]
        
        # Tokenize sentences in a single call
        tokens = self.tokenizer(
            sentences,
            max_length=self.max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        return tokens

    def forward(self, batch_encoding: BatchEncoding) -> Tensor:
        # Perform the forward pass and return embeddings
        outputs = self.model(**batch_encoding)
        return self.pool_embeddings(outputs.last_hidden_state, batch_encoding["attention_mask"])

    def pool_embeddings(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
        # Handle pooling based on the chosen strategy
        if self.pooling_mode == "mean":
            # Mask unwanted tokens and compute mean
            mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
            summed = torch.sum(hidden_states * mask_expanded, dim=1)
            summed_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)  # Avoid divide-by-zero
            return summed / summed_mask

        elif self.pooling_mode == "eos_token":
            # Return the last hidden states for tokens
            eos_positions = attention_mask.sum(dim=1) - 1
            return hidden_states[torch.arange(hidden_states.size(0)), eos_positions]

        elif self.pooling_mode == "last_token":
            # Simply take the last hidden state
            return hidden_states[:, -1]

        else:
            raise ValueError(f"Pooling mode '{self.pooling_mode}' is not recognized")

    def encode(
        self,
        sentences: List[str],
        batch_size: int = 32,
        encode_type: str = "query",
        instruction: str = "",
        show_progress_bar: bool = True,
        device: str = "cuda",
    ) -> np.ndarray:
        self.model.to(device)
        self.eval()

        all_embeddings = []
        
        with torch.no_grad():
            for batch in tqdm([sentences[i:i + batch_size] for i in range(0, len(sentences), batch_size)], disable=not show_progress_bar):
                batch_tokens = self.preprocess(batch, instruction=instruction, encode_type=encode_type)
                batch_tokens = {k: v.to(device) for k, v in batch_tokens.items()}  # Move to GPU
                embeddings = self.forward(batch_tokens)
                all_embeddings.append(embeddings.cpu())  # Move to CPU after computation

        return torch.cat(all_embeddings, dim=0).numpy()

    def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray:
        return self.encode(queries, encode_type="query", **kwargs)

    def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray:
        sentences = [
            (doc["title"] + " " + doc["text"]).strip() if "title" in doc else doc["text"].strip()
            for doc in corpus
        ]
        return self.encode(sentences, encode_type="passage", **kwargs)



In [3]:
@models.register("llama-8b-instruct-optimized")
def create_llama_8b_instruct() -> SebModel:
    hf_name = "meta-llama/Meta-Llama-3-8B-Instruct"
    meta = ModelMeta(
        name=hf_name.split("/")[-1],
        huggingface_name=hf_name,
        reference=f"https://huggingface.co/{hf_name}",
        languages=[],
        open_source=True,
        embedding_size=4096,
        architecture="Llama",
        release_date=date(2023, 12, 20),
    )
    return SebModel(
        encoder=LazyLoadEncoder(OptimizedLlama8BInstruct),
        meta=meta,
    )

AttributeError: 'list' object has no attribute 'register'

In [None]:
model_name = "llama-8b-instruct"
#def run_benchmark():
models = [seb.get_model(model_name)]
benchmark = seb.Benchmark(languages=['da'])
results = benchmark.evaluate_models(models=models)
#Save pickle
with open(f"SEB_eval_{model_name}.pkl", 'wb') as f:
    pickle.dump(results, f)
avg_score = np.mean([res.get_main_score() for res in results])
print(f'\nAverage results: {avg_score}')
print(f'\nFull Results:\n{results}')
