In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from utils import get_transforms, get_model, get_last_layer
from cnn import run_model, apply_gradcam

plt.style.use("ggplot")

In [None]:
MODEL = "cvgg13"

### Визуализация работы модели

In [None]:
images_dir = Path("data/gradcam_images/ref")
transform = get_transforms()
device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_model(MODEL).to(device)
model.load_state_dict(torch.load(f"data/weights/{MODEL}.pt"))
layer = get_last_layer(model, MODEL)

inputs = []
outputs = []


for i, path in enumerate(images_dir.iterdir()):
    image = transform(Image.open(path))
    label = path.stem
    inputs.append((image, label))

    gradcam_image = apply_gradcam(model, device, image, layer)
    prediction = run_model(model, device, image)[0]
    outputs.append((gradcam_image, prediction))

In [None]:
def visualize(inputs, outputs):
    n = len(inputs)
    fig, axes = plt.subplots(2, n)

    for ax in axes.flat:
        ax.axis("off")

    for i in range(n):
        image, label = inputs[i]
        gradcam_image, prediction = outputs[i]

        axes.flat[i].imshow(image)
        axes.flat[i].set_title(label)
        axes.flat[i + n].imshow(gradcam_image)
        axes.flat[i + n].set_title(prediction)

    axes[0, 0].annotate(
        "Вход",
        xy=(-0.3, 0.5),
        xycoords="axes fraction",
        rotation=90,
        va="center",
        fontsize=12,
    )

    axes[1, 0].annotate(
        "Выход",
        xy=(-0.3, 0.5),
        xycoords="axes fraction",
        rotation=90,
        va="center",
        fontsize=12,
    )

    plt.tight_layout()
    plt.subplots_adjust(hspace=-0.5)
    plt.show()

In [None]:
visualize(inputs[:5], outputs[:5])

In [None]:
visualize(inputs[5:], outputs[5:])

### Метрики

In [None]:
import polars as pl
import seaborn as sns

sns.set_style("whitegrid")

In [None]:
test = pl.read_csv(f"data/csv/f1_test_{MODEL}.csv")
train = pl.read_csv(f"data/csv/f1_train_{MODEL}.csv")

In [None]:
sns.lineplot(test, x="Step", y="Value", label="test")
sns.lineplot(train, x="Step", y="Value", label="train")
plt.title("Точность модели")
plt.xlabel("Эпоха")
plt.ylabel("Точность")
plt.show()