In [1]:
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import Dataset, load_metric, DatasetDict
import csv
import sentencepiece
import google.protobuf
import evaluate
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, DatasetDict
import torch
from tqdm import tqdm
import os

In [2]:
classes = [
    "Customer Service",
    "Audio",
    "Visual",
    "Bluetooth",
    "Microphone",
    "Battery Life",
    "Internet",
    "Performance/Speed",
    "Value/Price",
    "Ease of Use",
    "Comfortability/Fit",
    "Undetermined"
]
class2id = {class_:id for id, class_ in enumerate(classes)}
id2class = {id:class_ for id, class_ in enumerate(classes)}
model_path = "microsoft/deberta-v3-large"

tokenizer = AutoTokenizer.from_pretrained(model_path)



In [5]:
# load in the database
with open("reviews_labeled.csv", "r") as f:
    reader = csv.reader(f)
    data = list(reader)[1:]

In [6]:
bad_data = [0]
def preprocess_function(example):
    if 'Review' not in example or 'Label' not in example:
        bad_data[0] += 1
        print(f"Bad data: {bad_data[0]}")
        return None

    text = example['Review']
    label = example['Label']

    labels = [0.0 for _ in classes]
    if label in class2id:
        labels[class2id[label]] = 1.0

    tokenized_sentence = tokenizer(text, padding="max_length", truncation=True, max_length=128)
    return {
        "input_ids": tokenized_sentence['input_ids'],  # Adjusted to how your tokenizer returns values
        "attention_mask": tokenized_sentence['attention_mask'],
        "labels": labels
    }

dataset_arr = []
for row in data:
    if len(row) > 1 and row[1] in classes:
        processed = preprocess_function({"Review": row[0], "Label": row[1]})
        if processed:
            dataset_arr.append(processed)

print(len(dataset_arr))

18687


In [7]:
from sklearn.model_selection import train_test_split

# split data into training, validation, and testing
train_data, test_data = train_test_split(dataset_arr, test_size=0.05, random_state=42)
test_data, val_data = train_test_split(test_data, test_size=0.5, random_state=42)

print(f"Train: {len(train_data)}")
print(f"Validation: {len(val_data)}")
print(f"Test: {len(test_data)}")

Train: 17752
Validation: 468
Test: 467


In [8]:
train_data_dict = {"input_ids": [], "attention_mask": [], "labels": []}
val_data_dict = {"input_ids": [], "attention_mask": [], "labels": []}
test_data_dict = {"input_ids": [], "attention_mask": [], "labels": []}

for data in train_data:
    train_data_dict["input_ids"].append(data["input_ids"])
    train_data_dict["attention_mask"].append(data["attention_mask"])
    train_data_dict["labels"].append(data["labels"])

for data in val_data:
    val_data_dict["input_ids"].append(data["input_ids"])
    val_data_dict["attention_mask"].append(data["attention_mask"])
    val_data_dict["labels"].append(data["labels"])

for data in test_data:
    test_data_dict["input_ids"].append(data["input_ids"])
    test_data_dict["attention_mask"].append(data["attention_mask"])
    test_data_dict["labels"].append(data["labels"])

train_dataset = Dataset.from_dict(train_data_dict)
val_dataset = Dataset.from_dict(val_data_dict)
test_dataset = Dataset.from_dict(test_data_dict)

tokenized_datasets = DatasetDict({
    "train": train_dataset,
    "validation": val_dataset,
    "test": test_dataset
})

In [9]:
# Define a simple function to convert lists to tensors
def collate_fn(batch):
    return {key: torch.tensor([d[key] for d in batch]) for key in batch[0]}

# Creating DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=8, collate_fn=collate_fn, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=8, collate_fn=collate_fn)


In [15]:
# load the multi class classification model
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(classes), id2label=id2class, label2id=class2id, problem_type="multi_label_classification")

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [18]:
def train(model, train_dataloader, val_dataloader, optimizer, training_args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print(f"Training on {device}")

    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)

    best_val_loss = float('inf')

    for epoch in range(training_args.num_train_epochs):
        model.train()
        total_train_loss = 0
        train_progress_bar = tqdm(train_dataloader, total=len(train_dataloader),
                                  desc=f"Epoch {epoch + 1}/{training_args.num_train_epochs} Training")

        for i, batch in enumerate(train_progress_bar):
            optimizer.zero_grad()

            inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device)
            outputs = model(**inputs)
            loss = criterion(outputs.logits, labels)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

            if (i + 1) % 40 == 0:
                print(f"Epoch {epoch + 1}, Batch {i + 1}, Loss: {loss.item():.4f}")

        scheduler.step()

        if training_args.evaluation_strategy == 'epoch':
            model.eval()
            total_val_loss = 0
            with torch.no_grad():
                for batch in tqdm(val_dataloader, desc="Validating"):
                    inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
                    labels = batch['labels'].to(device)
                    outputs = model(**inputs)
                    val_loss = criterion(outputs.logits, labels)
                    total_val_loss += val_loss.item()

            avg_val_loss = total_val_loss / len(val_dataloader)
            print(f"Epoch {epoch + 1}, Validation Loss: {avg_val_loss:.4f}")

            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                if training_args.save_strategy == "best":
                    torch.save(model.state_dict(), os.path.join(training_args.output_dir, 'best_model.pth'))

        if training_args.save_strategy == "epoch":
            torch.save(model.state_dict(), os.path.join(training_args.output_dir, f'model_epoch_{epoch + 1}.pth'))

    if training_args.load_best_model_at_end and training_args.save_strategy == "best":
        model.load_state_dict(torch.load(os.path.join(training_args.output_dir, 'best_model.pth')))

    return model


In [19]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)

training_args = TrainingArguments(

   output_dir="classification_beta",
   learning_rate=2e-4,
   per_device_train_batch_size=2,
   per_device_eval_batch_size=2,
   num_train_epochs=5,
   weight_decay=0.01,
   evaluation_strategy="epoch",
   save_strategy="epoch",
   load_best_model_at_end=True,
)

trainer = train(model, train_dataloader, val_dataloader, optimizer, training_args)

trainer.train()


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Training on cuda


Epoch 1/5 Training:   2%|▏         | 40/2219 [00:17<13:28,  2.70it/s] 

Epoch 1, Batch 40, Loss: 2.1599


Epoch 1/5 Training:   4%|▎         | 80/2219 [00:33<13:05,  2.72it/s]

Epoch 1, Batch 80, Loss: 2.4659


Epoch 1/5 Training:   5%|▌         | 120/2219 [00:48<12:39,  2.76it/s]

Epoch 1, Batch 120, Loss: 2.1702


Epoch 1/5 Training:   7%|▋         | 160/2219 [01:04<12:26,  2.76it/s]

Epoch 1, Batch 160, Loss: 1.7163


Epoch 1/5 Training:   9%|▉         | 200/2219 [01:19<13:03,  2.58it/s]

Epoch 1, Batch 200, Loss: 2.4010


Epoch 1/5 Training:  11%|█         | 240/2219 [01:34<11:57,  2.76it/s]

Epoch 1, Batch 240, Loss: 1.1180


Epoch 1/5 Training:  13%|█▎        | 280/2219 [01:49<11:56,  2.71it/s]

Epoch 1, Batch 280, Loss: 1.8477


