In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import torchmetrics
import pytorch_lightning as pl
from torchmetrics.classification import MultilabelPrecision, MultilabelRecall, MultilabelF1Score
from pytorch_lightning.loggers import CSVLogger



splits = {'java_train': 'data/java_train-00000-of-00001.parquet', 'java_test': 'data/java_test-00000-of-00001.parquet', 'python_train': 'data/python_train-00000-of-00001.parquet', 'python_test': 'data/python_test-00000-of-00001.parquet', 'pharo_train': 'data/pharo_train-00000-of-00001.parquet', 'pharo_test': 'data/pharo_test-00000-of-00001.parquet'}

java_train = pd.read_parquet("hf://datasets/NLBSE/nlbse25-code-comment-classification/" + splits["java_train"])
java_test = pd.read_parquet("hf://datasets/NLBSE/nlbse25-code-comment-classification/" + splits["java_test"])

python_train = pd.read_parquet("hf://datasets/NLBSE/nlbse25-code-comment-classification/" + splits["python_train"])
python_test = pd.read_parquet("hf://datasets/NLBSE/nlbse25-code-comment-classification/" + splits["python_test"])

pharo_train = pd.read_parquet("hf://datasets/NLBSE/nlbse25-code-comment-classification/" + splits["pharo_train"])
pharo_test = pd.read_parquet("hf://datasets/NLBSE/nlbse25-code-comment-classification/" + splits["pharo_test"])

# Split Java dataset
java_train_data, java_val_data = train_test_split(java_train, test_size=0.2, random_state=42)

# Split Python dataset
python_train_data, python_val_data = train_test_split(python_train, test_size=0.2, random_state=42)

# Split Pharo dataset
pharo_train_data, pharo_val_data = train_test_split(pharo_train, test_size=0.2, random_state=42)

print(f"Java train size: {len(java_train_data)}, Java val size: {len(java_val_data)}")
print(f"Python train size: {len(python_train_data)}, Python val size: {len(python_val_data)}")
print(f"Pharo train size: {len(pharo_train_data)}, Pharo val size: {len(pharo_val_data)}")

#print(java_train.iloc[0, :])
#print(python_train.iloc[0, :])
#print(pharo_train.iloc[0, :])

  from .autonotebook import tqdm as notebook_tqdm


Java train size: 6091, Java val size: 1523
Python train size: 1507, Python val size: 377
Pharo train size: 1038, Pharo val size: 260


In [6]:
class JavaCommentDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.comments = dataframe['combo'].tolist()
        self.labels = dataframe['labels'].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.comments)
    
    def __getitem__(self, idx):
        # Tokenize the text
        text = self.comments[idx]
        tokens = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        # Process labels
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        
        # Reshape input for CNN
        input_ids = tokens['input_ids'].squeeze(0)
        
        # Reshape embeddings to match CNN input format [batch_size, channels, sequence_length, embedding_dim]
        cnn_input = input_ids.unsqueeze(0)
        
        return {
            'input_ids': cnn_input,
            'labels': label
        }

In [7]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
max_len = 512 

# Prepare Dataset
train_dataset = JavaCommentDataset(java_train_data, tokenizer, max_len)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = JavaCommentDataset(java_val_data, tokenizer, max_len)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

test_dataset = JavaCommentDataset(java_test, tokenizer, max_len)
test_loader = DataLoader(test_dataset, batch_size=32)


In [8]:
from pytorch_lightning.callbacks import EarlyStopping
early_stopping = EarlyStopping(
    monitor="val_loss",  # Metric to monitor
    mode="min",          # Use "max" if monitoring accuracy
    patience=10,          # Number of epochs with no improvement after which training will be stopped
    verbose=True         # Prints information about early stopping when triggered
)

