In [None]:
1

: 

In [1]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer
import datasets
from collections import defaultdict
from tqdm import tqdm
import json
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device('cuda:1')
device

device(type='cuda', index=1)

In [3]:
emb_model = AutoModel.from_pretrained(
    'google/embeddinggemma-300m',
    device_map=device,
    trust_remote_code=True,
).requires_grad_(False)
emb_tokenizer = AutoTokenizer.from_pretrained(
    'google/embeddinggemma-300m',
    trust_remote_code=True,
)

In [4]:
imdb = datasets.load_dataset('imdb')

In [None]:
def get_embeddings(loader: DataLoader) -> dict[str, torch.Tensor]:
    embeddings = defaultdict(list)

    for batch in tqdm(loader):
        inputs = emb_tokenizer(
            batch['text'], padding=True, truncation=True, max_length=512, return_tensors='pt'
        ).to(device)
        with torch.no_grad():
            outputs = emb_model(**inputs).last_hidden_state.cpu()

        attn_mask = inputs['attention_mask'].cpu()
        token_len = attn_mask.sum(axis=1)

        mask_expanded = attn_mask.unsqueeze(-1).expand_as(outputs)

        outputs_masked_sum = (outputs * mask_expanded).sum(dim=1)
        embeddings['mean'].append(outputs_masked_sum / token_len.unsqueeze(-1))

        row_indices = torch.arange(len(outputs))
        embeddings['last'].append(outputs[row_indices, token_len - 1])

        embeddings['label'].append(batch['label'])

    embeddings = {key: torch.cat(val) for key, val in embeddings.items()}
    return embeddings

In [None]:
def get_and_save_embeddings(name: str, loader: DataLoader):
    embeddings = get_embeddings(loader)

    data = {key: value.tolist() for key, value in embeddings.items()}
    full_path = f'data/gemma_{name}.json'
    with open(full_path, 'w') as out:
        json.dump(data, out)
        print(
            f'Written {len(data)} tables of shape {next(iter(embeddings.values())).shape} '
            f'into {full_path}'
        )

In [7]:
train_loader = DataLoader(imdb['train'], batch_size=256)  # type: ignore
test_loader = DataLoader(imdb['test'], batch_size=256)  # type: ignore

In [8]:
get_and_save_embeddings('train', train_loader)
get_and_save_embeddings('test', test_loader)

100%|██████████| 98/98 [07:53<00:00,  4.83s/it]


Written 3 tables of shape torch.Size([25000, 768]) into data/gemma_train.json


100%|██████████| 98/98 [07:53<00:00,  4.83s/it]


Written 3 tables of shape torch.Size([25000, 768]) into data/gemma_test.json


In [None]:
def random_crop(item: dict, rng: random.Random, min_len=10, max_len=100):
    text = item['text']
    words = text.split()
    if len(words) > min_len:
        segment_length = rng.randint(min_len, min(max_len, len(words)))
        max_start = len(words) - segment_length
        start_idx = rng.randint(0, max_start) if max_start > 0 else 0
        text = ' '.join(words[start_idx : start_idx + segment_length])
    return {'text': text}

In [10]:
rng = random.Random(42)
train_random_crop = imdb['train'].map(random_crop, fn_kwargs={'rng': rng, 'max_len': 50})
test_random_crop = imdb['test'].map(random_crop, fn_kwargs={'rng': rng, 'max_len': 50})

train_random_crop_loader = DataLoader(train_random_crop, batch_size=256)  # type: ignore
test_random_crop_loader = DataLoader(test_random_crop, batch_size=256)  # type: ignore

Map: 100%|██████████| 25000/25000 [00:01<00:00, 15450.18 examples/s]
Map: 100%|██████████| 25000/25000 [00:01<00:00, 15568.01 examples/s]


In [11]:
get_and_save_embeddings('train_random_crop_10_50', train_random_crop_loader)
get_and_save_embeddings('test_random_crop_10_50', test_random_crop_loader)

100%|██████████| 98/98 [01:08<00:00,  1.44it/s]


Written 3 tables of shape torch.Size([25000, 768]) into data/gemma_train_random_crop_10_50.json


100%|██████████| 98/98 [01:08<00:00,  1.43it/s]


Written 3 tables of shape torch.Size([25000, 768]) into data/gemma_test_random_crop_10_50.json
