In [None]:
import numpy as np
import polars as pl

from transformers.pipelines import AutoTokenizer, AutoModel
import torch

In [None]:
print(f"Доступно GPU: {torch.cuda.device_count()}")
device = 0 if torch.cuda.is_available() else -1
device

In [None]:
model = AutoModel.from_pretrained(
    "answerdotai/ModernBERT-base",
    torch_dtype=torch.bfloat16,  # Используем float16 для ускорения
    device_map="auto",
    attn_implementation="flash_attention_2"  # Ускорение внимания
).eval()

tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base", use_fast=True)

def batch_embed(texts, batch_size=32):
    embeddings = []

    with torch.inference_mode():
        for i in range(0, len(texts), batch_size):
            inputs = tokenizer(
                texts[i:i+batch_size],
                padding="longest",
                truncation=True,
                max_length=8192,
                return_tensors="pt"
            ).to(model.device)

            with torch.cuda.amp.autocast():
                outputs = model(**inputs)

            attn_mask = inputs.attention_mask.unsqueeze(-1)
            emb = torch.sum(outputs.last_hidden_state * attn_mask, dim=1) / torch.clamp(attn_mask.sum(dim=1), min=1)
            embeddings.append(emb.half().cpu())  # float16 для экономии памяти
            torch.cuda.empty_cache()

    result = torch.cat(embeddings).numpy().reshape(len(texts), 768)
    torch.cuda.empty_cache()
    return result

In [None]:
lazy_df = pl.scan_parquet("../data/preprocessed/articles.parquet")
total = lazy_df.select(pl.count()).collect().item()
threshold = 1.0 if total <= 10000 else 10000 / total

lazy_df = lazy_df.with_columns(
    pl.struct([pl.col("assets")]).map_elements(
        lambda _: np.random.rand(),
        skip_nulls=False,
        return_dtype=pl.Float64()
    ).alias("random")
)
sampled_lazy_df = lazy_df.filter(pl.col("random") < threshold).select(["text", "datetime"])

data = sampled_lazy_df.collect().to_pandas()
texts = data.text.to_list()

In [None]:
embeddings = batch_embed(texts, batch_size=32)
# 20.3 без очистк кэша | 10.8 с очисткой памяти | 10.5 с токенезатором на Rust | 10.6 с MixedPrecision | 10.6 с синхронизацией | Было 13 минут на примерно 10900

In [None]:
np.save("embeddings.npy", embeddings)
data.to_csv("data.csv", index=False)

In [None]:
torch.cuda.empty_cache()