Epoch 1/5 Training:  14%|█▍        | 320/2219 [02:04<12:23,  2.56it/s]

Epoch 1, Batch 320, Loss: 1.6071


Epoch 1/5 Training:  16%|█▌        | 360/2219 [02:19<12:10,  2.54it/s]

Epoch 1, Batch 360, Loss: 1.0986


Epoch 1/5 Training:  18%|█▊        | 400/2219 [02:35<11:02,  2.75it/s]

Epoch 1, Batch 400, Loss: 1.0180


Epoch 1/5 Training:  20%|█▉        | 440/2219 [02:50<10:48,  2.74it/s]

Epoch 1, Batch 440, Loss: 1.0813


Epoch 1/5 Training:  22%|██▏       | 480/2219 [03:05<11:30,  2.52it/s]

Epoch 1, Batch 480, Loss: 0.7113


Epoch 1/5 Training:  23%|██▎       | 520/2219 [03:21<12:09,  2.33it/s]

Epoch 1, Batch 520, Loss: 0.7267


Epoch 1/5 Training:  25%|██▌       | 560/2219 [03:36<10:33,  2.62it/s]

Epoch 1, Batch 560, Loss: 0.8296


Epoch 1/5 Training:  27%|██▋       | 600/2219 [03:52<11:27,  2.35it/s]

Epoch 1, Batch 600, Loss: 0.1642


Epoch 1/5 Training:  29%|██▉       | 640/2219 [04:07<09:44,  2.70it/s]

Epoch 1, Batch 640, Loss: 1.0187


Epoch 1/5 Training:  31%|███       | 680/2219 [04:22<09:27,  2.71it/s]

Epoch 1, Batch 680, Loss: 1.5717


Epoch 1/5 Training:  32%|███▏      | 720/2219 [04:38<09:37,  2.59it/s]

Epoch 1, Batch 720, Loss: 1.0713


Epoch 1/5 Training:  34%|███▍      | 760/2219 [04:55<14:15,  1.71it/s]

Epoch 1, Batch 760, Loss: 1.5650


Epoch 1/5 Training:  36%|███▌      | 800/2219 [05:13<08:38,  2.74it/s]

Epoch 1, Batch 800, Loss: 1.3715


Epoch 1/5 Training:  38%|███▊      | 840/2219 [05:28<08:19,  2.76it/s]

Epoch 1, Batch 840, Loss: 0.5067


Epoch 1/5 Training:  40%|███▉      | 880/2219 [05:44<08:10,  2.73it/s]

Epoch 1, Batch 880, Loss: 0.8411


Epoch 1/5 Training:  41%|████▏     | 920/2219 [06:00<09:22,  2.31it/s]

Epoch 1, Batch 920, Loss: 0.9352


Epoch 1/5 Training:  43%|████▎     | 960/2219 [06:15<08:23,  2.50it/s]

Epoch 1, Batch 960, Loss: 1.1281


Epoch 1/5 Training:  45%|████▌     | 1000/2219 [06:31<08:10,  2.48it/s]

Epoch 1, Batch 1000, Loss: 1.5277


Epoch 1/5 Training:  47%|████▋     | 1040/2219 [06:47<07:16,  2.70it/s]

Epoch 1, Batch 1040, Loss: 1.1595


Epoch 1/5 Training:  49%|████▊     | 1080/2219 [07:03<06:46,  2.80it/s]

Epoch 1, Batch 1080, Loss: 0.3927


Epoch 1/5 Training:  50%|█████     | 1120/2219 [07:17<05:59,  3.05it/s]

Epoch 1, Batch 1120, Loss: 0.8381


Epoch 1/5 Training:  52%|█████▏    | 1160/2219 [07:32<06:46,  2.61it/s]

Epoch 1, Batch 1160, Loss: 0.6052


Epoch 1/5 Training:  54%|█████▍    | 1200/2219 [07:48<06:57,  2.44it/s]

Epoch 1, Batch 1200, Loss: 0.6749


Epoch 1/5 Training:  56%|█████▌    | 1240/2219 [08:04<06:35,  2.47it/s]

Epoch 1, Batch 1240, Loss: 0.2663


Epoch 1/5 Training:  58%|█████▊    | 1280/2219 [08:19<05:54,  2.65it/s]

Epoch 1, Batch 1280, Loss: 0.6941


Epoch 1/5 Training:  59%|█████▉    | 1320/2219 [08:34<05:26,  2.75it/s]

Epoch 1, Batch 1320, Loss: 0.7348


Epoch 1/5 Training:  61%|██████▏   | 1360/2219 [08:50<05:41,  2.51it/s]

Epoch 1, Batch 1360, Loss: 1.7111


Epoch 1/5 Training:  63%|██████▎   | 1400/2219 [09:05<05:23,  2.54it/s]

Epoch 1, Batch 1400, Loss: 0.9065


Epoch 1/5 Training:  65%|██████▍   | 1440/2219 [09:21<05:34,  2.33it/s]

Epoch 1, Batch 1440, Loss: 0.7441


Epoch 1/5 Training:  67%|██████▋   | 1480/2219 [09:36<04:25,  2.79it/s]

Epoch 1, Batch 1480, Loss: 0.7714


Epoch 1/5 Training:  68%|██████▊   | 1520/2219 [09:52<04:22,  2.66it/s]

Epoch 1, Batch 1520, Loss: 0.5721


Epoch 1/5 Training:  70%|███████   | 1560/2219 [10:10<05:32,  1.98it/s]

Epoch 1, Batch 1560, Loss: 0.4366


Epoch 1/5 Training:  72%|███████▏  | 1600/2219 [10:26<03:29,  2.95it/s]

Epoch 1, Batch 1600, Loss: 1.1744


Epoch 1/5 Training:  74%|███████▍  | 1640/2219 [10:40<03:33,  2.71it/s]

Epoch 1, Batch 1640, Loss: 1.0065


Epoch 1/5 Training:  76%|███████▌  | 1680/2219 [10:55<03:32,  2.53it/s]

Epoch 1, Batch 1680, Loss: 0.1682


Epoch 1/5 Training:  78%|███████▊  | 1720/2219 [11:11<03:15,  2.55it/s]

Epoch 1, Batch 1720, Loss: 1.8701


Epoch 1/5 Training:  79%|███████▉  | 1760/2219 [11:26<02:58,  2.57it/s]

Epoch 1, Batch 1760, Loss: 0.6423


Epoch 1/5 Training:  81%|████████  | 1800/2219 [11:41<02:29,  2.81it/s]

Epoch 1, Batch 1800, Loss: 0.5709


Epoch 1/5 Training:  83%|████████▎ | 1840/2219 [11:57<02:13,  2.84it/s]

Epoch 1, Batch 1840, Loss: 0.1201


Epoch 1/5 Training:  85%|████████▍ | 1880/2219 [12:12<02:19,  2.43it/s]

Epoch 1, Batch 1880, Loss: 0.5461


Epoch 1/5 Training:  87%|████████▋ | 1920/2219 [12:28<01:57,  2.53it/s]

Epoch 1, Batch 1920, Loss: 1.3960


Epoch 1/5 Training:  88%|████████▊ | 1960/2219 [12:43<01:45,  2.45it/s]

