In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification
from transformers import get_scheduler
from torch.optim import AdamW   
import torch
from tqdm.auto import tqdm
# from datasets import load_metric

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("yelp_review_full", cache_dir='.')

Generating train split: 100%|██████████| 650000/650000 [00:01<00:00, 614866.46 examples/s]
Generating test split: 100%|██████████| 50000/50000 [00:00<00:00, 634840.76 examples/s]


In [3]:
small_train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = dataset["test"].shuffle(seed=42).select(range(1000))


In [4]:
tokenizer = AutoTokenizer.from_pretrained("Llama-encoder-1.0B")
tokenizer.pad_token = tokenizer.eos_token
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_train = small_train_dataset.map(tokenize_function, batched=True)



Map:   0%|          | 0/1000 [00:00<?, ? examples/s]Asking to pad to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no padding.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 1000/1000 [00:00<00:00, 8094.22 examples/s]


In [5]:
tokenized_train = tokenized_train.remove_columns(["text"])
tokenized_train = tokenized_train.rename_column("label", "labels")


In [None]:
train_dataloader = DataLoader(tokenized_train, shuffle=True, batch_size=1)
model = AutoModelForSequenceClassification.from_pretrained("Llama-encoder-1.0B", num_labels=5)
model.config.pad_token_id = model.config.eos_token_id
optimizer = AdamW(model.parameters(), lr=5e-5)
num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)

progress_bar = tqdm(range(num_training_steps))

model.train()
# model.model.gradient_checkpointing = True

for epoch in range(num_epochs):
    for batch in train_dataloader:

        preprocess_batch = {}
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                preprocess_batch[k] = v.to(device)
            elif isinstance(v, list):
                preprocess_batch[k] = torch.stack(v, dim=1).to(device)

        for k, v in preprocess_batch.items():
            print(k,v.shape)
        # batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**preprocess_batch)
        loss = outputs.loss
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)


Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at Llama-encoder-1.0B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  0%|          | 0/3000 [00:00<?, ?it/s]

labels torch.Size([1])
input_ids torch.Size([1, 110])
attention_mask torch.Size([1, 110])


  0%|          | 1/3000 [00:08<7:11:25,  8.63s/it]

labels torch.Size([1])
input_ids torch.Size([1, 90])
attention_mask torch.Size([1, 90])


  0%|          | 2/3000 [00:13<5:19:01,  6.38s/it]

labels torch.Size([1])
input_ids torch.Size([1, 152])
attention_mask torch.Size([1, 152])


  0%|          | 3/3000 [00:19<5:06:11,  6.13s/it]

labels torch.Size([1])
input_ids torch.Size([1, 206])
attention_mask torch.Size([1, 206])


  0%|          | 4/3000 [00:26<5:25:24,  6.52s/it]

labels torch.Size([1])
input_ids torch.Size([1, 501])
attention_mask torch.Size([1, 501])


  0%|          | 5/3000 [00:40<7:40:53,  9.23s/it]

labels torch.Size([1])
input_ids torch.Size([1, 240])
attention_mask torch.Size([1, 240])


  0%|          | 6/3000 [00:48<7:23:38,  8.89s/it]

labels torch.Size([1])
input_ids torch.Size([1, 62])
attention_mask torch.Size([1, 62])


  0%|          | 7/3000 [00:53<6:16:03,  7.54s/it]

labels torch.Size([1])
input_ids torch.Size([1, 476])
attention_mask torch.Size([1, 476])


  0%|          | 8/3000 [01:06<7:39:16,  9.21s/it]

labels torch.Size([1])
input_ids torch.Size([1, 398])
attention_mask torch.Size([1, 398])


  0%|          | 9/3000 [01:17<8:07:03,  9.77s/it]

labels torch.Size([1])
input_ids torch.Size([1, 411])
attention_mask torch.Size([1, 411])


  0%|          | 10/3000 [01:28<8:32:14, 10.28s/it]

labels torch.Size([1])
input_ids torch.Size([1, 70])
attention_mask torch.Size([1, 70])


  0%|          | 11/3000 [01:33<7:15:29,  8.74s/it]

labels torch.Size([1])
input_ids torch.Size([1, 366])
attention_mask torch.Size([1, 366])


  0%|          | 12/3000 [01:45<8:00:37,  9.65s/it]

labels torch.Size([1])
input_ids torch.Size([1, 29])
attention_mask torch.Size([1, 29])


  0%|          | 13/3000 [01:50<6:42:27,  8.08s/it]

labels torch.Size([1])
input_ids torch.Size([1, 157])
attention_mask torch.Size([1, 157])


  0%|          | 14/3000 [01:57<6:32:45,  7.89s/it]

labels torch.Size([1])
input_ids torch.Size([1, 229])
attention_mask torch.Size([1, 229])


  0%|          | 15/3000 [02:06<6:52:36,  8.29s/it]

labels torch.Size([1])
input_ids torch.Size([1, 66])
attention_mask torch.Size([1, 66])


  1%|          | 16/3000 [02:11<6:06:09,  7.36s/it]

labels torch.Size([1])
input_ids torch.Size([1, 117])
attention_mask torch.Size([1, 117])


  1%|          | 17/3000 [02:18<5:49:34,  7.03s/it]

labels torch.Size([1])
input_ids torch.Size([1, 162])
attention_mask torch.Size([1, 162])


  1%|          | 18/3000 [02:25<5:50:14,  7.05s/it]

labels torch.Size([1])
input_ids torch.Size([1, 473])
attention_mask torch.Size([1, 473])


  1%|          | 19/3000 [02:39<7:43:49,  9.34s/it]

labels torch.Size([1])
input_ids torch.Size([1, 573])
attention_mask torch.Size([1, 573])


  1%|          | 20/3000 [02:57<9:45:37, 11.79s/it]

labels torch.Size([1])
input_ids torch.Size([1, 64])
attention_mask torch.Size([1, 64])


  1%|          | 21/3000 [03:02<8:10:51,  9.89s/it]

labels torch.Size([1])
input_ids torch.Size([1, 75])
attention_mask torch.Size([1, 75])


  1%|          | 22/3000 [03:08<7:03:48,  8.54s/it]

labels torch.Size([1])
input_ids torch.Size([1, 20])
attention_mask torch.Size([1, 20])


  1%|          | 23/3000 [03:12<6:00:04,  7.26s/it]

labels torch.Size([1])
input_ids torch.Size([1, 172])
attention_mask torch.Size([1, 172])


  1%|          | 24/3000 [03:20<6:07:04,  7.40s/it]

labels torch.Size([1])
input_ids torch.Size([1, 353])
attention_mask torch.Size([1, 353])


  1%|          | 25/3000 [03:32<7:12:58,  8.73s/it]

labels torch.Size([1])
input_ids torch.Size([1, 27])
attention_mask torch.Size([1, 27])


  1%|          | 26/3000 [03:36<6:11:51,  7.50s/it]

labels torch.Size([1])
input_ids torch.Size([1, 124])
attention_mask torch.Size([1, 124])


  1%|          | 27/3000 [03:43<6:01:16,  7.29s/it]

labels torch.Size([1])
input_ids torch.Size([1, 377])
attention_mask torch.Size([1, 377])


  1%|          | 28/3000 [03:56<7:31:35,  9.12s/it]

labels torch.Size([1])
input_ids torch.Size([1, 529])
attention_mask torch.Size([1, 529])
