In [1]:
import re
from tqdm import tqdm
from dataclasses import dataclass

import nltk
import torch
from sentence_transformers import SentenceTransformer

In [2]:
class ChunkSplitter:

    @staticmethod
    def split_to_sentence(text:str, at_least_length:int=10):
        sentences = nltk.tokenize.sent_tokenize(text)
        sentences = [sent for sent in sentences if len(sent) > at_least_length]
        return sentences

    @staticmethod
    def split_to_paragraph(text:str, at_least_length:int=10):
        paragraphs = re.split(r"\n\n|\n", text)
        paragraphs = [p for p in paragraphs if len(p) > at_least_length]
        return paragraphs


@dataclass
class SplitGranularity:
    paragraph :str = ChunkSplitter.split_to_paragraph
    sentence  :str = ChunkSplitter.split_to_sentence


class SummaCImagerForEMB:
    def __init__(self,
                 model_name_or_path:str="tals/albert-xlarge-vitaminc-mnli",
                 document_granularity:SplitGranularity="sentence",
                 summary_granularity:SplitGranularity="sentence",
                 max_doc_sents=100,
                 device="cpu"):

        self.device = device
        self.model_name_or_path = model_name_or_path
        self.load_emb()

        self.entailment_idx = 0
        self.contradiction_idx = 1
        self.neutral_idx = 2
        self.channel = sum(x is not None for x in [self.entailment_idx, self.contradiction_idx, self.neutral_idx])

        self.document_granularity = document_granularity
        self.document_splitter = getattr(SplitGranularity, self.document_granularity)

        self.summary_granularity = summary_granularity
        self.summary_splitter = getattr(SplitGranularity, summary_granularity)

        self.max_doc_sents = max_doc_sents
        self.max_input_length = 500


    def load_emb(self):
        self.model = SentenceTransformer(self.model_name_or_path)
        self.model.to(self.device)
        if self.device == "cuda":
            self.model.half()


    def create_pair_dataset(self, document_chunks:list[str], summary_chunks:list[str]):

        def count_generator(i=0):
            while True:
                yield i
                i += 1
        counter = count_generator(0)

        return [
            {
                'document': document_chunks[i],
                'summary': summary_chunks[j],
                'document_idx': i,
                'summary_idx': j,
                'pair_idx': next(counter),
            }
            for i in range(len(document_chunks))
            for j in range(len(summary_chunks))
        ]


    def build_image(self, document:str, summary:str, batch_size=4, return_dict_or_matrix:str='dict'):
        document_chunks = self.document_splitter(document)
        summary_chunks = self.summary_splitter(summary)
        pair_dataset = self.create_pair_dataset(document_chunks, summary_chunks)

        def batch_generator(dataset:list[dict], batch_size:str=20):

            dataset_size = len(dataset)
            chunk_size = (dataset_size // batch_size) + bool((dataset_size % batch_size))

            for i in range(chunk_size):
                batch = dataset[batch_size*i:batch_size*(i+1)]
                # pair_data = [(data['document'], data['summary']) for data in batch]

                yield [data['document'] for data in batch], [data['summary'] for data in batch]


        document_embs, summary_embs = [], []
        for i, batch in enumerate(tqdm(batch_generator(pair_dataset, batch_size=batch_size), desc='Processing')):
            document_batch, summary_batch = batch
            with torch.no_grad():
                document_emb = self.model.encode(document_batch, normalize_embeddings=True)
                summary_emb = self.model.encode(summary_batch, normalize_embeddings=True)
                # similarity = embs1 @ embs2.T

        #     if i:
        #         probs = torch.cat((probs, torch.nn.functional.softmax(model_outputs["logits"], dim=-1)))
        #     else:
        #         probs = torch.nn.functional.softmax(model_outputs["logits"], dim=-1)

        # if return_dict_or_matrix == 'matrix':
        #     _shape = (self.channel, document_chunks.__len__(), summary_chunks.__len__())
        #     image = probs[:, [self.entailment_idx, self.contradiction_idx, self.neutral_idx]].numpy().T.reshape(_shape)
        # elif return_dict_or_matrix == 'dict':
        #     _shape = (document_chunks.__len__(), summary_chunks.__len__())
        #     image = {
        #         'entailment': probs[:, self.entailment_idx].numpy().T.reshape(_shape),
        #         'contradiction': probs[:, self.contradiction_idx].numpy().T.reshape(_shape),
        #         'neutral': probs[:, self.neutral_idx].numpy().T.reshape(_shape),
        #     }
        # else:
        #     raise ValueError(f'Invalid value \'return_dict_or_matrix\'={return_dict_or_matrix}')

        # return probs, image
        return document_emb, summary_emb

In [3]:
document = """Scientists are studying Mars to learn about the Red Planet and find landing sites for future missions.
One possible site, known as Arcadia Planitia, is covered instrange sinuous features.
The shapes could be signs that the area is actually made of glaciers, which are large masses of slow-moving ice.
Arcadia Planitia is in Mars' northern lowlands."""
summary = "There are strange shape patterns on Arcadia Planitia. The shapes could indicate the area might be made of glaciers. This makes Arcadia Planitia ideal for future missions."
summary2 = "There are strange shape patterns on Arcadia Planitia. The shapes could indicate the area might be made of glaciers."

model_name_or_path = "DMetaSoul/Dmeta-embedding-zh"
imager = SummaCImagerForEMB(model_name_or_path)
# batch, model_outputs, batch_probs, image = imager.build_image(document, summary, batch_size=3)
document_emb, summary_emb = imager.build_image(document, summary, batch_size=8, return_dict_or_matrix='dict')
document_emb, summary_emb

Processing: 2it [00:01,  1.65it/s]


(array([[ 0.00809112,  0.03825084,  0.00626019, ...,  0.01670528,
         -0.06983218,  0.00651465],
        [ 0.00821151,  0.02340669, -0.00473241, ..., -0.03436563,
         -0.03132045, -0.01498133],
        [ 0.00821151,  0.02340669, -0.00473241, ..., -0.03436563,
         -0.03132045, -0.01498133],
        [ 0.00821151,  0.02340669, -0.00473241, ..., -0.03436563,
         -0.03132045, -0.01498133]], dtype=float32),
 array([[ 0.00405655,  0.02207707, -0.00563519, ...,  0.01684863,
         -0.08713452, -0.01289776],
        [ 0.00626215,  0.03595928,  0.00767179, ..., -0.01445544,
         -0.0352744 ,  0.00671129],
        [ 0.00926525,  0.0548636 ,  0.00087204, ...,  0.0018177 ,
         -0.06174414,  0.0076586 ],
        [ 0.00405655,  0.02207707, -0.00563519, ...,  0.01684863,
         -0.08713452, -0.01289776]], dtype=float32))

In [4]:
document_emb @ summary_emb.T

array([[0.59722584, 0.6904497 , 0.88714266, 0.59722584],
       [0.6291899 , 0.5766112 , 0.53290033, 0.6291899 ],
       [0.6291899 , 0.5766112 , 0.53290033, 0.6291899 ],
       [0.6291899 , 0.5766112 , 0.53290033, 0.6291899 ]], dtype=float32)

In [6]:
document_emb.shape

(4, 768)