Epoch 1, Batch 1960, Loss: 1.0787


Epoch 1/5 Training:  90%|█████████ | 2000/2219 [12:59<01:20,  2.73it/s]

Epoch 1, Batch 2000, Loss: 0.4567


Epoch 1/5 Training:  92%|█████████▏| 2040/2219 [13:14<01:03,  2.80it/s]

Epoch 1, Batch 2040, Loss: 0.6481


Epoch 1/5 Training:  94%|█████████▎| 2080/2219 [13:33<01:05,  2.12it/s]

Epoch 1, Batch 2080, Loss: 0.4780


Epoch 1/5 Training:  96%|█████████▌| 2120/2219 [13:49<00:36,  2.73it/s]

Epoch 1, Batch 2120, Loss: 0.7603


Epoch 1/5 Training:  97%|█████████▋| 2160/2219 [14:04<00:22,  2.62it/s]

Epoch 1, Batch 2160, Loss: 1.2654


Epoch 1/5 Training:  99%|█████████▉| 2200/2219 [14:20<00:06,  2.73it/s]

Epoch 1, Batch 2200, Loss: 0.8894


Epoch 1/5 Training: 100%|██████████| 2219/2219 [14:27<00:00,  2.56it/s]
Validating: 100%|██████████| 59/59 [00:06<00:00,  9.38it/s]


Epoch 1, Validation Loss: 0.6192


Epoch 2/5 Training:   2%|▏         | 40/2219 [00:15<13:42,  2.65it/s]

Epoch 2, Batch 40, Loss: 0.0823


Epoch 2/5 Training:   4%|▎         | 80/2219 [00:31<13:01,  2.74it/s]

Epoch 2, Batch 80, Loss: 1.1081


Epoch 2/5 Training:   5%|▌         | 120/2219 [00:48<21:00,  1.67it/s]

Epoch 2, Batch 120, Loss: 1.0181


Epoch 2/5 Training:   7%|▋         | 160/2219 [01:06<12:47,  2.68it/s]

Epoch 2, Batch 160, Loss: 0.2738


Epoch 2/5 Training:   9%|▉         | 200/2219 [01:22<12:18,  2.74it/s]

Epoch 2, Batch 200, Loss: 0.3580


Epoch 2/5 Training:  11%|█         | 240/2219 [01:37<12:34,  2.62it/s]

Epoch 2, Batch 240, Loss: 0.4961


Epoch 2/5 Training:  13%|█▎        | 280/2219 [01:53<13:45,  2.35it/s]

Epoch 2, Batch 280, Loss: 0.2007


Epoch 2/5 Training:  14%|█▍        | 320/2219 [02:09<12:29,  2.53it/s]

Epoch 2, Batch 320, Loss: 0.1070


Epoch 2/5 Training:  16%|█▌        | 360/2219 [02:24<11:44,  2.64it/s]

Epoch 2, Batch 360, Loss: 0.4570


Epoch 2/5 Training:  18%|█▊        | 400/2219 [02:40<11:37,  2.61it/s]

Epoch 2, Batch 400, Loss: 0.0518


Epoch 2/5 Training:  20%|█▉        | 440/2219 [02:55<10:45,  2.75it/s]

Epoch 2, Batch 440, Loss: 0.0504


Epoch 2/5 Training:  22%|██▏       | 480/2219 [03:11<10:31,  2.75it/s]

Epoch 2, Batch 480, Loss: 0.1299


Epoch 2/5 Training:  23%|██▎       | 520/2219 [03:26<10:21,  2.74it/s]

Epoch 2, Batch 520, Loss: 0.0962


Epoch 2/5 Training:  25%|██▌       | 560/2219 [03:41<09:38,  2.87it/s]

Epoch 2, Batch 560, Loss: 0.6155


Epoch 2/5 Training:  27%|██▋       | 600/2219 [04:02<15:01,  1.80it/s]

Epoch 2, Batch 600, Loss: 0.9203


Epoch 2/5 Training:  29%|██▉       | 640/2219 [04:17<10:39,  2.47it/s]

Epoch 2, Batch 640, Loss: 0.1098


Epoch 2/5 Training:  31%|███       | 680/2219 [04:32<10:05,  2.54it/s]

Epoch 2, Batch 680, Loss: 0.4209


Epoch 2/5 Training:  32%|███▏      | 720/2219 [04:48<10:11,  2.45it/s]

Epoch 2, Batch 720, Loss: 0.0686


Epoch 2/5 Training:  34%|███▍      | 760/2219 [05:03<08:56,  2.72it/s]

Epoch 2, Batch 760, Loss: 0.5276


Epoch 2/5 Training:  36%|███▌      | 800/2219 [05:19<10:14,  2.31it/s]

Epoch 2, Batch 800, Loss: 0.0673


Epoch 2/5 Training:  38%|███▊      | 840/2219 [05:34<09:03,  2.54it/s]

Epoch 2, Batch 840, Loss: 0.1688


Epoch 2/5 Training:  40%|███▉      | 880/2219 [05:50<08:17,  2.69it/s]

Epoch 2, Batch 880, Loss: 0.1101


Epoch 2/5 Training:  41%|████▏     | 920/2219 [06:05<07:53,  2.74it/s]

Epoch 2, Batch 920, Loss: 0.0224


Epoch 2/5 Training:  43%|████▎     | 960/2219 [06:21<07:51,  2.67it/s]

Epoch 2, Batch 960, Loss: 0.1476


Epoch 2/5 Training:  45%|████▌     | 1000/2219 [06:36<07:20,  2.77it/s]

Epoch 2, Batch 1000, Loss: 0.4962


Epoch 2/5 Training:  47%|████▋     | 1040/2219 [06:56<08:39,  2.27it/s]

Epoch 2, Batch 1040, Loss: 0.0776


Epoch 2/5 Training:  49%|████▊     | 1080/2219 [07:11<06:59,  2.71it/s]

Epoch 2, Batch 1080, Loss: 0.3102


Epoch 2/5 Training:  50%|█████     | 1120/2219 [07:26<06:48,  2.69it/s]

Epoch 2, Batch 1120, Loss: 0.1186


Epoch 2/5 Training:  52%|█████▏    | 1160/2219 [07:42<06:27,  2.74it/s]

Epoch 2, Batch 1160, Loss: 0.5321


Epoch 2/5 Training:  54%|█████▍    | 1200/2219 [07:57<06:21,  2.67it/s]

Epoch 2, Batch 1200, Loss: 0.0217


Epoch 2/5 Training:  56%|█████▌    | 1240/2219 [08:13<06:22,  2.56it/s]

Epoch 2, Batch 1240, Loss: 0.3715


Epoch 2/5 Training:  58%|█████▊    | 1280/2219 [08:28<05:56,  2.63it/s]

Epoch 2, Batch 1280, Loss: 0.1603


Epoch 2/5 Training:  59%|█████▉    | 1320/2219 [08:43<05:20,  2.80it/s]

Epoch 2, Batch 1320, Loss: 0.4424


Epoch 2/5 Training:  61%|██████▏   | 1360/2219 [08:57<05:30,  2.60it/s]

Epoch 2, Batch 1360, Loss: 0.1503


Epoch 2/5 Training:  63%|██████▎   | 1400/2219 [09:13<05:01,  2.71it/s]

