# note-1-train-and-test
This notebook is for main task of training and testing the model. 
It loads the pre-trained model and train it on the MNIST dataset.

## Flow
1. Setup the environment
1. Train the model
1. Test the model


## Step

### 1. Setup the environment
#### Install the dependencies

In [None]:
!pip install --quiet --upgrade pip
!pip install --quiet -r requirements.txt

#### Import the libraries and set the environment variables

In [None]:
import os
import matplotlib.pyplot as plt
import torch, torchvision

# Config
FORCE_CPU = True
FREEZE_RESNET = False
CHANNEL_SIZE = 3
RESIZE_SIZE = (224, 224)

NORMALIZE_MEAN = (0.5,0.5,0.5)
NORMALIZE_STD = (0.5,0.5,0.5)

DATA_ROOT_PATH = "./data"
MODEL_ROOT_PATH = "./models"

# Hyperparameters
learning_rate = 0.001
momentum = 0.9
num_epochs = 3
batch_size = 100
num_workers = 0 # main process
t_max = 200

transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Grayscale(num_output_channels=CHANNEL_SIZE),
        torchvision.transforms.Resize(RESIZE_SIZE),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(NORMALIZE_MEAN,NORMALIZE_STD),
    ]
)

if FORCE_CPU:
    device = torch.device("cpu")
    print("CPU is forced. Using CPU.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA is available. Using GPU.")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("MPS is available and built. Using Apple Silicon GPU.")
else:
    device = torch.device("cpu")
    print("Neither CUDA nor MPS is available or built. Using CPU.")

#### Download the dataset

In [None]:
dataset_train = torchvision.datasets.MNIST(root=DATA_ROOT_PATH, train=True, download=True, transform=transform)
dataloader_train = torch.utils.data.DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)

dataset_test = torchvision.datasets.MNIST(root=DATA_ROOT_PATH, train=False, download=True, transform=transform)
dataloader_test = torch.utils.data.DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=True, num_workers=num_workers)

#### Show some samples of the dataset

In [None]:
# Dataset Info.
print(f"Training Dataset: {len(dataset_train)} samples")
print(f"Testing Dataset: {len(dataset_test)} samples")

# Visualize a sample
sample_train = next(iter(dataloader_train))
sample_test = next(iter(dataloader_test))

img_train = sample_train[0][0].permute(1, 2, 0) # (C, H, W) -> (H, W, C)
img_train = (img_train + 1) / 2 # Normalize Grayscale to [0, 1]

img_test = sample_test[0][0].permute(1, 2, 0) # (C, H, W) -> (H, W, C)
img_test = (img_test + 1) / 2 # Normalize Grayscale to [0, 1]

plt.imshow(img_train)
plt.title(f"Training sample (Label: {sample_train[1][0]})")
plt.show()

plt.imshow(img_test)
plt.title(f"Testing sample (Label: {sample_test[1][0]})")
plt.show()

### 2. Train the model

#### Define the training function

In [None]:
def train_model(
    model : torch.nn.Module,
    dataloader : torch.utils.data.DataLoader,
    criterion : torch.nn.Module,
    optimizer : torch.optim.Optimizer,
    show_progress : bool = True,
    epoch : int = 0,
    num_epochs : int = 0,
):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)
        predicted = torch.max(outputs.data, 1)[1]
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        if show_progress:
            print(f"  -- batch [Partial Loss: {loss.item():.4f}, Partial Accuracy: {100*correct/total:.1f}%]")

    total_loss = train_loss / total
    total_acc = correct / total

    if show_progress:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss:.4f}, Accuracy: {100*total_acc:.1f}%")

    return total_loss, total_acc

#### Load and set the model
Load the pre-trained model and set the parameters to be trained.

In [None]:
model = torchvision.models.resnet18(
    weights = torchvision.models.resnet.ResNet18_Weights.IMAGENET1K_V1
)

if FREEZE_RESNET:
    for name, param in model.named_parameters():
        # Freeze the parameters
        param.requires_grad = False

# Fine-tune the fully connected layer
model.fc = torch.nn.Linear(model.fc.in_features, len(dataset_train.classes))
model = model.to(device)

# # Get summary
# from torchsummary import summary
# print("Pre-trained ResNet18 Summary:")
# summary(model, (3, 224, 224))

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(
    params = model.parameters(),
    lr = learning_rate,
    momentum = momentum,
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=t_max)

#### Do train the model

In [None]:
for epoch in range(num_epochs):
    train_model(model, dataloader_train, criterion, optimizer, show_progress=True, epoch=epoch, num_epochs=num_epochs)

## Save Model

In [None]:
if not os.path.exists(MODEL_ROOT_PATH):
    os.makedirs(MODEL_ROOT_PATH)

model_scripted = torch.jit.script(model)
model_scripted.save(f"{MODEL_ROOT_PATH}/model.pt")

## Load Model

In [None]:
loaded_model = torch.jit.load(f"{MODEL_ROOT_PATH}/model.pt")
loaded_model.eval()
loaded_model = loaded_model.to(device)

## Test Model

In [None]:
# Visualize a sample
sample_test = next(iter(dataloader_test))

img_test = sample_test[0][0].permute(1, 2, 0) # (C, H, W) -> (H, W, C)
img_test = (img_test + 1) / 2 # Normalize Grayscale to [0, 1]

test_img = img_test.permute(2, 0, 1).unsqueeze(dim=0).to(device)  # (H,W,C) -> (1,C,H,W)
response = loaded_model(test_img)
probabilities = torch.nn.functional.softmax(response, dim=1)
predicted_class = torch.argmax(probabilities).item()

plt.imshow(img_test)
plt.title(f"Testing sample (Real: {sample_test[1][0]}, Predicted: {predicted_class})")
plt.show()

print(f"Predicted Number: {predicted_class}")
print(f"Probability Distribution:")
for idx, prob in enumerate(probabilities[0]):
    print(f" - Number {idx}: {prob.item()*100:.2f}%")