In [1]:
from torchmetrics import Metric
import torch
from typing import Optional, Union, Sequence, Tuple
import wandb
from torchmetrics.utilities.plot import (
    _AX_TYPE,
    _PLOT_OUT_TYPE,
    plot_single_or_multi_val,
)
import pandas as pd
import wandb

In [4]:
third.device

device(type='cuda', index=0)

In [2]:
class AccuracyPerClass(Metric):
    def __init__(self, targetToLabelMapper: dict, **kwargs):
        super().__init__(**kwargs)
        self.add_state("predicted", default=torch.tensor([]), dist_reduce_fx="cat")
        self.add_state("target", default=torch.tensor([]), dist_reduce_fx="cat")
        self.targetToLabelMapper = targetToLabelMapper

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        if preds.shape != target.shape:
            raise ValueError("preds and target must have the same shape")

        self.predicted = torch.cat([self.predicted, preds])
        self.target = torch.cat([self.target, target])

    def compute(self) -> torch.Tensor:
        target = self.target.cpu().numpy()
        predicted = self.predicted.cpu().numpy()
        output = []
        # output
        for t, p in zip(target, predicted):
            output.append([self.targetToLabelMapper[t], t == p])

        return output

    def log(self, title: str):
        data = self.compute()
        table = wandb.Table(data=data, columns=["label", "value"])
        bar_chart = wandb.plot.bar(table, "label", "value", title)
        wandb.log({"accuracy_per_class_id": bar_chart})

In [3]:
metric = AccuracyPerClass({0: "cat", 1: "dog"})

In [8]:
preds = torch.tensor([0, 1, 0, 1, 0, 1])
target = torch.tensor([0, 1, 1, 0, 1, 0])

metric.update(preds, target)
print(metric.compute())

[['cat', True], ['dog', True], ['dog', False], ['cat', False], ['cat', True], ['dog', True], ['dog', False], ['cat', False], ['dog', False], ['cat', False]]


In [11]:
wandb.login()

True

In [12]:
run = wandb.init(
    # Set the project where this run will be logged
    project="bird_clef_2024",
)

In [13]:
import random

table = wandb.Table(
    data=[
        ["car", random.random()],
        ["bus", random.random()],
        ["road", random.random()],
        ["person", random.random()],
    ],
    columns=["class", "acc"],
)
wandb.log({"bar-plot1": wandb.plot.bar(table, "class", "acc")})

In [9]:
metric.log("Accuracy per class")

In [14]:
wandb.finish()

VBox(children=(Label(value='0.004 MB of 0.016 MB uploaded\r'), FloatProgress(value=0.24157131493605824, max=1.…

In [24]:
from torchmetrics.classification import MulticlassF1Score

target = torch.tensor([2, 1, 0, 0])
preds = torch.tensor(
    [[0.16, 0.26, 0.58], [0.22, 0.61, 0.17], [0.71, 0.09, 0.20], [0.05, 0.82, 0.13]]
)
metric = MulticlassF1Score(num_classes=3, average=None)
metric.update(preds, target)
metric.compute()

# mcf1s = MulticlassF1Score(num_classes=3, multidim_average='samplewise', average=None)
# mcf1s(preds, target)

tensor([0.6667, 0.6667, 1.0000])

In [27]:
from torchmetrics import Precision


preds = torch.tensor([2, 0, 2, 1])
target = torch.tensor([1, 1, 2, 0])
precision = Precision(task="multiclass", average=None, num_classes=3)
precision(preds, target)

tensor(0.2500)