In [1]:
from datasets import load_dataset

dataset = load_dataset("smangrul/amazon_esci")

Found cached dataset parquet (/raid/sourab/.cache/huggingface/datasets/smangrul___parquet/smangrul--amazon_esci-321288cabf0cc045/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['query', 'product_title', 'product_id', 'esci_label', 'split', 'relevance_label', '__index_level_0__'],
        num_rows: 839306
    })
    validation: Dataset({
        features: ['query', 'product_title', 'product_id', 'esci_label', 'split', 'relevance_label', '__index_level_0__'],
        num_rows: 363402
    })
})

In [3]:
column_names = dataset["train"].column_names
column_names

['query',
 'product_title',
 'product_id',
 'esci_label',
 'split',
 'relevance_label',
 '__index_level_0__']

In [4]:
from transformers import AutoTokenizer
model_name = "intfloat/e5-large-v2"

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
def preprocess_function(examples):
    queries = examples["query"]
    result = tokenizer(queries, padding="max_length", max_length=70, truncation=True)
    result = {f"query_{k}":v for k,v in result.items()}
    
    products = examples["product_title"]
    result_products = tokenizer(products, padding="max_length", max_length=70, truncation=True)
    for k, v in result_products.items():
        result[f"product_{k}"] = v
    
    result["labels"] = examples["relevance_label"]
    return result

processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset["train"].column_names,
    desc="Running tokenizer on dataset",
)

Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/smangrul___parquet/smangrul--amazon_esci-321288cabf0cc045/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-2a66a10e3de89ace.arrow
Loading cached processed dataset at /raid/sourab/.cache/huggingface/datasets/smangrul___parquet/smangrul--amazon_esci-321288cabf0cc045/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7/cache-c1218f26bf89f822.arrow


In [6]:
print(processed_datasets["train"][1])