Epoch 2, Batch 1400, Loss: 1.6617


Epoch 2/5 Training:  65%|██████▍   | 1440/2219 [09:28<05:05,  2.55it/s]

Epoch 2, Batch 1440, Loss: 1.4780


Epoch 2/5 Training:  67%|██████▋   | 1480/2219 [09:43<04:40,  2.64it/s]

Epoch 2, Batch 1480, Loss: 0.5561


Epoch 2/5 Training:  68%|██████▊   | 1520/2219 [09:58<04:23,  2.65it/s]

Epoch 2, Batch 1520, Loss: 0.1442


Epoch 2/5 Training:  70%|███████   | 1560/2219 [10:14<04:01,  2.73it/s]

Epoch 2, Batch 1560, Loss: 0.1442


Epoch 2/5 Training:  72%|███████▏  | 1600/2219 [10:29<03:43,  2.77it/s]

Epoch 2, Batch 1600, Loss: 0.1045


Epoch 2/5 Training:  74%|███████▍  | 1640/2219 [10:44<03:41,  2.61it/s]

Epoch 2, Batch 1640, Loss: 0.8079


Epoch 2/5 Training:  76%|███████▌  | 1680/2219 [11:00<04:02,  2.22it/s]

Epoch 2, Batch 1680, Loss: 0.1951


Epoch 2/5 Training:  78%|███████▊  | 1720/2219 [11:16<03:12,  2.59it/s]

Epoch 2, Batch 1720, Loss: 0.2212


Epoch 2/5 Training:  79%|███████▉  | 1760/2219 [11:32<03:03,  2.50it/s]

Epoch 2, Batch 1760, Loss: 0.2899


Epoch 2/5 Training:  81%|████████  | 1800/2219 [11:48<02:48,  2.49it/s]

Epoch 2, Batch 1800, Loss: 0.3825


Epoch 2/5 Training:  83%|████████▎ | 1840/2219 [12:04<02:18,  2.74it/s]

Epoch 2, Batch 1840, Loss: 0.1531


Epoch 2/5 Training:  85%|████████▍ | 1880/2219 [12:23<03:06,  1.82it/s]

Epoch 2, Batch 1880, Loss: 0.3613


Epoch 2/5 Training:  87%|████████▋ | 1920/2219 [12:39<01:51,  2.68it/s]

Epoch 2, Batch 1920, Loss: 0.2808


Epoch 2/5 Training:  88%|████████▊ | 1960/2219 [12:55<01:46,  2.43it/s]

Epoch 2, Batch 1960, Loss: 0.0989


Epoch 2/5 Training:  90%|█████████ | 2000/2219 [13:11<01:29,  2.45it/s]

Epoch 2, Batch 2000, Loss: 0.0836


Epoch 2/5 Training:  92%|█████████▏| 2040/2219 [13:26<01:05,  2.73it/s]

Epoch 2, Batch 2040, Loss: 0.0570


Epoch 2/5 Training:  94%|█████████▎| 2080/2219 [13:42<00:54,  2.57it/s]

Epoch 2, Batch 2080, Loss: 0.1892


Epoch 2/5 Training:  96%|█████████▌| 2120/2219 [13:57<00:42,  2.35it/s]

Epoch 2, Batch 2120, Loss: 0.2925


Epoch 2/5 Training:  97%|█████████▋| 2160/2219 [14:14<00:25,  2.32it/s]

Epoch 2, Batch 2160, Loss: 0.0826


Epoch 2/5 Training:  99%|█████████▉| 2200/2219 [14:29<00:08,  2.24it/s]

Epoch 2, Batch 2200, Loss: 0.2114


Epoch 2/5 Training: 100%|██████████| 2219/2219 [14:37<00:00,  2.53it/s]
Validating: 100%|██████████| 59/59 [00:06<00:00,  9.36it/s]


Epoch 2, Validation Loss: 0.3499


Epoch 3/5 Training:   2%|▏         | 40/2219 [00:14<14:22,  2.53it/s]

Epoch 3, Batch 40, Loss: 0.1091


Epoch 3/5 Training:   4%|▎         | 80/2219 [00:30<13:53,  2.57it/s]

Epoch 3, Batch 80, Loss: 0.3378


Epoch 3/5 Training:   5%|▌         | 120/2219 [00:45<12:44,  2.74it/s]

Epoch 3, Batch 120, Loss: 0.0171


Epoch 3/5 Training:   7%|▋         | 160/2219 [01:01<14:45,  2.32it/s]

Epoch 3, Batch 160, Loss: 0.9156


Epoch 3/5 Training:   9%|▉         | 200/2219 [01:16<12:27,  2.70it/s]

Epoch 3, Batch 200, Loss: 0.1264


Epoch 3/5 Training:  11%|█         | 240/2219 [01:32<12:57,  2.55it/s]

Epoch 3, Batch 240, Loss: 0.7651


Epoch 3/5 Training:  13%|█▎        | 280/2219 [01:47<13:20,  2.42it/s]

Epoch 3, Batch 280, Loss: 0.4818


Epoch 3/5 Training:  14%|█▍        | 320/2219 [02:03<13:18,  2.38it/s]

Epoch 3, Batch 320, Loss: 0.0253


Epoch 3/5 Training:  16%|█▌        | 360/2219 [02:18<12:10,  2.54it/s]

Epoch 3, Batch 360, Loss: 0.1636


Epoch 3/5 Training:  18%|█▊        | 400/2219 [02:37<14:04,  2.16it/s]

Epoch 3, Batch 400, Loss: 0.1589


Epoch 3/5 Training:  20%|█▉        | 440/2219 [02:53<11:48,  2.51it/s]

Epoch 3, Batch 440, Loss: 0.0537


Epoch 3/5 Training:  22%|██▏       | 480/2219 [03:08<10:08,  2.86it/s]

Epoch 3, Batch 480, Loss: 0.8667


Epoch 3/5 Training:  23%|██▎       | 520/2219 [03:23<10:16,  2.76it/s]

Epoch 3, Batch 520, Loss: 0.0689


Epoch 3/5 Training:  25%|██▌       | 560/2219 [03:39<10:32,  2.62it/s]

Epoch 3, Batch 560, Loss: 0.0555


Epoch 3/5 Training:  27%|██▋       | 600/2219 [03:54<10:13,  2.64it/s]

Epoch 3, Batch 600, Loss: 0.5857


Epoch 3/5 Training:  29%|██▉       | 640/2219 [04:10<10:11,  2.58it/s]

Epoch 3, Batch 640, Loss: 0.3543


Epoch 3/5 Training:  31%|███       | 680/2219 [04:25<09:36,  2.67it/s]

Epoch 3, Batch 680, Loss: 0.3365


Epoch 3/5 Training:  32%|███▏      | 720/2219 [04:40<09:11,  2.72it/s]

Epoch 3, Batch 720, Loss: 0.0984


Epoch 3/5 Training:  34%|███▍      | 760/2219 [04:56<09:07,  2.66it/s]

Epoch 3, Batch 760, Loss: 0.0236


Epoch 3/5 Training:  36%|███▌      | 800/2219 [05:11<09:18,  2.54it/s]

Epoch 3, Batch 800, Loss: 0.4492


