# Model Loading and Inference with PyTorch

In this notebook we will:
- Load the trained model from `model_creation_and_train`
- Perform inference on test data
- Visualize predictions


In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

## Loading the Dataset

We use the same FashionMNIST test dataset we used for training.

In [None]:
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

batch_size = 64
test_dataloader = DataLoader(test_data, batch_size=batch_size)

## Define the Network

We must recreate the **same model architecture** as used during training.

In [None]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512,10)
        )
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
model = NeuralNetwork().to(device)


## Loading Model Weights

We load the trained state dictionary from the previous notebook.


In [None]:
model.load_state_dict(torch.load("../01_model_creation_and_train/model.pth"))
model.eval()
print("Model loaded successfully!")

## Making Predictions

We take one batch of test images and predict their labels.


In [None]:
images, labels = next(iter(test_dataloader))
images, labels = images.to(device), labels.to(device)

with torch.no_grad():
    outputs = model(images)
    predictions = outputs.argmax(dim=1)

print("Predictions:", predictions[:10])
print("True labels:", labels[:10])

## Visualizing Predictions

Display a few test images with predicted labels.

In [None]:
def show_images(images, labels, preds, n=6):
    plt.figure(figsize=(12, 3))
    for i in range(n):
        plt.subplot(1, n, i+1)
        plt.imshow(images[i].cpu().squeeze(), cmap='gray')
        plt.title(f"P: {preds[i].item()}\nT: {labels[i].item()}")
        plt.axis('off')
    plt.show()

show_images(images, labels, predictions)