{'query_input_ids': [101, 999, 22091, 2078, 5302, 13777, 13310, 2302, 11418, 2015, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'query_token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'query_attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'product_input_ids': [101, 4098, 4887, 3406, 1016, 1011, 5308, 2410, 2595, 2629, 1012, 4002, 1011, 1020, 1016, 22086, 14585, 9587, 13777, 16358, 12824, 2007, 3756, 11418, 1010, 1006, 1017, 1000, 8857, 9594, 1010, 1017, 1013, 1018, 1000, 5747, 8613, 1007, 102, 0,

In [7]:
import torch
from torch import nn
from transformers import AutoModel

class AutoModelForSentenceEmbedding(nn.Module):
    def __init__(self, model_name, tokenizer, normalize=True):
        super(AutoModelForSentenceEmbedding, self).__init__()

        self.model = AutoModel.from_pretrained(model_name) #, load_in_8bit=True, device_map={"":0})
        self.normalize = normalize
        self.tokenizer = tokenizer

    def forward(self, **kwargs):
        model_output = self.model(**kwargs)
        embeddings = self.mean_pooling(model_output, kwargs['attention_mask'])
        if self.normalize:
            embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)

        return embeddings

    def mean_pooling(self, model_output, attention_mask):
        token_embeddings = model_output[0]  # First element of model_output contains all token embeddings
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

    def save_pretrained(self, output_path):
        if xm.is_master_ordinal():
            self.tokenizer.save_pretrained(output_path)
            self.model.config.save_pretrained(output_path)

        xm.save(self.model.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
        
    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)
       



In [8]:
def get_cosing_embeddings(query_embs, product_embs):
    return torch.sum(query_embs*product_embs, axis=1)

def get_loss(cosine_score, labels):
    return torch.mean(torch.square(labels*(1-cosine_score)+torch.clamp((1-labels)*cosine_score, min=0.0)))


In [10]:
model = AutoModelForSentenceEmbedding(model_name, tokenizer)


In [9]:
from peft import get_peft_model, LoraConfig, TaskType

peft_config = LoraConfig(r=8, lora_alpha=16, bias="none", task_type=TaskType.EMBEDDING)

[2023-06-28 13:42:42,234] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)

Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cuda118.so
CUDA SETUP: CUDA runtime path found: /home/sourab/miniconda3/envs/ml/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /home/sourab/miniconda3/envs/ml/lib/python3.11/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...


Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


In [11]:
model = get_peft_model(model, peft_config)
model

PeftModelForEmbedding(
  (base_model): LoraModel(
    (model): AutoModelForSentenceEmbedding(
      (model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 1024, padding_idx=0)
          (position_embeddings): Embedding(512, 1024)
          (token_type_embeddings): Embedding(2, 1024)
          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0-23): 24 x BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(
                    in_features=1024, out_features=1024, bias=True
                    (lora_dropout): ModuleDict(
                      (default): Identity()
                    )
                    (lora_A): ModuleDict(
                      (default): Linear(in_features=1024, out_features=8, bias=False)


In [12]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

batch_size=128

train_dataloader = DataLoader(processed_datasets["train"], 
                              shuffle=True, 
                              collate_fn=default_data_collator, 
                              batch_size=batch_size,
                              pin_memory=True)

val_dataloader = DataLoader(processed_datasets["validation"], 
                              shuffle=False, 
                              collate_fn=default_data_collator, 
                              batch_size=batch_size,
                              pin_memory=True)



In [13]:
next(iter(train_dataloader))

{'query_input_ids': tensor([[  101, 16012,  3207,  ...,     0,     0,     0],
         [  101,  2227,  6099,  ...,     0,     0,     0],
         [  101,  9476,  2303,  ...,     0,     0,     0],
         ...,
         [  101,  9781, 13227,  ...,     0,     0,     0],
         [  101,  1015,  4720,  ...,     0,     0,     0],
         [  101,  2961,  1998,  ...,     0,     0,     0]]),
 'query_token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]),
 'query_attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'product_input_ids': tensor([[  101,  2208,  2227,  ...,     0,     0,     0],
         [  101,  1006,  5308, 

In [14]:
lr=1e-3
epochs=3
from transformers import get_scheduler

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=epochs*len(train_dataloader),
    )

In [15]:
from tqdm import tqdm

model.to("cuda")
model.train()

for epoch in range(epochs):
    for i, batch in enumerate(tqdm(train_dataloader)):
        query_embs = model(**{k.replace("query_", ""):v.to("cuda") for k,v in batch.items() if "query" in k})
        product_embs = model(**{k.replace("product_", ""):v.to("cuda") for k,v in batch.items() if "product" in k})
        loss = get_loss(get_cosing_embeddings(query_embs, product_embs), batch["labels"].to("cuda"))
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        model.zero_grad()
        if i%20==0:
            print(loss)
        

  0%|                                                                                       | 1/6558 [00:02<3:53:55,  2.14s/it]

tensor(0.1725, device='cuda:0', grad_fn=<MeanBackward0>)


  0%|▎                                                                                     | 21/6558 [00:32<2:52:35,  1.58s/it]

tensor(0.1570, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|▌                                                                                     | 41/6558 [01:02<2:52:16,  1.59s/it]

tensor(0.1466, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|▊                                                                                     | 61/6558 [01:32<2:51:49,  1.59s/it]

tensor(0.1788, device='cuda:0', grad_fn=<MeanBackward0>)


  1%|█                                                                                     | 81/6558 [02:02<2:51:24,  1.59s/it]

tensor(0.1443, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|█▎                                                                                   | 101/6558 [02:33<2:50:58,  1.59s/it]

tensor(0.1330, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|█▌                                                                                   | 121/6558 [03:03<2:50:32,  1.59s/it]

tensor(0.1324, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|█▊                                                                                   | 141/6558 [03:33<2:49:55,  1.59s/it]

tensor(0.1209, device='cuda:0', grad_fn=<MeanBackward0>)


  2%|█▉                                                                                   | 145/6558 [03:39<2:41:51,  1.51s/it]


KeyboardInterrupt: 