In [3]:
pip install torch torchvision transformers datasets peft




In [4]:
import torch
from torch.utils.data import DataLoader
from transformers import DebertaV2Tokenizer, DebertaV2ForSequenceClassification, AdamW, DataCollatorWithPadding
from datasets import load_dataset
from tqdm import tqdm
from peft import LoraConfig, get_peft_model

# Load the MRPC dataset
dataset = load_dataset('glue', 'mrpc')

# Load the tokenizer
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'], padding='max_length', truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Ensure the labels are present and correctly formatted
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

# Data collator that will dynamically pad the inputs
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Create DataLoaders
train_dataloader = DataLoader(tokenized_datasets['train'], batch_size=8, shuffle=True, collate_fn=data_collator)
val_dataloader = DataLoader(tokenized_datasets['validation'], batch_size=8, collate_fn=data_collator)

# Device configuration
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load a pretrained DeBERTa model for sequence classification
model = DebertaV2ForSequenceClassification.from_pretrained('microsoft/deberta-v3-large', num_labels=2)

# Define LoRA configuration
best_r = 4  # Example value, should be tuned
lora_config = LoraConfig(
    modules_to_save=['classifier', 'pooler'],
    task_type="SEQ_CLS",
    r=best_r,
    target_modules=["query_proj", "value_proj"],
    lora_alpha=64
)

# Integrate LoRA with the model
model = get_peft_model(model, lora_config)
model.to(device)

def train(model, train_dataloader, optimizer, device, criterion=torch.nn.CrossEntropyLoss()):
    model.train()
    total_loss = 0

    for batch in tqdm(train_dataloader, desc="Training"):
        inputs = {key: val.to(device) for key, val in batch.items() if key != 'labels'}
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(train_dataloader)
    return average_loss

# Optimizer
optimizer = AdamW(model.parameters(), lr=2e-5)

def evaluate(model, val_dataloader, device):
    model.eval()
    total_correct = 0
    total_samples = 0

    with torch.no_grad():
        for batch in tqdm(val_dataloader, desc="Evaluating"):
            inputs = {key: val.to(device) for key, val in batch.items() if key != 'labels'}
            labels = batch['labels'].to(device)

            outputs = model(**inputs)
            predictions = torch.argmax(outputs.logits, dim=-1)

            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

    accuracy = total_correct / total_samples
    return accuracy

# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    train_loss = train(model, train_dataloader, optimizer, device)
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}")
    val_accuracy = evaluate(model, val_dataloader, device)
    print(f"Validation Accuracy: {val_accuracy:.4f}")




