In [None]:
from google.colab import drive

drive.mount('/content/gdrive')
path = '/content/gdrive/My Drive/MyCode/animal-code'
%cd {path}

Mounted at /content/gdrive
/content/gdrive/.shortcut-targets-by-id/1PX4wlgoBojg5BTTQ3yOEDoz8vk2IF1ju/animal-code


In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import argparse
import warnings
from classification_models import CNN
import matplotlib.pyplot as plt

warnings.filterwarnings("ignore")

In [None]:
def get_args():
    parser = argparse.ArgumentParser(description="Animal classifier")
    parser.add_argument("-s", "--size", type=int, default=224)
    parser.add_argument("-i", "--input_path", type=str, default="test_video.mp4")
    parser.add_argument("-o", "--output_path", type=str, default="test_video_output.mp4")
    parser.add_argument("-c", "--checkpoint_path", type=str, default="trained_models/best.pt")
    args = parser.parse_args()
    return args

In [None]:
def test(args):
    categories = ["butterfly", "cat", "chicken", "cow", "dog", "elephant", "horse", "sheep", "spider", "squirrel"]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN(num_classes=len(categories)).to(device)

    if args.checkpoint_path and os.path.isfile(args.checkpoint_path):
        checkpoint = torch.load(args.checkpoint_path)
        model.load_state_dict(checkpoint["model"])
        model.eval()
    else:
        print("A checkpoint must be provided")
        exit(0)

    if not args.input_path:
        print("An image must be provided")
        exit(0)

    cap = cv2.VideoCapture(args.input_path)
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    out = cv2.VideoWriter(args.output_path, cv2.VideoWriter_fourcc(*"MJPG"), int(cap.get(cv2.CAP_PROP_FPS)), (width, height))
    counter = 0
    while cap.isOpened():
        print(counter)
        counter += 1
        flag, frame = cap.read()
        if not flag:
            break
        image = cv2.resize(frame, (args.size, args.size))
        image = np.transpose(image, (2, 0, 1))
        image = image / 255
        # image = np.expand_dims(image, 0)
        image = torch.from_numpy(image).to(device).float()[None, :, :, :]
        softmax = nn.Softmax()
        with torch.no_grad():
            prediction = model(image)
        probs = softmax(prediction)
        max_value, max_index = torch.max(probs, dim=1)
        cv2.putText(frame, categories[max_index], (10, 100), cv2.FONT_HERSHEY_SIMPLEX, 5, (128, 0, 128), 2, cv2.LINE_AA)
        out.write(frame)
    cap.release()
    out.release()

In [None]:
if __name__ == "__main__":
    args = get_args()
    test(args)