Epoch 3/5 Training:  38%|███▊      | 840/2219 [05:26<08:36,  2.67it/s]

Epoch 3, Batch 840, Loss: 0.0829


Epoch 3/5 Training:  40%|███▉      | 880/2219 [05:41<08:41,  2.57it/s]

Epoch 3, Batch 880, Loss: 0.0428


Epoch 3/5 Training:  41%|████▏     | 920/2219 [05:57<07:55,  2.73it/s]

Epoch 3, Batch 920, Loss: 0.0381


Epoch 3/5 Training:  43%|████▎     | 960/2219 [06:12<07:27,  2.81it/s]

Epoch 3, Batch 960, Loss: 0.1239


Epoch 3/5 Training:  45%|████▌     | 1000/2219 [06:30<09:56,  2.04it/s]

Epoch 3, Batch 1000, Loss: 0.5444


Epoch 3/5 Training:  47%|████▋     | 1040/2219 [06:47<08:07,  2.42it/s]

Epoch 3, Batch 1040, Loss: 0.6544


Epoch 3/5 Training:  49%|████▊     | 1080/2219 [07:02<07:33,  2.51it/s]

Epoch 3, Batch 1080, Loss: 0.0176


Epoch 3/5 Training:  50%|█████     | 1120/2219 [07:17<06:34,  2.79it/s]

Epoch 3, Batch 1120, Loss: 0.0104


Epoch 3/5 Training:  52%|█████▏    | 1160/2219 [07:32<06:42,  2.63it/s]

Epoch 3, Batch 1160, Loss: 0.1496


Epoch 3/5 Training:  54%|█████▍    | 1200/2219 [07:47<06:25,  2.64it/s]

Epoch 3, Batch 1200, Loss: 0.0527


Epoch 3/5 Training:  56%|█████▌    | 1240/2219 [08:02<06:04,  2.69it/s]

Epoch 3, Batch 1240, Loss: 0.0327


Epoch 3/5 Training:  58%|█████▊    | 1280/2219 [08:17<05:30,  2.84it/s]

Epoch 3, Batch 1280, Loss: 0.0143


Epoch 3/5 Training:  59%|█████▉    | 1320/2219 [08:32<05:44,  2.61it/s]

Epoch 3, Batch 1320, Loss: 0.0149


Epoch 3/5 Training:  61%|██████▏   | 1360/2219 [08:46<04:47,  2.99it/s]

Epoch 3, Batch 1360, Loss: 0.3503


Epoch 3/5 Training:  63%|██████▎   | 1400/2219 [09:01<04:58,  2.74it/s]

Epoch 3, Batch 1400, Loss: 0.6916


Epoch 3/5 Training:  65%|██████▍   | 1440/2219 [09:16<04:42,  2.76it/s]

Epoch 3, Batch 1440, Loss: 0.0388


Epoch 3/5 Training:  67%|██████▋   | 1480/2219 [09:30<04:53,  2.52it/s]

Epoch 3, Batch 1480, Loss: 0.0191


Epoch 3/5 Training:  68%|██████▊   | 1520/2219 [09:45<04:34,  2.55it/s]

Epoch 3, Batch 1520, Loss: 0.0527


Epoch 3/5 Training:  70%|███████   | 1560/2219 [10:00<04:07,  2.67it/s]

Epoch 3, Batch 1560, Loss: 0.0211


Epoch 3/5 Training:  72%|███████▏  | 1600/2219 [10:15<03:47,  2.72it/s]

Epoch 3, Batch 1600, Loss: 0.0301


Epoch 3/5 Training:  74%|███████▍  | 1640/2219 [10:30<03:51,  2.51it/s]

Epoch 3, Batch 1640, Loss: 0.9453


Epoch 3/5 Training:  76%|███████▌  | 1680/2219 [10:45<03:15,  2.76it/s]

Epoch 3, Batch 1680, Loss: 0.3837


Epoch 3/5 Training:  78%|███████▊  | 1720/2219 [10:59<03:04,  2.71it/s]

Epoch 3, Batch 1720, Loss: 0.0172


Epoch 3/5 Training:  79%|███████▉  | 1760/2219 [11:17<04:29,  1.70it/s]

Epoch 3, Batch 1760, Loss: 0.1879


Epoch 3/5 Training:  81%|████████  | 1800/2219 [11:34<02:34,  2.71it/s]

Epoch 3, Batch 1800, Loss: 0.0312


Epoch 3/5 Training:  83%|████████▎ | 1840/2219 [11:49<02:20,  2.69it/s]

Epoch 3, Batch 1840, Loss: 0.1170


Epoch 3/5 Training:  85%|████████▍ | 1880/2219 [12:04<02:07,  2.66it/s]

Epoch 3, Batch 1880, Loss: 0.1268


Epoch 3/5 Training:  87%|████████▋ | 1920/2219 [12:18<01:44,  2.85it/s]

Epoch 3, Batch 1920, Loss: 0.0347


Epoch 3/5 Training:  88%|████████▊ | 1960/2219 [12:33<01:36,  2.69it/s]

Epoch 3, Batch 1960, Loss: 0.0287


Epoch 3/5 Training:  90%|█████████ | 2000/2219 [12:48<01:21,  2.69it/s]

Epoch 3, Batch 2000, Loss: 0.1178


Epoch 3/5 Training:  92%|█████████▏| 2040/2219 [13:03<01:02,  2.88it/s]

Epoch 3, Batch 2040, Loss: 0.0563


Epoch 3/5 Training:  94%|█████████▎| 2080/2219 [13:19<00:55,  2.53it/s]

Epoch 3, Batch 2080, Loss: 0.0232


Epoch 3/5 Training:  96%|█████████▌| 2120/2219 [13:35<00:38,  2.58it/s]

Epoch 3, Batch 2120, Loss: 0.0447


Epoch 3/5 Training:  97%|█████████▋| 2160/2219 [13:50<00:21,  2.74it/s]

Epoch 3, Batch 2160, Loss: 0.2669


Epoch 3/5 Training:  99%|█████████▉| 2200/2219 [14:10<00:06,  2.72it/s]

Epoch 3, Batch 2200, Loss: 0.2184


Epoch 3/5 Training: 100%|██████████| 2219/2219 [14:17<00:00,  2.59it/s]
Validating: 100%|██████████| 59/59 [00:06<00:00,  9.59it/s]


Epoch 3, Validation Loss: 0.4341


Epoch 4/5 Training:   2%|▏         | 40/2219 [00:15<13:23,  2.71it/s]

Epoch 4, Batch 40, Loss: 0.0894


Epoch 4/5 Training:   4%|▎         | 80/2219 [00:29<12:35,  2.83it/s]

Epoch 4, Batch 80, Loss: 0.0420


Epoch 4/5 Training:   5%|▌         | 120/2219 [00:44<12:37,  2.77it/s]

Epoch 4, Batch 120, Loss: 0.0197


Epoch 4/5 Training:   7%|▋         | 160/2219 [00:59<12:46,  2.68it/s]

Epoch 4, Batch 160, Loss: 0.0221


Epoch 4/5 Training:   9%|▉         | 200/2219 [01:14<12:31,  2.69it/s]

Epoch 4, Batch 200, Loss: 0.0117


Epoch 4/5 Training:  11%|█         | 240/2219 [01:29<13:14,  2.49it/s]

