In [18]:
import { mnist } from './src/nn/datasets.ts'
import { document } from 'jsr:@ry/jupyter-helper'
import { barY, cell, plot } from 'npm:@observablehq/plot'
import { MNIST } from './beautiful_mnist.ts'
import { Tensor } from './src/tensor.ts'

const [x_train, y_train] = await mnist()

const sample = Tensor.randint([1], undefined, x_train.shape[0])
const label = await y_train.get(sample).tolist<number[]>()
const img = await x_train.get(sample).reshape([28, 28]).tolist<number[][]>()

await Deno.jupyter.display(
  plot({
    title: `Label: ${label[0]}`,
    height: 500,
    width: 500,
    marks: [
      cell(img.flatMap((row, y) => row.map((value, x) => ({ x, y, value }))), {
        x: 'x',
        y: 'y',
        fill: 'value',
      }),
    ],
    color: { scheme: 'greys' },
    document,
  }),
)

const model = new MNIST()
await model.load('./model.safetensors')
const res = await model.call(x_train.get(sample)).reshape([10]).tolist<number[]>()

const pred = res.map((value, i) => ({ i, value }))
await Deno.jupyter.display(
  plot({
    title: 'Predictions',
    height: 400,
    marks: [
      barY(pred, {
        x: 'i',
        y: 'value',
        fill: (d) => (d.value === Math.max(...pred.map((p) => p.value)) ? 'red' : 'steelblue'),
      }),
    ],
    document,
  }),
)