In [2]:
import os
import cv2
import numpy as np

import torch
from fastai.vision.all import load_learner, PILImage
# from fastxtend.vision.all import adam

In [3]:
model_path_fastai = os.path.join(os.path.curdir, r"model.pkl")
model_path_torch = os.path.join(os.path.curdir, r"model.pt")

In [4]:
model = torch.load(model_path_torch)
model = model.double()

In [5]:
def round_pred(pred, threshold:float=0):
    if pred < 0.4 and pred > threshold:
        return 0.4
    elif pred < threshold:
        return -1
    else:
        return pred

def predict(img, model_path=model_path_fastai):
    """predicts steer and throttle given an image path

    Parameters:
    ---
    img_path: str
        path to image

    Returns:
    ---
    steer: float
        predicted steer value
    throttle: float
        predicted throttle value
    """
    # learner = Learner().load(model_path)
    learner = load_learner(model_path)
    steer, throttle = learner.predict(img)
    throttle = round_pred(throttle, -0.5)
    return steer, throttle

In [8]:
def predict_torch(img, model):
    """predicts steer and throttle given an image path

    Parameters:
    ---
    img_path: str
        path to image

    Returns:
    ---
    steer: float
        predicted steer value
    throttle: float
        predicted throttle value
    """
    img = cv2.resize(img, (256, 256))
    img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
    img = np.expand_dims(img, axis=0)
    img = img.transpose(0, 3, 1, 2)
    img = img / 255
    img = torch.from_numpy(img).double()

    steer, throttle = model.forward(img).tolist()[0]
    throttle = round_pred(throttle, -0.5)
    return throttle, steer

In [10]:
%%timeit
# start_time = time.perf_counter()

folder = r"C:\Users\medha\AppData\Local\Temp\airsim_car"
images = os.listdir(folder)

for img_path in images:
    img = cv2.imread(os.path.join(folder, img_path), cv2.IMREAD_GRAYSCALE)
    throttle, steer = predict_torch(img, model)
# end_time = time.perf_counter()
# print(f"Time taken: {end_time - start_time:.2f} seconds for {len(images)} images")

4min ± 8.11 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