Epoch 4, Batch 240, Loss: 0.3190


Epoch 4/5 Training:  13%|█▎        | 280/2219 [01:45<13:45,  2.35it/s]

Epoch 4, Batch 280, Loss: 0.0621


Epoch 4/5 Training:  14%|█▍        | 320/2219 [02:00<13:34,  2.33it/s]

Epoch 4, Batch 320, Loss: 0.0537


Epoch 4/5 Training:  16%|█▌        | 360/2219 [02:15<10:59,  2.82it/s]

Epoch 4, Batch 360, Loss: 0.0402


Epoch 4/5 Training:  18%|█▊        | 400/2219 [02:30<12:19,  2.46it/s]

Epoch 4, Batch 400, Loss: 0.0308


Epoch 4/5 Training:  20%|█▉        | 440/2219 [02:45<10:53,  2.72it/s]

Epoch 4, Batch 440, Loss: 0.0263


Epoch 4/5 Training:  22%|██▏       | 480/2219 [02:59<09:59,  2.90it/s]

Epoch 4, Batch 480, Loss: 0.0111


Epoch 4/5 Training:  23%|██▎       | 520/2219 [03:14<09:54,  2.86it/s]

Epoch 4, Batch 520, Loss: 0.0222


Epoch 4/5 Training:  25%|██▌       | 560/2219 [03:28<09:59,  2.77it/s]

Epoch 4, Batch 560, Loss: 0.0087


Epoch 4/5 Training:  27%|██▋       | 600/2219 [03:44<09:26,  2.86it/s]

Epoch 4, Batch 600, Loss: 0.0144


Epoch 4/5 Training:  29%|██▉       | 640/2219 [03:58<09:04,  2.90it/s]

Epoch 4, Batch 640, Loss: 0.0340


Epoch 4/5 Training:  31%|███       | 680/2219 [04:19<09:18,  2.76it/s]

Epoch 4, Batch 680, Loss: 0.0090


Epoch 4/5 Training:  32%|███▏      | 720/2219 [04:34<09:15,  2.70it/s]

Epoch 4, Batch 720, Loss: 0.0119


Epoch 4/5 Training:  34%|███▍      | 760/2219 [04:49<09:58,  2.44it/s]

Epoch 4, Batch 760, Loss: 0.0387


Epoch 4/5 Training:  36%|███▌      | 800/2219 [05:04<08:21,  2.83it/s]

Epoch 4, Batch 800, Loss: 0.0156


Epoch 4/5 Training:  38%|███▊      | 840/2219 [05:19<08:39,  2.65it/s]

Epoch 4, Batch 840, Loss: 0.1257


Epoch 4/5 Training:  40%|███▉      | 880/2219 [05:35<08:32,  2.61it/s]

Epoch 4, Batch 880, Loss: 0.0095


Epoch 4/5 Training:  41%|████▏     | 920/2219 [05:50<07:32,  2.87it/s]

Epoch 4, Batch 920, Loss: 0.0065


Epoch 4/5 Training:  43%|████▎     | 960/2219 [06:05<07:34,  2.77it/s]

Epoch 4, Batch 960, Loss: 0.0207


Epoch 4/5 Training:  45%|████▌     | 1000/2219 [06:20<07:51,  2.59it/s]

Epoch 4, Batch 1000, Loss: 0.4309


Epoch 4/5 Training:  47%|████▋     | 1040/2219 [06:35<07:57,  2.47it/s]

Epoch 4, Batch 1040, Loss: 0.0366


Epoch 4/5 Training:  49%|████▊     | 1080/2219 [06:50<06:57,  2.73it/s]

Epoch 4, Batch 1080, Loss: 0.3201


Epoch 4/5 Training:  50%|█████     | 1120/2219 [07:05<06:56,  2.64it/s]

Epoch 4, Batch 1120, Loss: 0.0628


Epoch 4/5 Training:  52%|█████▏    | 1160/2219 [07:20<05:54,  2.98it/s]

Epoch 4, Batch 1160, Loss: 0.0784


Epoch 4/5 Training:  54%|█████▍    | 1200/2219 [07:35<06:21,  2.67it/s]

Epoch 4, Batch 1200, Loss: 0.0510


Epoch 4/5 Training:  56%|█████▌    | 1240/2219 [07:50<06:02,  2.70it/s]

Epoch 4, Batch 1240, Loss: 0.0130


Epoch 4/5 Training:  58%|█████▊    | 1280/2219 [08:06<06:47,  2.31it/s]

Epoch 4, Batch 1280, Loss: 0.4660


Epoch 4/5 Training:  59%|█████▉    | 1320/2219 [08:22<05:50,  2.56it/s]

Epoch 4, Batch 1320, Loss: 0.0756


Epoch 4/5 Training:  61%|██████▏   | 1360/2219 [08:37<06:06,  2.34it/s]

Epoch 4, Batch 1360, Loss: 0.0087


Epoch 4/5 Training:  63%|██████▎   | 1400/2219 [08:53<06:49,  2.00it/s]

Epoch 4, Batch 1400, Loss: 0.0075


Epoch 4/5 Training:  65%|██████▍   | 1440/2219 [09:12<04:27,  2.91it/s]

Epoch 4, Batch 1440, Loss: 0.0134


Epoch 4/5 Training:  67%|██████▋   | 1480/2219 [09:27<04:27,  2.76it/s]

Epoch 4, Batch 1480, Loss: 0.0152


Epoch 4/5 Training:  68%|██████▊   | 1520/2219 [09:42<04:12,  2.77it/s]

Epoch 4, Batch 1520, Loss: 0.0032


Epoch 4/5 Training:  70%|███████   | 1560/2219 [09:57<04:00,  2.75it/s]

Epoch 4, Batch 1560, Loss: 0.1362


Epoch 4/5 Training:  72%|███████▏  | 1600/2219 [10:11<03:30,  2.95it/s]

Epoch 4, Batch 1600, Loss: 0.0186


Epoch 4/5 Training:  74%|███████▍  | 1640/2219 [10:26<03:38,  2.65it/s]

Epoch 4, Batch 1640, Loss: 0.0928


Epoch 4/5 Training:  76%|███████▌  | 1680/2219 [10:41<03:33,  2.53it/s]

Epoch 4, Batch 1680, Loss: 0.1775


Epoch 4/5 Training:  78%|███████▊  | 1720/2219 [10:57<03:08,  2.65it/s]

Epoch 4, Batch 1720, Loss: 0.0145


Epoch 4/5 Training:  79%|███████▉  | 1760/2219 [11:12<02:42,  2.83it/s]

Epoch 4, Batch 1760, Loss: 0.0880


Epoch 4/5 Training:  81%|████████  | 1800/2219 [11:27<02:29,  2.80it/s]

Epoch 4, Batch 1800, Loss: 0.0573


Epoch 4/5 Training:  83%|████████▎ | 1840/2219 [11:42<02:21,  2.68it/s]

Epoch 4, Batch 1840, Loss: 0.0947


Epoch 4/5 Training:  85%|████████▍ | 1880/2219 [11:57<02:05,  2.71it/s]

Epoch 4, Batch 1880, Loss: 0.2907


Epoch 4/5 Training:  87%|████████▋ | 1920/2219 [12:12<01:50,  2.71it/s]

