# Inference


In [None]:
import copy
import time
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim
import torchvision.transforms as transforms

from utils import dataset
from models import nvidia

# Tqdm progress bar
from tqdm import tqdm_notebook, tqdm

WEIGHTS_FILE = "./checkpoints/nvidia_dave2.pth"
BATCH_SIZE=32

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

In [None]:
def main():
    # Normalizing images per the paper and resizing each image to 66 x 200.
    transform = transforms.Compose([
        # Citation:
        # https://pytorch.org/vision/stable/transforms.html#scriptable-transforms
        transforms.Resize((192, 256)),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))   
    ])

    # Loading in images with normalization and resizing applied.
    _ ,_, test_set = dataset.load_nvidia_dataset(transform=transform, batch_size=BATCH_SIZE)
    torch.cuda.empty_cache()

    # Loading in the NVIDIA DAVE-2 model.
    model = nvidia.NvidiaDaveCNN()
    model.load_state_dict(torch.load(WEIGHTS_FILE))
    
    if torch.cuda.is_available():
        model = model.to(torch.device("cuda"))

    model.eval()

    cumulative_mse_losses = 0.0
    criterion = nn.MSELoss(reduction="sum")
    rmse = 0.0

    # Get the progress bar for later modification
    progress_bar = tqdm_notebook(test_set, ascii=True)

    for idx, (data, target) in enumerate(progress_bar):

        if torch.cuda.is_available():
            data = data.cuda()
            target = target.cuda()

        out = None
        with torch.inference_mode():
            out = model(data)
            mse = criterion(out, target)

        cumulative_mse_losses += mse

        progress_bar.set_description_str(f"Batch: {idx+1}, Loss: {(mse/len(target)):.4f}")
    rmse = (cumulative_mse_losses / len(test_set) ) ** 0.5
    print("* RMSE: ", rmse.item())
    return rmse


if __name__ == '__main__':
    rmse = main()