In [9]:
import torch
import tqdm
# !pip install polars
import polars as pl
from transformers import DistilBertTokenizer, DistilBertModel
from torch.utils.data import Dataset, DataLoader
# !mkdir -p ./data/lvl1_data

In [10]:
model_name = 'distilbert-base-cased'

# Load tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
model = DistilBertModel.from_pretrained(model_name)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/263M [00:00<?, ?B/s]

In [13]:
batch_size = 128

# Create a dataset class for the parquet data
class TextDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        item = self.df.row(idx, named=True)
        return item['item_id'], item['text']

# Set up dataset and dataloader
df = pl.read_parquet("./data/lvl1_data/items.parquet")
dataset = TextDataset(df)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
embs_df = None

# Set model to evaluation mode
model.eval()

# Process each batch
for batch_idx, batch in enumerate(tqdm.tqdm(dataloader)):
    # print(batch)
    batch_ids, batch_texts = batch

    # Tokenize the batch and move inputs to GPU
    inputs = tokenizer(list(batch_texts), padding=True, truncation=True, return_tensors='pt', max_length=512).to(device)

    # Infer embeddings without computing gradients
    with torch.no_grad():
        outputs = model(**inputs)

        embeddings = outputs.last_hidden_state[:, 0, :]

        embeddings = embeddings.cpu().numpy()
        embeddings_asll = [emb.tolist() for emb in embeddings]

        batch_dicts = [{"item_id": i, "cls_emb":e} for (i,e) in zip(batch_ids, embeddings_asll)]

        if embs_df is None:
          embs_df = pl.from_dicts(batch_dicts, schema_overrides={"item_id": pl.Int64, "cls_emb": pl.List(pl.Float32)})
        else:
          embs_df = pl.concat([embs_df, pl.from_dicts(batch_dicts, schema_overrides={"item_id": pl.Int64, "cls_emb": pl.List(pl.Float32)})])
        if batch_idx > 4193:
          break

  5%|▍         | 603/12579 [26:04<8:42:22,  2.62s/it]

In [11]:
embs_df

item_id,roberta_emb
i64,list[f32]
0,"[0.151659, 0.118043, … -0.304602]"
1,"[0.136803, 0.0675, … -0.302022]"


In [None]:
embs_df
from google.colab import files
embs_df.write_parquet("./bert_item_infers_pt1.parquet")
files.download("./bert_item_infers_pt1.parquet")
len(embs_df)

In [None]:
!mkdir -p ./data/lvl2_data
df.join(embs_df, on = "item_id", how = "left").write_parquet("./data/lvl2_data/items_pt1.parquet")

In [None]:
files.download("./data/lvl2_data/items_pt1.parquet")