Epoch 4, Batch 1920, Loss: 0.0608


Epoch 4/5 Training:  88%|████████▊ | 1960/2219 [12:27<01:39,  2.61it/s]

Epoch 4, Batch 1960, Loss: 0.0285


Epoch 4/5 Training:  90%|█████████ | 2000/2219 [12:42<01:23,  2.63it/s]

Epoch 4, Batch 2000, Loss: 0.6235


Epoch 4/5 Training:  92%|█████████▏| 2040/2219 [13:00<01:27,  2.04it/s]

Epoch 4, Batch 2040, Loss: 0.0151


Epoch 4/5 Training:  94%|█████████▎| 2080/2219 [13:17<00:49,  2.80it/s]

Epoch 4, Batch 2080, Loss: 0.4716


Epoch 4/5 Training:  96%|█████████▌| 2120/2219 [13:33<00:38,  2.58it/s]

Epoch 4, Batch 2120, Loss: 0.0043


Epoch 4/5 Training:  97%|█████████▋| 2160/2219 [13:48<00:21,  2.69it/s]

Epoch 4, Batch 2160, Loss: 0.0041


Epoch 4/5 Training:  99%|█████████▉| 2200/2219 [14:03<00:07,  2.66it/s]

Epoch 4, Batch 2200, Loss: 0.0071


Epoch 4/5 Training: 100%|██████████| 2219/2219 [14:10<00:00,  2.61it/s]
Validating: 100%|██████████| 59/59 [00:06<00:00,  9.08it/s]


Epoch 4, Validation Loss: 0.4115


Epoch 5/5 Training:   2%|▏         | 40/2219 [00:14<13:31,  2.68it/s]

Epoch 5, Batch 40, Loss: 0.0027


Epoch 5/5 Training:   4%|▎         | 80/2219 [00:30<12:46,  2.79it/s]

Epoch 5, Batch 80, Loss: 0.0094


Epoch 5/5 Training:   5%|▌         | 120/2219 [00:45<13:45,  2.54it/s]

Epoch 5, Batch 120, Loss: 0.0061


Epoch 5/5 Training:   7%|▋         | 160/2219 [01:00<12:41,  2.71it/s]

Epoch 5, Batch 160, Loss: 0.0045


Epoch 5/5 Training:   9%|▉         | 200/2219 [01:15<12:38,  2.66it/s]

Epoch 5, Batch 200, Loss: 0.0727


Epoch 5/5 Training:  11%|█         | 240/2219 [01:30<11:57,  2.76it/s]

Epoch 5, Batch 240, Loss: 0.0038


Epoch 5/5 Training:  13%|█▎        | 280/2219 [01:46<12:04,  2.67it/s]

Epoch 5, Batch 280, Loss: 0.0422


Epoch 5/5 Training:  14%|█▍        | 320/2219 [02:01<11:45,  2.69it/s]

Epoch 5, Batch 320, Loss: 0.0632


Epoch 5/5 Training:  16%|█▌        | 360/2219 [02:16<11:54,  2.60it/s]

Epoch 5, Batch 360, Loss: 0.0389


Epoch 5/5 Training:  18%|█▊        | 400/2219 [02:32<12:49,  2.36it/s]

Epoch 5, Batch 400, Loss: 0.0066


Epoch 5/5 Training:  20%|█▉        | 440/2219 [02:48<11:11,  2.65it/s]

Epoch 5, Batch 440, Loss: 0.0054


Epoch 5/5 Training:  22%|██▏       | 480/2219 [03:08<13:45,  2.11it/s]

Epoch 5, Batch 480, Loss: 0.4725


Epoch 5/5 Training:  23%|██▎       | 520/2219 [03:23<10:34,  2.68it/s]

Epoch 5, Batch 520, Loss: 0.0041


Epoch 5/5 Training:  25%|██▌       | 560/2219 [03:39<10:21,  2.67it/s]

Epoch 5, Batch 560, Loss: 0.0921


Epoch 5/5 Training:  27%|██▋       | 600/2219 [03:54<09:46,  2.76it/s]

Epoch 5, Batch 600, Loss: 0.0065


Epoch 5/5 Training:  29%|██▉       | 640/2219 [04:09<09:34,  2.75it/s]

Epoch 5, Batch 640, Loss: 0.0398


Epoch 5/5 Training:  31%|███       | 680/2219 [04:24<09:41,  2.64it/s]

Epoch 5, Batch 680, Loss: 0.0102


Epoch 5/5 Training:  32%|███▏      | 720/2219 [04:40<09:59,  2.50it/s]

Epoch 5, Batch 720, Loss: 0.0019


Epoch 5/5 Training:  34%|███▍      | 760/2219 [04:56<09:12,  2.64it/s]

Epoch 5, Batch 760, Loss: 0.0125


Epoch 5/5 Training:  36%|███▌      | 800/2219 [05:12<09:16,  2.55it/s]

Epoch 5, Batch 800, Loss: 0.0030


Epoch 5/5 Training:  38%|███▊      | 840/2219 [05:28<08:38,  2.66it/s]

Epoch 5, Batch 840, Loss: 0.0030


Epoch 5/5 Training:  40%|███▉      | 880/2219 [05:44<08:35,  2.60it/s]

Epoch 5, Batch 880, Loss: 0.0112


Epoch 5/5 Training:  41%|████▏     | 920/2219 [05:59<08:33,  2.53it/s]

Epoch 5, Batch 920, Loss: 0.0229


Epoch 5/5 Training:  43%|████▎     | 960/2219 [06:14<07:22,  2.84it/s]

Epoch 5, Batch 960, Loss: 0.0099


Epoch 5/5 Training:  45%|████▌     | 1000/2219 [06:29<08:40,  2.34it/s]

Epoch 5, Batch 1000, Loss: 0.0082


Epoch 5/5 Training:  47%|████▋     | 1040/2219 [06:44<07:27,  2.63it/s]

Epoch 5, Batch 1040, Loss: 0.0149


Epoch 5/5 Training:  49%|████▊     | 1080/2219 [06:59<09:06,  2.08it/s]

Epoch 5, Batch 1080, Loss: 0.0043


Epoch 5/5 Training:  50%|█████     | 1120/2219 [07:18<07:00,  2.62it/s]

Epoch 5, Batch 1120, Loss: 0.0041


Epoch 5/5 Training:  52%|█████▏    | 1160/2219 [07:33<06:20,  2.78it/s]

Epoch 5, Batch 1160, Loss: 0.0391


Epoch 5/5 Training:  54%|█████▍    | 1200/2219 [07:48<06:33,  2.59it/s]

Epoch 5, Batch 1200, Loss: 0.5127


Epoch 5/5 Training:  56%|█████▌    | 1240/2219 [08:04<06:36,  2.47it/s]

Epoch 5, Batch 1240, Loss: 0.0322


Epoch 5/5 Training:  58%|█████▊    | 1280/2219 [08:19<05:50,  2.68it/s]

Epoch 5, Batch 1280, Loss: 0.0097


Epoch 5/5 Training:  59%|█████▉    | 1320/2219 [08:34<05:54,  2.53it/s]

Epoch 5, Batch 1320, Loss: 0.0140