In [9]:
class PyTorchCNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_classes):
        super(PyTorchCNN, self).__init__()
        # Embedding layer
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embed_dim)
        # CNN layers definition
        self.cnn_layers = torch.nn.Sequential(
            torch.nn.Conv2d(1, 3, kernel_size=(5, embed_dim)),  # Example embedding size
            torch.nn.BatchNorm2d(3),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=(2, 1)), 
            
            torch.nn.Conv2d(3, 16, kernel_size=(3, 1)),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=(2, 1)), 
            
            torch.nn.Conv2d(16, 32, kernel_size=(3, 1)),
            torch.nn.BatchNorm2d(32),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=(2, 1))
        )
        # Dynamically calculate flattened size
        self.flattened_size = self._get_flattened_size(embed_dim)
        # Fully connected layers
        self.fc_layers = torch.nn.Sequential(
            torch.nn.Linear(self.flattened_size, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.ReLU(),

            torch.nn.Linear(512, 256),                  # Second layer
            torch.nn.BatchNorm1d(256),                  
            torch.nn.ReLU(),

            torch.nn.Linear(256, 128),                  # Third layer
            torch.nn.BatchNorm1d(128),                  
            torch.nn.ReLU(),

            torch.nn.Linear(128, num_classes)  # Output layer
        )
    def _get_flattened_size(self, embed_dim):
        """
        Computes the size of the flattened output after the CNN layers.
        """
        with torch.no_grad():
            # Create a dummy input of shape [batch_size=1, channels=1, sequence_length, embed_dim]
            dummy_input = torch.zeros(1, 1, 512, embed_dim)
            cnn_out = self.cnn_layers(dummy_input)
            return cnn_out.numel()
        
    def forward(self, input_ids):
        embeddings = self.embedding(input_ids)  # Shape: [batch_size, sequence_length, embed_dim]

        #Pass through CNN layers
        cnn_out = self.cnn_layers(embeddings) 

        cnn_out = torch.flatten(cnn_out, 1) 

        output = self.fc_layers(cnn_out) 

        return output


In [10]:
class LightningModel(pl.LightningModule):
    def __init__(self, model, learning_rate, num_classes):
        super().__init__()

        self.learning_rate = learning_rate
        self.model = model
        self.num_classes = num_classes
        # Metrics
        self.train_acc = torchmetrics.Accuracy(task="multilabel", num_labels=num_classes)
        self.val_acc = torchmetrics.Accuracy(task="multilabel", num_labels=num_classes)
        self.test_acc = torchmetrics.Accuracy(task="multilabel", num_labels=num_classes)

        # Initialize class-wise accuracy tracking
        self.class_wise_train_acc = {i: torchmetrics.Accuracy(task="multilabel", num_labels=num_classes) for i in range(num_classes)}
        self.class_wise_val_acc = {i: torchmetrics.Accuracy(task="multilabel", num_labels=num_classes) for i in range(num_classes)}
        self.class_wise_test_acc = {i: torchmetrics.Accuracy(task="multilabel", num_labels=num_classes) for i in range(num_classes)}
        # Precision Metrics
        self.train_precision = MultilabelPrecision(num_labels=num_classes, average="none")
        self.val_precision = MultilabelPrecision(num_labels=num_classes, average="none")
        self.test_precision = MultilabelPrecision(num_labels=num_classes, average="none")
        # Recall Metrics
        self.train_recall = MultilabelRecall(num_labels=num_classes, average="none")
        self.val_recall = MultilabelRecall(num_labels=num_classes, average="none")
        self.test_recall = MultilabelRecall(num_labels=num_classes, average="none")

        # F1 Metrics
        self.train_f1 = MultilabelF1Score(num_labels=num_classes, average="none")
        self.val_f1 = MultilabelF1Score(num_labels=num_classes, average="none")
        self.test_f1 = MultilabelF1Score(num_labels=num_classes, average="none")

        
    def forward(self, x):
        return self.model(x)

    def _shared_step(self, batch):
        input_ids = batch['input_ids']  # Tokenized input
        true_labels = batch['labels']  # Multi-hot encoded labels

        logits = self.model(input_ids)
        
        # Compute the loss
        loss = F.cross_entropy(logits, true_labels)
        
        # Compute the predicted labels by applying a threshold
        probabilities = F.softmax(logits, dim=-1) 

        predicted_labels_idx = torch.argmax(probabilities, dim=-1)
        
        batch_size = predicted_labels_idx.size(0)
        num_classes = probabilities.size(1)  # Number of classes
        one_hot_predictions = torch.zeros(batch_size, num_classes, device=logits.device)
        one_hot_predictions.scatter_(1, predicted_labels_idx.unsqueeze(1), 1)
        
        return loss, true_labels, one_hot_predictions

    def training_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("train_loss", loss)
        self.train_acc(predicted_labels, true_labels)
        self.log("train_acc", self.train_acc, prog_bar=True, on_epoch=True, on_step=False)
        # Class-wise accuracy logging
        class_accuracies = self._calculate_class_accuracy(predicted_labels, true_labels)

        precision_values = self.train_precision(predicted_labels, true_labels)
        for i, precision in enumerate(precision_values):
            self.log(f"train_precision_class_{i}", precision, prog_bar=True)

        # Recall Logging
        recall_values = self.train_recall(predicted_labels, true_labels)
        for i, recall in enumerate(recall_values):
            self.log(f"train_recall_class_{i}", recall, prog_bar=True)

        # F1 Logging
        f1_values = self.train_f1(predicted_labels, true_labels)
        for i, f1 in enumerate(f1_values):
            self.log(f"train_f1_class_{i}", f1, prog_bar=True)

        # for class_idx, accuracy in class_accuracies.items():
        #     self.log(f"class_{class_idx}_train_acc", accuracy, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)

        self.log("val_loss", loss, prog_bar=True)
        self.val_acc(predicted_labels, true_labels)
        self.log("val_acc", self.val_acc, prog_bar=True)
        # Class-wise accuracy logging
        class_accuracies = self._calculate_class_accuracy(predicted_labels, true_labels)

        # for class_idx, accuracy in class_accuracies.items():
        #     self.log(f"class_{class_idx}_val_acc", accuracy, prog_bar=True)

        # Precision
        precision_values = self.val_precision(predicted_labels, true_labels)
        for i, precision in enumerate(precision_values):
            self.log(f"val_precision_class_{i}", precision, prog_bar=True)

        # Recall Logging
        recall_values = self.val_recall(predicted_labels, true_labels)
        for i, recall in enumerate(recall_values):
            self.log(f"val_recall_class_{i}", recall, prog_bar=True)

        # F1 Logging
        f1_values = self.val_f1(predicted_labels, true_labels)
        for i, f1 in enumerate(f1_values):
            self.log(f"val_f1_class_{i}", f1, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, true_labels, predicted_labels = self._shared_step(batch)
        self.test_acc(predicted_labels, true_labels)
        self.log("test_acc", self.test_acc)

        class_accuracies = self._calculate_class_accuracy(predicted_labels, true_labels)
        
        # for class_idx, accuracy in class_accuracies.items():
        #     self.log(f"class_{class_idx}_test_acc", accuracy)
        # Precision
        precision_values = self.test_precision(predicted_labels, true_labels)
        for i, precision in enumerate(precision_values):
            self.log(f"test_precision_class_{i}", precision)

        # Recall Logging
        recall_values = self.test_recall(predicted_labels, true_labels)
        for i, recall in enumerate(recall_values):
            self.log(f"test_recall_class_{i}", recall)

        # F1 Logging
        f1_values = self.test_f1(predicted_labels, true_labels)
        for i, f1 in enumerate(f1_values):
            self.log(f"test_f1_class_{i}", f1)


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    def _calculate_class_accuracy(self, predicted_labels, true_labels):
        class_accuracies = {}

        # Compute correct predictions per class
        correct_per_class = (predicted_labels * true_labels).sum(dim=0)  # Element-wise AND followed by sum across batch
        total_per_class = true_labels.sum(dim=0)  # Total true instances per class

        # Calculate accuracy for each class, avoiding division by zero
        for i in range(self.num_classes):
            correct = correct_per_class[i].item()
            total = total_per_class[i].item()

            if total > 0:
                class_accuracies[i] = correct / total
            else:
                class_accuracies[i] = 0.0 

        return class_accuracies
    


In [11]:
pl.seed_everything(123)
vocab_size = 30522  #For BERT tokenizer
embed_dim = 768     #embedding dimension

Seed set to 123


In [8]:
pytorch_model = PyTorchCNN(vocab_size=vocab_size, embed_dim=embed_dim,num_classes=7)  
lightning_model = LightningModel(model=pytorch_model, learning_rate=0.01,num_classes=7 )

# Setup PyTorch Lightning trainer
trainer = pl.Trainer(
    max_epochs=200,
    accelerator="gpu",  # Change to "gpu" if you want to use a GPU
    devices="auto",
    callbacks=[early_stopping],
    logger=CSVLogger(save_dir="logs/", name="my-model"),
    deterministic=True
)

# Train the model
trainer.fit(lightning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | model           | PyTorchCNN          | 24.6 M | train
1  | train_acc       | MultilabelAccuracy  | 0      | train
2  | val_acc         | MultilabelAccuracy  | 0      | train
3  | test_acc        | MultilabelAccuracy  | 0      | train
4  | train_precision | MultilabelPrecision | 0      | train
5  | val_precision   | MultilabelPrecision | 0     

Sanity Checking:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 191/191 [02:08<00:00,  1.49it/s, v_num=75, train_precision_class_0=1.000, train_precision_class_1=0.000, train_precision_class_2=0.000, train_precision_class_3=0.250, train_precision_class_4=0.000, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=0.625, train_recall_class_1=0.000, train_recall_class_2=0.000, train_recall_class_3=0.333, train_recall_class_4=0.000, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=0.769, train_f1_class_1=0.000, train_f1_class_2=0.000, train_f1_class_3=0.286, train_f1_class_4=0.000, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=1.190, val_acc=0.871, val_precision_class_0=0.643, val_precision_class_1=0.042, val_precision_class_2=0.084, val_precision_class_3=0.285, val_precision_class_4=0.418, val_precision_class_5=0.000, val_precision_class_6=0.000, val_recall_class_0=0.940, val_recall_class_1=0.0315, val_recall_class_2=0.0375, val_recall_class_3=0.061, val_recall

Metric val_loss improved. New best score: 1.190


Epoch 1: 100%|██████████| 191/191 [02:11<00:00,  1.45it/s, v_num=75, train_precision_class_0=1.000, train_precision_class_1=0.000, train_precision_class_2=0.000, train_precision_class_3=0.000, train_precision_class_4=0.000, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=0.667, train_recall_class_1=0.000, train_recall_class_2=0.000, train_recall_class_3=0.000, train_recall_class_4=0.000, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=0.800, train_f1_class_1=0.000, train_f1_class_2=0.000, train_f1_class_3=0.000, train_f1_class_4=0.000, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=0.925, val_acc=0.928, val_precision_class_0=0.741, val_precision_class_1=0.000, val_precision_class_2=0.021, val_precision_class_3=0.920, val_precision_class_4=0.634, val_precision_class_5=0.000, val_precision_class_6=0.000, val_recall_class_0=0.990, val_recall_class_1=0.000, val_recall_class_2=0.00525, val_recall_class_3=0.722, val_recall

Metric val_loss improved by 0.266 >= min_delta = 0.0. New best score: 0.925


Epoch 2: 100%|██████████| 191/191 [02:12<00:00,  1.45it/s, v_num=75, train_precision_class_0=0.833, train_precision_class_1=0.000, train_precision_class_2=0.000, train_precision_class_3=1.000, train_precision_class_4=0.333, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=1.000, train_recall_class_1=0.000, train_recall_class_2=0.000, train_recall_class_3=1.000, train_recall_class_4=0.500, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=0.909, train_f1_class_1=0.000, train_f1_class_2=0.000, train_f1_class_3=1.000, train_f1_class_4=0.400, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=0.628, val_acc=0.947, val_precision_class_0=0.863, val_precision_class_1=0.336, val_precision_class_2=0.404, val_precision_class_3=0.881, val_precision_class_4=0.754, val_precision_class_5=0.000, val_precision_class_6=0.000, val_recall_class_0=0.953, val_recall_class_1=0.288, val_recall_class_2=0.325, val_recall_class_3=0.871, val_recall_c

Metric val_loss improved by 0.297 >= min_delta = 0.0. New best score: 0.628


Epoch 4: 100%|██████████| 191/191 [02:11<00:00,  1.45it/s, v_num=75, train_precision_class_0=0.833, train_precision_class_1=0.000, train_precision_class_2=1.000, train_precision_class_3=1.000, train_precision_class_4=1.000, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=1.000, train_recall_class_1=0.000, train_recall_class_2=0.500, train_recall_class_3=1.000, train_recall_class_4=1.000, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=0.909, train_f1_class_1=0.000, train_f1_class_2=0.667, train_f1_class_3=1.000, train_f1_class_4=1.000, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=0.608, val_acc=0.951, val_precision_class_0=0.895, val_precision_class_1=0.625, val_precision_class_2=0.417, val_precision_class_3=0.930, val_precision_class_4=0.788, val_precision_class_5=0.000, val_precision_class_6=0.182, val_recall_class_0=0.927, val_recall_class_1=0.674, val_recall_class_2=0.343, val_recall_class_3=0.855, val_recall_c

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 0.608


Epoch 14: 100%|██████████| 191/191 [02:11<00:00,  1.45it/s, v_num=75, train_precision_class_0=1.000, train_precision_class_1=0.000, train_precision_class_2=1.000, train_precision_class_3=1.000, train_precision_class_4=1.000, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=1.000, train_recall_class_1=0.000, train_recall_class_2=1.000, train_recall_class_3=1.000, train_recall_class_4=1.000, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=1.000, train_f1_class_1=0.000, train_f1_class_2=1.000, train_f1_class_3=1.000, train_f1_class_4=1.000, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=0.703, val_acc=0.948, val_precision_class_0=0.931, val_precision_class_1=0.643, val_precision_class_2=0.389, val_precision_class_3=0.919, val_precision_class_4=0.863, val_precision_class_5=0.202, val_precision_class_6=0.268, val_recall_class_0=0.835, val_recall_class_1=0.639, val_recall_class_2=0.389, val_recall_class_3=0.881, val_recall_

Monitored metric val_loss did not improve in the last 10 records. Best score: 0.608. Signaling Trainer to stop.


Epoch 14: 100%|██████████| 191/191 [02:11<00:00,  1.45it/s, v_num=75, train_precision_class_0=1.000, train_precision_class_1=0.000, train_precision_class_2=1.000, train_precision_class_3=1.000, train_precision_class_4=1.000, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=1.000, train_recall_class_1=0.000, train_recall_class_2=1.000, train_recall_class_3=1.000, train_recall_class_4=1.000, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=1.000, train_f1_class_1=0.000, train_f1_class_2=1.000, train_f1_class_3=1.000, train_f1_class_4=1.000, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=0.703, val_acc=0.948, val_precision_class_0=0.931, val_precision_class_1=0.643, val_precision_class_2=0.389, val_precision_class_3=0.919, val_precision_class_4=0.863, val_precision_class_5=0.202, val_precision_class_6=0.268, val_recall_class_0=0.835, val_recall_class_1=0.639, val_recall_class_2=0.389, val_recall_class_3=0.881, val_recall_

In [9]:
trainer.test(model=lightning_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 54/54 [00:02<00:00, 21.09it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.9122152924537659
     test_f1_class_0        0.6910209059715271
     test_f1_class_1        0.49542027711868286
     test_f1_class_2        0.09800131618976593
     test_f1_class_3        0.6750164031982422
     test_f1_class_4        0.5982286930084229
     test_f1_class_5        0.10512077808380127
     test_f1_class_6        0.1029297485947609
 test_precision_class_0     0.8559073805809021
 test_precision_class_1     0.4991304278373718
 test_precision_class_2     0.10679098218679428
 test_precision_class_3     0.6961787939071655
 test_precision_class_4     0.6044206619262695
 test_precision_class_5     0

[{'test_acc': 0.9122152924537659,
  'test_precision_class_0': 0.8559073805809021,
  'test_precision_class_1': 0.4991304278373718,
  'test_precision_class_2': 0.10679098218679428,
  'test_precision_class_3': 0.6961787939071655,
  'test_precision_class_4': 0.6044206619262695,
  'test_precision_class_5': 0.11130435019731522,
  'test_precision_class_6': 0.09328364580869675,
  'test_recall_class_0': 0.6058998107910156,
  'test_recall_class_1': 0.49294689297676086,
  'test_recall_class_2': 0.10322607308626175,
  'test_recall_class_3': 0.7043520212173462,
  'test_recall_class_4': 0.6157073974609375,
  'test_recall_class_5': 0.102028988301754,
  'test_recall_class_6': 0.15026088058948517,
  'test_f1_class_0': 0.6910209059715271,
  'test_f1_class_1': 0.49542027711868286,
  'test_f1_class_2': 0.09800131618976593,
  'test_f1_class_3': 0.6750164031982422,
  'test_f1_class_4': 0.5982286930084229,
  'test_f1_class_5': 0.10512077808380127,
  'test_f1_class_6': 0.1029297485947609}]

# Python Comments

## 1) Modify Dataset Class for Python

In [13]:
class PythonCommentDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.comments = dataframe['combo'].tolist()
        self.labels = dataframe['labels'].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.comments)
    
    def __getitem__(self, idx):
        # Tokenize the text
        text = self.comments[idx]
        tokens = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        # Process labels
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        
        # Reshape input for CNN
        input_ids = tokens['input_ids'].squeeze(0)
        
        # Reshape embeddings to match CNN input format [batch_size, channels, sequence_length, embedding_dim]
        cnn_input = input_ids.unsqueeze(0)
        
        return {
            'input_ids': cnn_input,
            'labels': label
        }


## 2) Load Python Data

In [14]:
# re-set tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
max_len = 512 

# Prepare Dataset
python_train_dataset = PythonCommentDataset(python_train_data, tokenizer, max_len)
python_val_dataset = PythonCommentDataset(python_val_data, tokenizer, max_len)
python_test_dataset = PythonCommentDataset(python_test, tokenizer, max_len)

# Dataloaders
python_train_loader = DataLoader(python_train_dataset, batch_size=32, shuffle=True)
python_val_loader = DataLoader(python_val_dataset, batch_size=32, shuffle=False)
python_test_loader = DataLoader(python_test_dataset, batch_size=32)


## 3) Update Models for Python

In [15]:
# Initialize model for Python comments
pytorch_python_model = PyTorchCNN(vocab_size=vocab_size, embed_dim=embed_dim, num_classes=5)
lightning_python_model = LightningModel(model=pytorch_python_model, learning_rate=0.01, num_classes=5)


## 4) Train model for Python comments

In [16]:
# Setup PyTorch Lightning trainer for CPU
trainer = pl.Trainer(
    max_epochs=200,
    accelerator="gpu",  # Use CPU instead of GPU
    devices="auto",
    callbacks=[early_stopping],  
    logger=CSVLogger(save_dir="logs/", name="my-model"),
    deterministic=True
)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:

# Train the model on Python dataset
trainer.fit(lightning_python_model, train_dataloaders=python_train_loader, val_dataloaders=python_val_loader)

# Test the model on Python test dataset
trainer.test(model=lightning_python_model, dataloaders=python_test_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | model           | PyTorchCNN          | 24.6 M | train
1  | train_acc       | MultilabelAccuracy  | 0      | train
2  | val_acc         | MultilabelAccuracy  | 0      | train
3  | test_acc        | MultilabelAccuracy  | 0      | train
4  | train_precision | MultilabelPrecision | 0      | train
5  | val_precision   | MultilabelPrecision | 0      | train
6  | test_precision  | MultilabelPrecision | 0      | train
7  | train_recall    | MultilabelRecall    | 0 

Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (48) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 48/48 [00:31<00:00,  1.51it/s, v_num=77, train_precision_class_0=0.000, train_precision_class_1=0.000, train_precision_class_2=0.000, train_precision_class_3=0.000, train_precision_class_4=0.000, train_recall_class_0=0.000, train_recall_class_1=0.000, train_recall_class_2=0.000, train_recall_class_3=0.000, train_recall_class_4=0.000, train_f1_class_0=0.000, train_f1_class_1=0.000, train_f1_class_2=0.000, train_f1_class_3=0.000, train_f1_class_4=0.000, val_loss=1.630, val_acc=0.736, val_precision_class_0=0.530, val_precision_class_1=0.344, val_precision_class_2=0.0849, val_precision_class_3=0.000, val_precision_class_4=0.271, val_recall_class_0=0.459, val_recall_class_1=0.546, val_recall_class_2=0.0283, val_recall_class_3=0.000, val_recall_class_4=0.379, val_f1_class_0=0.486, val_f1_class_1=0.412, val_f1_class_2=0.0424, val_f1_class_3=0.000, val_f1_class_4=0.289, train_acc=0.710]

Metric val_loss improved. New best score: 1.628


Epoch 2: 100%|██████████| 48/48 [00:33<00:00,  1.41it/s, v_num=77, train_precision_class_0=1.000, train_precision_class_1=0.500, train_precision_class_2=0.000, train_precision_class_3=0.000, train_precision_class_4=0.000, train_recall_class_0=1.000, train_recall_class_1=1.000, train_recall_class_2=0.000, train_recall_class_3=0.000, train_recall_class_4=0.000, train_f1_class_0=1.000, train_f1_class_1=0.667, train_f1_class_2=0.000, train_f1_class_3=0.000, train_f1_class_4=0.000, val_loss=1.430, val_acc=0.785, val_precision_class_0=0.666, val_precision_class_1=0.571, val_precision_class_2=0.000, val_precision_class_3=0.153, val_precision_class_4=0.354, val_recall_class_0=0.708, val_recall_class_1=0.641, val_recall_class_2=0.000, val_recall_class_3=0.124, val_recall_class_4=0.313, val_f1_class_0=0.678, val_f1_class_1=0.593, val_f1_class_2=0.000, val_f1_class_3=0.125, val_f1_class_4=0.308, train_acc=0.801]    

Metric val_loss improved by 0.202 >= min_delta = 0.0. New best score: 1.427


Epoch 12: 100%|██████████| 48/48 [00:33<00:00,  1.42it/s, v_num=77, train_precision_class_0=0.000, train_precision_class_1=1.000, train_precision_class_2=0.000, train_precision_class_3=0.000, train_precision_class_4=0.000, train_recall_class_0=0.000, train_recall_class_1=0.333, train_recall_class_2=0.000, train_recall_class_3=0.000, train_recall_class_4=0.000, train_f1_class_0=0.000, train_f1_class_1=0.500, train_f1_class_2=0.000, train_f1_class_3=0.000, train_f1_class_4=0.000, val_loss=1.670, val_acc=0.806, val_precision_class_0=0.793, val_precision_class_1=0.698, val_precision_class_2=0.198, val_precision_class_3=0.340, val_precision_class_4=0.407, val_recall_class_0=0.646, val_recall_class_1=0.551, val_recall_class_2=0.171, val_recall_class_3=0.324, val_recall_class_4=0.481, val_f1_class_0=0.703, val_f1_class_1=0.602, val_f1_class_2=0.181, val_f1_class_3=0.311, val_f1_class_4=0.433, train_acc=0.894] 

Monitored metric val_loss did not improve in the last 10 records. Best score: 1.427. Signaling Trainer to stop.


Epoch 12: 100%|██████████| 48/48 [00:34<00:00,  1.39it/s, v_num=77, train_precision_class_0=0.000, train_precision_class_1=1.000, train_precision_class_2=0.000, train_precision_class_3=0.000, train_precision_class_4=0.000, train_recall_class_0=0.000, train_recall_class_1=0.333, train_recall_class_2=0.000, train_recall_class_3=0.000, train_recall_class_4=0.000, train_f1_class_0=0.000, train_f1_class_1=0.500, train_f1_class_2=0.000, train_f1_class_3=0.000, train_f1_class_4=0.000, val_loss=1.670, val_acc=0.806, val_precision_class_0=0.793, val_precision_class_1=0.698, val_precision_class_2=0.198, val_precision_class_3=0.340, val_precision_class_4=0.407, val_recall_class_0=0.646, val_recall_class_1=0.551, val_recall_class_2=0.171, val_recall_class_3=0.324, val_recall_class_4=0.481, val_f1_class_0=0.703, val_f1_class_1=0.602, val_f1_class_2=0.181, val_f1_class_3=0.311, val_f1_class_4=0.433, train_acc=0.894]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.



Testing DataLoader 0: 100%|██████████| 13/13 [00:00<00:00, 24.25it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.7674877047538757
     test_f1_class_0        0.3780786395072937
     test_f1_class_1        0.4347551167011261
     test_f1_class_2        0.15749779343605042
     test_f1_class_3        0.2125399112701416
     test_f1_class_4        0.3562993109226227
 test_precision_class_0     0.5348981022834778
 test_precision_class_1     0.6206597685813904
 test_precision_class_2     0.2048635482788086
 test_precision_class_3     0.1990695297718048
 test_precision_class_4     0.3929506242275238
   test_recall_class_0      0.3130086064338684
   test_recall_class_1      0.38730278611183167
   test_recall_class_2      0.

[{'test_acc': 0.7674877047538757,
  'test_precision_class_0': 0.5348981022834778,
  'test_precision_class_1': 0.6206597685813904,
  'test_precision_class_2': 0.2048635482788086,
  'test_precision_class_3': 0.1990695297718048,
  'test_precision_class_4': 0.3929506242275238,
  'test_recall_class_0': 0.3130086064338684,
  'test_recall_class_1': 0.38730278611183167,
  'test_recall_class_2': 0.1707717627286911,
  'test_recall_class_3': 0.29173511266708374,
  'test_recall_class_4': 0.36346864700317383,
  'test_f1_class_0': 0.3780786395072937,
  'test_f1_class_1': 0.4347551167011261,
  'test_f1_class_2': 0.15749779343605042,
  'test_f1_class_3': 0.2125399112701416,
  'test_f1_class_4': 0.3562993109226227}]

# Pharo Comments

# 1) Setup Dataset for Pharo 

The Pharo dataset follows the java dataset very closely as it contains the same number of classes

In [13]:
class PharoCommentDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len):
        self.comments = dataframe['combo'].tolist()
        self.labels = dataframe['labels'].tolist()
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.comments)
    
    def __getitem__(self, idx):
        # Tokenize the text
        text = self.comments[idx]
        tokens = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        # Process labels
        label = torch.tensor(self.labels[idx], dtype=torch.float)
        
        # Reshape input for CNN
        input_ids = tokens['input_ids'].squeeze(0)
        
        # Reshape embeddings to match CNN input format [batch_size, channels, sequence_length, embedding_dim]
        cnn_input = input_ids.unsqueeze(0)
        
        return {
            'input_ids': cnn_input,
            'labels': label
        }

The Pharo dataset follows the java dataset very closely as it contains the same number of classes

# 2) Reset model 

In [14]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased")
max_len = 512 

# Prepare Dataset for the Pharo comments
pharo_train_dataset = PharoCommentDataset(pharo_train_data, tokenizer, max_len)
pharo_train_loader = DataLoader(pharo_train_dataset, batch_size=32, shuffle=True)

pharo_val_dataset = PharoCommentDataset(pharo_val_data, tokenizer, max_len)
pharo_val_loader = DataLoader(pharo_val_dataset, batch_size=32, shuffle=False)

pharo_test_dataset = PharoCommentDataset(pharo_test, tokenizer, max_len)
pharo_test_loader = DataLoader(pharo_test_dataset, batch_size=32)

# 3) Implement models

In [15]:
# Initialize model for Pharo comments (7 classes)
pytorch_pharo_model = PyTorchCNN(vocab_size=vocab_size, embed_dim=embed_dim, num_classes=7)
lightning_pharo_model = LightningModel(model=pytorch_pharo_model, learning_rate=0.01, num_classes=7)

# 4) Set for Pharo comments

In [16]:
# Setup PyTorch Lightning trainer for GPU
trainer = pl.Trainer(
    max_epochs=200,
    accelerator="gpu",
    devices="auto",
    callbacks=[early_stopping],   
    logger=CSVLogger(save_dir="logs/", name="my-model"),
    deterministic=True
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# 5) Train and test the model

In [17]:
# Train the model using the pharo dataset
trainer.fit(lightning_pharo_model, train_dataloaders=pharo_train_loader, val_dataloaders=pharo_val_loader)

# Test the model on Pharo test dataset
trainer.test(model=lightning_pharo_model, dataloaders=pharo_test_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 3060') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | model           | PyTorchCNN          | 24.6 M | train
1  | train_acc       | MultilabelAccuracy  | 0      | train
2  | val_acc         | MultilabelAccuracy  | 0      | train
3  | test_acc        | MultilabelAccuracy  | 0      | train
4  | train_precision | MultilabelPrecision | 0      | train
5  | val_precision   | MultilabelPrecision | 0      | train
6  | test_precision  | MultilabelPrecision | 0      | train
7  | train_recall    | MultilabelRecall    | 0 

Sanity Checking:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


                                                                           

c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (33) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 33/33 [00:24<00:00,  1.35it/s, v_num=78, train_precision_class_0=0.000, train_precision_class_1=0.500, train_precision_class_2=0.600, train_precision_class_3=0.000, train_precision_class_4=0.000, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=0.000, train_recall_class_1=0.375, train_recall_class_2=0.750, train_recall_class_3=0.000, train_recall_class_4=0.000, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=0.000, train_f1_class_1=0.429, train_f1_class_2=0.667, train_f1_class_3=0.000, train_f1_class_4=0.000, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=1.820, val_acc=0.807, val_precision_class_0=0.238, val_precision_class_1=0.571, val_precision_class_2=0.326, val_precision_class_3=0.000, val_precision_class_4=0.000, val_precision_class_5=0.444, val_precision_class_6=0.000, val_recall_class_0=0.239, val_recall_class_1=0.474, val_recall_class_2=0.470, val_recall_class_3=0.000, val_recall_cla

Metric val_loss improved. New best score: 1.825


Epoch 1: 100%|██████████| 33/33 [00:24<00:00,  1.32it/s, v_num=78, train_precision_class_0=0.500, train_precision_class_1=0.571, train_precision_class_2=0.500, train_precision_class_3=0.000, train_precision_class_4=0.000, train_precision_class_5=0.000, train_precision_class_6=0.000, train_recall_class_0=0.200, train_recall_class_1=1.000, train_recall_class_2=1.000, train_recall_class_3=0.000, train_recall_class_4=0.000, train_recall_class_5=0.000, train_recall_class_6=0.000, train_f1_class_0=0.286, train_f1_class_1=0.727, train_f1_class_2=0.667, train_f1_class_3=0.000, train_f1_class_4=0.000, train_f1_class_5=0.000, train_f1_class_6=0.000, val_loss=1.510, val_acc=0.855, val_precision_class_0=0.482, val_precision_class_1=0.563, val_precision_class_2=0.596, val_precision_class_3=0.000, val_precision_class_4=0.123, val_precision_class_5=0.246, val_precision_class_6=0.000, val_recall_class_0=0.326, val_recall_class_1=0.930, val_recall_class_2=0.453, val_recall_class_3=0.000, val_recall_cla

Metric val_loss improved by 0.317 >= min_delta = 0.0. New best score: 1.507


Epoch 3: 100%|██████████| 33/33 [00:20<00:00,  1.58it/s, v_num=78, train_precision_class_0=0.333, train_precision_class_1=0.600, train_precision_class_2=0.500, train_precision_class_3=0.000, train_precision_class_4=0.000, train_precision_class_5=1.000, train_precision_class_6=0.000, train_recall_class_0=1.000, train_recall_class_1=1.000, train_recall_class_2=0.500, train_recall_class_3=0.000, train_recall_class_4=0.000, train_recall_class_5=0.500, train_recall_class_6=0.000, train_f1_class_0=0.500, train_f1_class_1=0.750, train_f1_class_2=0.500, train_f1_class_3=0.000, train_f1_class_4=0.000, train_f1_class_5=0.667, train_f1_class_6=0.000, val_loss=1.260, val_acc=0.894, val_precision_class_0=0.727, val_precision_class_1=0.827, val_precision_class_2=0.473, val_precision_class_3=0.000, val_precision_class_4=0.519, val_precision_class_5=0.729, val_precision_class_6=0.000, val_recall_class_0=0.445, val_recall_class_1=0.847, val_recall_class_2=0.620, val_recall_class_3=0.000, val_recall_cla

Metric val_loss improved by 0.245 >= min_delta = 0.0. New best score: 1.262


Epoch 13: 100%|██████████| 33/33 [00:20<00:00,  1.57it/s, v_num=78, train_precision_class_0=0.000, train_precision_class_1=1.000, train_precision_class_2=1.000, train_precision_class_3=0.000, train_precision_class_4=1.000, train_precision_class_5=1.000, train_precision_class_6=0.000, train_recall_class_0=0.000, train_recall_class_1=1.000, train_recall_class_2=0.400, train_recall_class_3=0.000, train_recall_class_4=1.000, train_recall_class_5=1.000, train_recall_class_6=0.000, train_f1_class_0=0.000, train_f1_class_1=1.000, train_f1_class_2=0.571, train_f1_class_3=0.000, train_f1_class_4=1.000, train_f1_class_5=1.000, train_f1_class_6=0.000, val_loss=1.480, val_acc=0.896, val_precision_class_0=0.536, val_precision_class_1=0.838, val_precision_class_2=0.618, val_precision_class_3=0.0615, val_precision_class_4=0.656, val_precision_class_5=0.606, val_precision_class_6=0.103, val_recall_class_0=0.552, val_recall_class_1=0.869, val_recall_class_2=0.337, val_recall_class_3=0.0308, val_recall_

Monitored metric val_loss did not improve in the last 10 records. Best score: 1.262. Signaling Trainer to stop.


Epoch 13: 100%|██████████| 33/33 [00:22<00:00,  1.49it/s, v_num=78, train_precision_class_0=0.000, train_precision_class_1=1.000, train_precision_class_2=1.000, train_precision_class_3=0.000, train_precision_class_4=1.000, train_precision_class_5=1.000, train_precision_class_6=0.000, train_recall_class_0=0.000, train_recall_class_1=1.000, train_recall_class_2=0.400, train_recall_class_3=0.000, train_recall_class_4=1.000, train_recall_class_5=1.000, train_recall_class_6=0.000, train_f1_class_0=0.000, train_f1_class_1=1.000, train_f1_class_2=0.571, train_f1_class_3=0.000, train_f1_class_4=1.000, train_f1_class_5=1.000, train_f1_class_6=0.000, val_loss=1.480, val_acc=0.896, val_precision_class_0=0.536, val_precision_class_1=0.838, val_precision_class_2=0.618, val_precision_class_3=0.0615, val_precision_class_4=0.656, val_precision_class_5=0.606, val_precision_class_6=0.103, val_recall_class_0=0.552, val_recall_class_1=0.869, val_recall_class_2=0.337, val_recall_class_3=0.0308, val_recall_

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\kingk\anaconda3\envs\pytorchenv\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.



Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 27.77it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.8882847428321838
     test_f1_class_0        0.3781798779964447
     test_f1_class_1        0.8131514191627502
     test_f1_class_2        0.38167139887809753
     test_f1_class_3        0.07381776720285416
     test_f1_class_4        0.5096940994262695
     test_f1_class_5         0.232232928276062
     test_f1_class_6        0.21618059277534485
 test_precision_class_0     0.4463667869567871
 test_precision_class_1     0.8587366342544556
 test_precision_class_2     0.5886966586112976
 test_precision_class_3     0.11072664707899094
 test_precision_class_4     0.5314878821372986
 test_precision_class_5     

[{'test_acc': 0.8882847428321838,
  'test_precision_class_0': 0.4463667869567871,
  'test_precision_class_1': 0.8587366342544556,
  'test_precision_class_2': 0.5886966586112976,
  'test_precision_class_3': 0.11072664707899094,
  'test_precision_class_4': 0.5314878821372986,
  'test_precision_class_5': 0.25221067667007446,
  'test_precision_class_6': 0.1845444142818451,
  'test_recall_class_0': 0.384212851524353,
  'test_recall_class_1': 0.8153007626533508,
  'test_recall_class_2': 0.3250617980957031,
  'test_recall_class_3': 0.05536332353949547,
  'test_recall_class_4': 0.5093425512313843,
  'test_recall_class_5': 0.25836217403411865,
  'test_recall_class_6': 0.3321799337863922,
  'test_f1_class_0': 0.3781798779964447,
  'test_f1_class_1': 0.8131514191627502,
  'test_f1_class_2': 0.38167139887809753,
  'test_f1_class_3': 0.07381776720285416,
  'test_f1_class_4': 0.5096940994262695,
  'test_f1_class_5': 0.232232928276062,
  'test_f1_class_6': 0.21618059277534485}]