In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import get_scheduler
from torch.optim import AdamW   
import torch
from tqdm.auto import tqdm
from llm2vec.models import LlamaBiModel
from peft import get_peft_model, LoraConfig, TaskType

# from datasets import load_metric

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("yelp_review_full", cache_dir='.')

Generating train split: 2600000 examples [00:01, 1486024.04 examples/s]
Generating test split: 200000 examples [00:00, 1591666.56 examples/s]


In [13]:
small_train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = dataset["test"].shuffle(seed=42).select(range(1000))


In [14]:
tokenizer = AutoTokenizer.from_pretrained("Llama-encoder-1.0B")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
    return tokenizer(examples["text"], return_tensors="pt", padding=True, truncation=True, max_length=512)

tokenized_train = small_train_dataset.map(tokenize_function, batched=True)



In [15]:
tokenized_train = tokenized_train.remove_columns(["text"])
# tokenized_train = tokenized_train.remove_columns(["label"])
tokenized_train = tokenized_train.rename_column("label", "labels")


In [18]:
train_dataloader = DataLoader(tokenized_train, shuffle=True, batch_size=1)
# model = AutoModelForSequenceClassification.from_pretrained("Llama-encoder-1.0B", num_labels=5)
model = LlamaBiModel.from_pretrained("Llama-encoder-1.0B")
# model.config.pad_token_id = model.config.eos_token_id
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=8,
        lora_alpha=32,
        lora_dropout=0.1,
        # target_modules=["query", "value"]
    )
# print(model)
model = get_peft_model(model, lora_config)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.print_trainable_parameters()
progress_bar = tqdm(range(num_training_steps))

# model.train()
# model.model.gradient_checkpointing = True

for epoch in range(num_epochs):
    for batch in train_dataloader:
        preprocess_batch = {}
        for k, v in batch.items():
            # print(k)
            # if k == 'labels':
            #     pass
            # elif isinstance(v, torch.Tensor):
            #     preprocess_batch[k] = v.to(device)
            # elif isinstance(v, list):
            #     preprocess_batch[k] = torch.stack(v, dim=1).to(device)

            if isinstance(v, torch.Tensor):
                preprocess_batch[k] = v.to(device)
            elif isinstance(v, list):
                preprocess_batch[k] = torch.stack(v, dim=1).to(device)

            print(preprocess_batch)
        # print(preprocess_batch)
        # with torch.no_grad():
        outputs = model(**preprocess_batch)

        # last_hidden_states = outputs.last_hidden_state
        # print(last_hidden_states, last_hidden_states.shape)
        # print(outputs.last_hidden_state.shape)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)


Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at Llama-encoder-1.0B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 1,136,640 || all params: 1,035,659,264 || trainable%: 0.1098


  0%|          | 2/3000 [06:45<168:57:14, 202.88s/it]


{'labels': tensor([4])}
{'labels': tensor([4]), 'input_ids': tensor([[    1,  3118,   310,   278,  1900,  8230,  1078,  1921,  2363,   306,
           505,  3926,  1063,   304,   322,   306,   626,   263,  8230,  1078,
         13524,   577,   306,   505,  1063,   304,   263,  2846,   975,   278,
          2440, 29889,   450,  8693,  4359,   756,   263,  1302,  1537,  4459,
           304,   372,   448,   366,  4459,  5476,   510,   287,   278,   937,
           931,   366,  6686,   964,   278,  2058, 29889,   450, 13925,   338,
         14154,  1532, 16370,   322,  1073,   825,   896,   526,  2599, 29892,
           920,   304,  1959,   596,   883,   856,  3204,   304,   278, 11015,
          9493, 29889,   512, 21499,   727,  2833,   304,   367,   343, 14895,
          7600,  1432,  2908,   448,  6060,   393,   338,  1363,  1784,  1016,
         29915, 29873,  1073,  1048,  8230,  1078,   322,   920,  7795,  5611,
           372,   338, 29973,   306,  1016, 29915, 29873,  1073, 29889



{'labels': tensor([4])}
{'labels': tensor([4]), 'input_ids': tensor([[    1,  1128,   508,   278,  1900,   360,  2960,   262,  3872,  8842,
           306,  3926,  1063,   304,   367,   297, 29715, 29973, 29871,   306,
           505,  2360,  1539,   263,  5121,  4926, 13925, 29889, 29871,  2216,
           763,   278,  1383,   514, 29914,  7858,   297, 12321,  4006, 29892,
          5625, 29876,   366,  2305,   297, 29715,   526,  7575, 29889, 29871,
         15992,   590, 29871,  7612, 26935,   304,   639, 20309,   313, 11884,
          1754,  1854,   372,   471,   577, 11410, 12115,   315,  1745,   267,
         29892,   263,  3761,   372,   471,   577,  1781, 29889, 29871,   306,
          3512,  1250,   363,   901,   313,   974,  3236, 29892,   306, 29915,
         29885,  1401,  1150,   287,   472,  3064, 29897,   322,   263,  8455,
          3614,   975,   746,   372,  2355, 19587, 29892,   541,  1584,   750,
           931,   304, 13563,   304,   592, 29889, 29871, 23350,  1090



{'labels': tensor([2])}
{'labels': tensor([2]), 'input_ids': tensor([[    1, 10791, 29901,  2428,  8444,  4480,   261, 29892,  1781,   282,
           449,   457,   313,  4187,   474,  2355,   941,  2649,   366, 29901,
           474,  1016, 29915, 29873,  1073,   825,   515,   282,   449,   457,
         29889,   541,   474,  3282, 29915, 29873, 24817,  1009, 29879,  1283,
           278,  1591, 29892,  9343,  1073, 29973,   467,  9360, 29991,   322,
          7575, 26552,   289,  1338,   314,   293,   373,   278,  4497,   328,
         29889,   320, 29876, 29905, 29876,  3200, 29901,  1922, 29892, 26072,
           414,   871,  2041,  1532,  2309, 29973,   393,  1838, 29915, 29873,
          2289,   664,   363,   592, 29889,   363,   263,  1532,  2309,  6866,
           914,   372,   471,  3117,   714, 11235,   313, 29873, 28470,   263,
          2586,   763,   727,   471,   373,   291, 22300,  6837,   297,   278,
         27654, 29892,   541,   321, 29882,   511,   541,  2289, 29892

KeyboardInterrupt: 