Epoch 5/5 Training:  61%|██████▏   | 1360/2219 [08:50<05:15,  2.72it/s]

Epoch 5, Batch 1360, Loss: 0.0253


Epoch 5/5 Training:  63%|██████▎   | 1400/2219 [09:05<04:56,  2.76it/s]

Epoch 5, Batch 1400, Loss: 0.0024


Epoch 5/5 Training:  65%|██████▍   | 1440/2219 [09:21<05:06,  2.54it/s]

Epoch 5, Batch 1440, Loss: 0.0060


Epoch 5/5 Training:  67%|██████▋   | 1480/2219 [09:37<04:17,  2.87it/s]

Epoch 5, Batch 1480, Loss: 0.0161


Epoch 5/5 Training:  68%|██████▊   | 1520/2219 [09:57<04:37,  2.52it/s]

Epoch 5, Batch 1520, Loss: 0.0036


Epoch 5/5 Training:  70%|███████   | 1560/2219 [10:12<03:58,  2.76it/s]

Epoch 5, Batch 1560, Loss: 0.0034


Epoch 5/5 Training:  72%|███████▏  | 1600/2219 [10:27<04:03,  2.54it/s]

Epoch 5, Batch 1600, Loss: 0.0139


Epoch 5/5 Training:  74%|███████▍  | 1640/2219 [10:43<03:43,  2.59it/s]

Epoch 5, Batch 1640, Loss: 0.0086


Epoch 5/5 Training:  76%|███████▌  | 1680/2219 [10:58<03:28,  2.58it/s]

Epoch 5, Batch 1680, Loss: 0.0120


Epoch 5/5 Training:  78%|███████▊  | 1720/2219 [11:13<03:07,  2.66it/s]

Epoch 5, Batch 1720, Loss: 0.0488


Epoch 5/5 Training:  79%|███████▉  | 1760/2219 [11:28<02:47,  2.75it/s]

Epoch 5, Batch 1760, Loss: 0.0089


Epoch 5/5 Training:  81%|████████  | 1800/2219 [11:44<02:52,  2.43it/s]

Epoch 5, Batch 1800, Loss: 0.0753


Epoch 5/5 Training:  83%|████████▎ | 1840/2219 [11:59<02:15,  2.79it/s]

Epoch 5, Batch 1840, Loss: 0.0318


Epoch 5/5 Training:  85%|████████▍ | 1880/2219 [12:14<02:03,  2.73it/s]

Epoch 5, Batch 1880, Loss: 0.0113


Epoch 5/5 Training:  87%|████████▋ | 1920/2219 [12:29<01:49,  2.74it/s]

Epoch 5, Batch 1920, Loss: 0.0280


Epoch 5/5 Training:  88%|████████▊ | 1960/2219 [12:44<01:34,  2.73it/s]

Epoch 5, Batch 1960, Loss: 0.0326


Epoch 5/5 Training:  90%|█████████ | 2000/2219 [12:59<01:20,  2.73it/s]

Epoch 5, Batch 2000, Loss: 0.4260


Epoch 5/5 Training:  92%|█████████▏| 2040/2219 [13:15<01:06,  2.70it/s]

Epoch 5, Batch 2040, Loss: 0.0504


Epoch 5/5 Training:  94%|█████████▎| 2080/2219 [13:30<00:55,  2.50it/s]

Epoch 5, Batch 2080, Loss: 0.0777


Epoch 5/5 Training:  96%|█████████▌| 2120/2219 [13:46<00:38,  2.55it/s]

Epoch 5, Batch 2120, Loss: 0.0041


Epoch 5/5 Training:  97%|█████████▋| 2160/2219 [14:00<00:22,  2.64it/s]

Epoch 5, Batch 2160, Loss: 0.0151


Epoch 5/5 Training:  99%|█████████▉| 2200/2219 [14:15<00:06,  2.98it/s]

Epoch 5, Batch 2200, Loss: 0.0011


Epoch 5/5 Training: 100%|██████████| 2219/2219 [14:24<00:00,  2.57it/s]
Validating: 100%|██████████| 59/59 [00:08<00:00,  7.08it/s]


Epoch 5, Validation Loss: 0.4653


DebertaV2ForSequenceClassification(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 1024, padding_idx=0)
      (LayerNorm): LayerNorm((1024,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-23): 24 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (key_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (value_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-07, element

In [10]:
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=len(classes), id2label=id2class, label2id=class2id, problem_type="multi_label_classification")
model.load_state_dict(torch.load('classification_beta/model_epoch_5.pth'))
model = model.to(device)
model.eval()

Some weights of DebertaV2ForSequenceClassification were not initialized from the model checkpoint at microsoft/deberta-v3-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DebertaV2ForSequenceClassification(
  (deberta): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 1024, padding_idx=0)
      (LayerNorm): LayerNorm((1024,), eps=1e-07, elementwise_affine=True)
      (dropout): StableDropout()
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-23): 24 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (key_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (value_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (pos_dropout): StableDropout()
              (dropout): StableDropout()
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-07, element

In [11]:
criterion = torch.nn.CrossEntropyLoss()

total_test_loss = 0
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for batch in tqdm(test_dataloader, desc="Testing"):
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        outputs = model(**inputs)
        test_loss = criterion(outputs.logits, labels)
        total_test_loss += test_loss.item()

        # Compute accuracy
        predicted_labels = torch.max(outputs.logits, 1)[1]  # Gets the indices of max logit
        predicted_labels_one_hot = F.one_hot(predicted_labels, num_classes=len(classes)).to(device)

        # Ensure labels are in the correct dtype for comparison, they should be in the same dtype as 'predicted_labels_one_hot'
        labels = labels.to(torch.int64)  # Adjust dtype if necessary

        # Now calculate the accuracy
        correct_predictions += (predicted_labels_one_hot == labels).all(dim=1).sum().item()
        total_predictions += labels.size(0)

avg_test_loss = total_test_loss / len(test_dataloader)
accuracy = correct_predictions / total_predictions * 100

print(f"Average test loss: {avg_test_loss:.2f}")
print(f"Accuracy: {accuracy:.2f}%")


Testing: 100%|██████████| 59/59 [00:05<00:00, 10.75it/s]

Average test loss: 0.46
Accuracy: 88.65%





In [12]:
# custom input into the model
custom_text = "I wish this device would not run out of power so quickly"
print("\033[94m" + custom_text + "\033[0m")

tokenized_text = tokenizer(custom_text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
tokenized_text = {k: v.to(device) for k, v in tokenized_text.items()}
output = model(**tokenized_text)

predicted_labels = torch.max(output.logits, 1)[1]
predicted_class = id2class[predicted_labels.item()]

print(f"Predicted class: {predicted_class}")

custom_text = "I can't hear a thing on this crappy headset"
print("\033[94m" + custom_text + "\033[0m")

tokenized_text = tokenizer(custom_text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
tokenized_text = {k: v.to(device) for k, v in tokenized_text.items()}
output = model(**tokenized_text)

predicted_labels = torch.max(output.logits, 1)[1]
predicted_class = id2class[predicted_labels.item()]

print(f"Predicted class: {predicted_class}")

[94mI wish this device would not run out of power so quickly[0m
Predicted class: Battery Life
[94mI can't hear a thing on this crappy headset[0m
Predicted class: Audio
