In [4]:
import codecs
import typing
from itertools import islice

import spacy
import os
os.environ["TOKENIZERS_PARALLELISM"] = "true"
from tqdm import tqdm

FILENAME = "../split.tsv"
OUTPUT_FILENAME = "../embeddings.tsv"
# SPACY_MODEL = "../data/v3-model.cpu/model-best"


In [5]:

def line_count(filename):
    with open(filename, "r", encoding="utf-8") as f:
        return len(f.readlines())
    

class Paragraph:
    __slots__ = ['id', 'content', 'embedding', 'entities']

    def __init__(self, id: str, content: str, entities: str | None):
        self.id = id
        self.content = content
        self.entities = entities


def iterate_passages(lines: typing.Iterable[str]) -> typing.Iterable[Paragraph]:
    for line in lines:
        a = line.strip().split('\t', 2)
        yield Paragraph(a[0], a[1], a[2] if len(a) > 2 else None)


def split_every(n, iterable):
    i = iter(iterable)
    piece = list(islice(i, n))
    while piece:
        yield piece
        piece = list(islice(i, n))


def calculate_passages(n: int, client, lines: typing.Iterable[str]) -> typing.Iterable[Paragraph]:
    for group in split_every(n, iterate_passages(lines)):
        content = [passage.content for passage in group]
        embeddings = client.encode_embeddings(content)
        for passage, embedding in zip(group, embeddings):
            passage.embedding = embedding
            yield passage
    

In [6]:
import base64
from chategw import AiClient

total_count = line_count(FILENAME)
client = AiClient(1024)

f = codecs.open(FILENAME, 'r', encoding='utf-8')
out_f = codecs.open(OUTPUT_FILENAME, 'w', buffering=True, encoding='utf-8')

for passage in calculate_passages(10_000, client, tqdm(f, total=total_count, desc="Embedding", mininterval=2)):
    b = passage.embedding.tobytes()
    bytes_embedding = base64.b64encode(b).decode('utf-8')
    out_f.write(passage.id)
    out_f.write('\t')
    out_f.write(passage.content)
    out_f.write('\t')
    out_f.write(bytes_embedding)
    if passage.entities:
        out_f.write('\t')
        out_f.write(passage.entities)
    out_f.write('\n')

f.close()
out_f.close()

Embedding: 100%|██████████| 1920774/1920774 [21:12<00:00, 1509.43it/s]