pytorch_model.bin:  41%|####      | 357M/874M [00:00<?, ?B/s]

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Training:   0%|          | 0/459 [00:00<?, ?it/s][A
Training:   0%|          | 1/459 [00:02<16:23,  2.15s/it][A
Training:   0%|          | 2/459 [00:02<08:55,  1.17s/it][A
Training:   1%|          | 3/459 [00:03<06:32,  1.16it/s][A
Training:   1%|          | 4/459 [00:03<05:24,  1.40it/s][A
Training:   1%|          | 5/459 [00:04<04:46,  1.58it/s][A
Training:   1%|▏         | 6/459 [00:04<04:23,  1.72it/s][A
Training:   2%|▏         | 7/459 [00:05<04:09,  1.81it/s][A
Training:   2%|▏         | 8/459 [00:05<03:59,  1.88it/s][A
Training:   2%|▏         | 9/459 [00:06<03:53,  1.93it/s][A
Training:   2%|▏         | 10/459 [00:06<03:49

Epoch 1/5, Train Loss: 0.6294



Evaluating:   0%|          | 0/51 [00:00<?, ?it/s][A
Evaluating:   2%|▏         | 1/51 [00:00<00:13,  3.63it/s][A
Evaluating:   4%|▍         | 2/51 [00:00<00:12,  3.88it/s][A
Evaluating:   6%|▌         | 3/51 [00:00<00:12,  3.93it/s][A
Evaluating:   8%|▊         | 4/51 [00:01<00:11,  3.92it/s][A
Evaluating:  10%|▉         | 5/51 [00:01<00:11,  3.88it/s][A
Evaluating:  12%|█▏        | 6/51 [00:01<00:11,  3.90it/s][A
Evaluating:  14%|█▎        | 7/51 [00:01<00:11,  3.92it/s][A
Evaluating:  16%|█▌        | 8/51 [00:02<00:10,  3.92it/s][A
Evaluating:  18%|█▊        | 9/51 [00:02<00:10,  3.89it/s][A
Evaluating:  20%|█▉        | 10/51 [00:02<00:10,  3.91it/s][A
Evaluating:  22%|██▏       | 11/51 [00:02<00:10,  3.93it/s][A
Evaluating:  24%|██▎       | 12/51 [00:03<00:09,  3.93it/s][A
Evaluating:  25%|██▌       | 13/51 [00:03<00:09,  3.92it/s][A
Evaluating:  27%|██▋       | 14/51 [00:03<00:09,  3.92it/s][A
Evaluating:  29%|██▉       | 15/51 [00:03<00:09,  3.93it/s][A
Evaluatin

Validation Accuracy: 0.6985



Training:   0%|          | 0/459 [00:00<?, ?it/s][A
Training:   0%|          | 1/459 [00:00<03:51,  1.97it/s][A
Training:   0%|          | 2/459 [00:01<03:54,  1.95it/s][A
Training:   1%|          | 3/459 [00:01<03:55,  1.94it/s][A
Training:   1%|          | 4/459 [00:02<03:55,  1.94it/s][A
Training:   1%|          | 5/459 [00:02<03:55,  1.93it/s][A
Training:   1%|▏         | 6/459 [00:03<03:55,  1.93it/s][A
Training:   2%|▏         | 7/459 [00:03<03:55,  1.92it/s][A
Training:   2%|▏         | 8/459 [00:04<03:54,  1.93it/s][A
Training:   2%|▏         | 9/459 [00:04<03:52,  1.93it/s][A
Training:   2%|▏         | 10/459 [00:05<03:52,  1.93it/s][A
Training:   2%|▏         | 11/459 [00:05<03:51,  1.94it/s][A
Training:   3%|▎         | 12/459 [00:06<03:51,  1.93it/s][A
Training:   3%|▎         | 13/459 [00:06<03:50,  1.94it/s][A
Training:   3%|▎         | 14/459 [00:07<03:49,  1.94it/s][A
Training:   3%|▎         | 15/459 [00:07<03:49,  1.93it/s][A
Training:   3%|▎         

Epoch 2/5, Train Loss: 0.4673



Evaluating:   0%|          | 0/51 [00:00<?, ?it/s][A
Evaluating:   2%|▏         | 1/51 [00:00<00:12,  4.13it/s][A
Evaluating:   4%|▍         | 2/51 [00:00<00:12,  4.06it/s][A
Evaluating:   6%|▌         | 3/51 [00:00<00:11,  4.00it/s][A
Evaluating:   8%|▊         | 4/51 [00:01<00:11,  3.94it/s][A
Evaluating:  10%|▉         | 5/51 [00:01<00:11,  3.95it/s][A
Evaluating:  12%|█▏        | 6/51 [00:01<00:11,  3.95it/s][A
Evaluating:  14%|█▎        | 7/51 [00:01<00:11,  3.98it/s][A
Evaluating:  16%|█▌        | 8/51 [00:02<00:10,  3.95it/s][A
Evaluating:  18%|█▊        | 9/51 [00:02<00:10,  3.96it/s][A
Evaluating:  20%|█▉        | 10/51 [00:02<00:10,  3.97it/s][A
Evaluating:  22%|██▏       | 11/51 [00:02<00:10,  3.97it/s][A
Evaluating:  24%|██▎       | 12/51 [00:03<00:09,  3.97it/s][A
Evaluating:  25%|██▌       | 13/51 [00:03<00:09,  3.97it/s][A
Evaluating:  27%|██▋       | 14/51 [00:03<00:09,  3.98it/s][A
Evaluating:  29%|██▉       | 15/51 [00:03<00:09,  3.96it/s][A
Evaluatin

Validation Accuracy: 0.8505



Training:   0%|          | 0/459 [00:00<?, ?it/s][A
Training:   0%|          | 1/459 [00:00<03:50,  1.99it/s][A
Training:   0%|          | 2/459 [00:01<03:54,  1.95it/s][A
Training:   1%|          | 3/459 [00:01<03:53,  1.95it/s][A
Training:   1%|          | 4/459 [00:02<03:52,  1.96it/s][A
Training:   1%|          | 5/459 [00:02<03:52,  1.95it/s][A
Training:   1%|▏         | 6/459 [00:03<03:51,  1.95it/s][A
Training:   2%|▏         | 7/459 [00:03<03:51,  1.95it/s][A
Training:   2%|▏         | 8/459 [00:04<03:51,  1.95it/s][A
Training:   2%|▏         | 9/459 [00:04<03:51,  1.94it/s][A
Training:   2%|▏         | 10/459 [00:05<03:51,  1.94it/s][A
Training:   2%|▏         | 11/459 [00:05<03:50,  1.95it/s][A
Training:   3%|▎         | 12/459 [00:06<03:48,  1.95it/s][A
Training:   3%|▎         | 13/459 [00:06<03:50,  1.94it/s][A
Training:   3%|▎         | 14/459 [00:07<03:49,  1.94it/s][A
Training:   3%|▎         | 15/459 [00:07<03:49,  1.94it/s][A
Training:   3%|▎         

Epoch 3/5, Train Loss: 0.3451



Evaluating:   0%|          | 0/51 [00:00<?, ?it/s][A
Evaluating:   2%|▏         | 1/51 [00:00<00:12,  4.12it/s][A
Evaluating:   4%|▍         | 2/51 [00:00<00:12,  4.04it/s][A
Evaluating:   6%|▌         | 3/51 [00:00<00:12,  4.00it/s][A
Evaluating:   8%|▊         | 4/51 [00:01<00:11,  3.96it/s][A
Evaluating:  10%|▉         | 5/51 [00:01<00:11,  3.96it/s][A
Evaluating:  12%|█▏        | 6/51 [00:01<00:11,  3.98it/s][A
Evaluating:  14%|█▎        | 7/51 [00:01<00:11,  3.99it/s][A
Evaluating:  16%|█▌        | 8/51 [00:02<00:10,  3.97it/s][A
Evaluating:  18%|█▊        | 9/51 [00:02<00:10,  3.97it/s][A
Evaluating:  20%|█▉        | 10/51 [00:02<00:10,  3.97it/s][A
Evaluating:  22%|██▏       | 11/51 [00:02<00:10,  3.98it/s][A
Evaluating:  24%|██▎       | 12/51 [00:03<00:09,  3.97it/s][A
Evaluating:  25%|██▌       | 13/51 [00:03<00:09,  3.97it/s][A
Evaluating:  27%|██▋       | 14/51 [00:03<00:09,  3.96it/s][A
Evaluating:  29%|██▉       | 15/51 [00:03<00:09,  3.97it/s][A
Evaluatin

Validation Accuracy: 0.8946



Training:   0%|          | 0/459 [00:00<?, ?it/s][A
Training:   0%|          | 1/459 [00:00<03:50,  1.98it/s][A
Training:   0%|          | 2/459 [00:01<03:53,  1.96it/s][A
Training:   1%|          | 3/459 [00:01<03:52,  1.96it/s][A
Training:   1%|          | 4/459 [00:02<03:52,  1.96it/s][A
Training:   1%|          | 5/459 [00:02<03:52,  1.95it/s][A
Training:   1%|▏         | 6/459 [00:03<03:53,  1.94it/s][A
Training:   2%|▏         | 7/459 [00:03<03:52,  1.95it/s][A
Training:   2%|▏         | 8/459 [00:04<03:52,  1.94it/s][A
Training:   2%|▏         | 9/459 [00:04<03:52,  1.93it/s][A
Training:   2%|▏         | 10/459 [00:05<03:51,  1.94it/s][A
Training:   2%|▏         | 11/459 [00:05<03:50,  1.94it/s][A
Training:   3%|▎         | 12/459 [00:06<03:50,  1.94it/s][A
Training:   3%|▎         | 13/459 [00:06<03:50,  1.93it/s][A
Training:   3%|▎         | 14/459 [00:07<03:50,  1.93it/s][A
Training:   3%|▎         | 15/459 [00:07<03:49,  1.93it/s][A
Training:   3%|▎         

Epoch 4/5, Train Loss: 0.2950



Evaluating:   0%|          | 0/51 [00:00<?, ?it/s][A
Evaluating:   2%|▏         | 1/51 [00:00<00:12,  4.14it/s][A
Evaluating:   4%|▍         | 2/51 [00:00<00:12,  4.05it/s][A
Evaluating:   6%|▌         | 3/51 [00:00<00:12,  4.00it/s][A
Evaluating:   8%|▊         | 4/51 [00:01<00:11,  3.93it/s][A
Evaluating:  10%|▉         | 5/51 [00:01<00:11,  3.96it/s][A
Evaluating:  12%|█▏        | 6/51 [00:01<00:11,  3.96it/s][A
Evaluating:  14%|█▎        | 7/51 [00:01<00:11,  3.97it/s][A
Evaluating:  16%|█▌        | 8/51 [00:02<00:10,  3.95it/s][A
Evaluating:  18%|█▊        | 9/51 [00:02<00:10,  3.97it/s][A
Evaluating:  20%|█▉        | 10/51 [00:02<00:10,  3.95it/s][A
Evaluating:  22%|██▏       | 11/51 [00:02<00:10,  3.94it/s][A
Evaluating:  24%|██▎       | 12/51 [00:03<00:09,  3.95it/s][A
Evaluating:  25%|██▌       | 13/51 [00:03<00:09,  3.95it/s][A
Evaluating:  27%|██▋       | 14/51 [00:03<00:09,  3.95it/s][A
Evaluating:  29%|██▉       | 15/51 [00:03<00:09,  3.95it/s][A
Evaluatin

Validation Accuracy: 0.8824



Training:   0%|          | 0/459 [00:00<?, ?it/s][A
Training:   0%|          | 1/459 [00:00<03:52,  1.97it/s][A
Training:   0%|          | 2/459 [00:01<03:55,  1.94it/s][A
Training:   1%|          | 3/459 [00:01<03:55,  1.94it/s][A
Training:   1%|          | 4/459 [00:02<03:54,  1.94it/s][A
Training:   1%|          | 5/459 [00:02<03:53,  1.95it/s][A
Training:   1%|▏         | 6/459 [00:03<03:53,  1.94it/s][A
Training:   2%|▏         | 7/459 [00:03<03:53,  1.93it/s][A
Training:   2%|▏         | 8/459 [00:04<03:53,  1.93it/s][A
Training:   2%|▏         | 9/459 [00:04<03:52,  1.94it/s][A
Training:   2%|▏         | 10/459 [00:05<03:51,  1.94it/s][A
Training:   2%|▏         | 11/459 [00:05<03:51,  1.94it/s][A
Training:   3%|▎         | 12/459 [00:06<03:51,  1.93it/s][A
Training:   3%|▎         | 13/459 [00:06<03:50,  1.94it/s][A
Training:   3%|▎         | 14/459 [00:07<03:50,  1.93it/s][A
Training:   3%|▎         | 15/459 [00:07<03:49,  1.93it/s][A
Training:   3%|▎         

Epoch 5/5, Train Loss: 0.2608



Evaluating:   0%|          | 0/51 [00:00<?, ?it/s][A
Evaluating:   2%|▏         | 1/51 [00:00<00:12,  4.16it/s][A
Evaluating:   4%|▍         | 2/51 [00:00<00:12,  4.06it/s][A
Evaluating:   6%|▌         | 3/51 [00:00<00:11,  4.03it/s][A
Evaluating:   8%|▊         | 4/51 [00:00<00:11,  3.98it/s][A
Evaluating:  10%|▉         | 5/51 [00:01<00:11,  3.99it/s][A
Evaluating:  12%|█▏        | 6/51 [00:01<00:11,  3.98it/s][A
Evaluating:  14%|█▎        | 7/51 [00:01<00:11,  3.98it/s][A
Evaluating:  16%|█▌        | 8/51 [00:02<00:10,  3.97it/s][A
Evaluating:  18%|█▊        | 9/51 [00:02<00:10,  3.96it/s][A
Evaluating:  20%|█▉        | 10/51 [00:02<00:10,  3.95it/s][A
Evaluating:  22%|██▏       | 11/51 [00:02<00:10,  3.98it/s][A
Evaluating:  24%|██▎       | 12/51 [00:03<00:09,  3.97it/s][A
Evaluating:  25%|██▌       | 13/51 [00:03<00:09,  3.96it/s][A
Evaluating:  27%|██▋       | 14/51 [00:03<00:09,  3.96it/s][A
Evaluating:  29%|██▉       | 15/51 [00:03<00:09,  3.96it/s][A
Evaluatin

Validation Accuracy: 0.8971



