In [None]:
import os
import numpy as np
import cv2
import imageio.v3 as imageio
import matplotlib.pyplot as plt
import jax
from jax import numpy as jnp
from flax import nnx
import orbax.checkpoint as ocp
from net import RubbishClassifier

In [None]:
def letterbox(im, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True, stride=32):
    # Resize and pad image while meeting stride-multiple constraints
    shape = im.shape[:2]  # current shape [height, width]
    if isinstance(new_shape, int):
        new_shape = (new_shape, new_shape)
    # Scale ratio (new / old)
    r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
    if not scaleup:  # only scale down, do not scale up (for better val mAP)
        r = min(r, 1.0)
    # Compute padding
    ratio = r, r  # width, height ratios
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1]  # wh padding
    if auto:  # minimum rectangle
        dw, dh = np.mod(dw, stride), np.mod(dh, stride)  # wh padding
    elif scaleFill:  # stretch
        dw, dh = 0.0, 0.0
        new_unpad = (new_shape[1], new_shape[0])
        ratio = new_shape[1] / shape[1], new_shape[0] / shape[0]  # width, height ratios
    # Ensure dw and dh are positive and can fill to new_shape
    dw = new_shape[1] - new_unpad[0] if dw < 0 else dw
    dh = new_shape[0] - new_unpad[1] if dh < 0 else dh
    dw /= 2  # divide padding into 2 sides
    dh /= 2
    if shape[::-1] != new_unpad:  # resize
        im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
    im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)  # add border
    return im, ratio, (dw, dh)

In [None]:
abstract_model = nnx.eval_shape(lambda: RubbishClassifier(num_classes=40, rngs=nnx.Rngs(0)))
graphdef, abstract_state = nnx.split(abstract_model)

with ocp.CheckpointManager(
    os.path.join(os.getcwd(), "checkpoints/"),
    options=ocp.CheckpointManagerOptions(max_to_keep=1),
) as mngr:
    stete = mngr.restore(mngr.latest_step(), args=ocp.args.StandardRestore(abstract_state))
    model = nnx.merge(graphdef, stete)



In [None]:
image_path = "./data/train/10/img_4256.jpg"
img = imageio.imread(image_path)
# Display the image using matplotlib
plt.imshow(img)
plt.show()

# print the shape of the image
print("Image shape:", img.shape)

In [None]:
img, _, _ = letterbox(img, new_shape=(224, 224), auto=False)
img = np.expand_dims(img, axis=0)  # Add batch dimension

logits = model(img)
# 应用softmax获得概率
probabilities = jax.nn.softmax(logits, axis=-1)

# 获得概率最大的类别
predicted_class = jnp.argmax(probabilities, axis=-1)
max_probability = jnp.max(probabilities, axis=-1)

# 显示前5个最可能的类别
top5_indices = jnp.argsort(probabilities[0])[-5:][::-1]  # 降序排列
print("\nTop 5 predictions:")
for i, idx in enumerate(top5_indices):
    print(f"{i+1}. Class {idx}: {probabilities[0][idx]:.4f} ({probabilities[0][idx]*100:.2f}%)")