In [1]:
import os
import torch
import numpy as np
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import cv2
import torch.nn as nn

class F110_YOLO(torch.nn.Module):
    def __init__(self):
        super(F110_YOLO, self).__init__()  # 180x320x3
        # TODO: Change the channel depth of each layer
        self.conv1 = nn.Conv2d(3, 32, kernel_size = 4, padding = 1, stride = 2)  # 90x160x32
        self.batchnorm1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU(inplace = True)

        self.conv2 = nn.Conv2d(32, 64, kernel_size = 4, padding = 1, stride = 2)  # 45x80x64
        self.batchnorm2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU(inplace = True)

        self.conv3 = nn.Conv2d(64, 128, kernel_size = 4, padding = 1, stride = 2)  # 22x40x128
        self.batchnorm3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU(inplace = True)
        
        self.conv4 = nn.Conv2d(128, 256, kernel_size = 4, padding = 1, stride = 2)  # 11x20x256
        self.batchnorm4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU(inplace = True)

        self.conv5 = nn.Conv2d(256, 512, kernel_size = 4, padding = 1, stride = 2)  # 5x10x512
        self.batchnorm5 = nn.BatchNorm2d(512)
        self.relu5 = nn.ReLU(inplace = True)

        self.conv6 = nn.Conv2d(512, 1024, kernel_size = 3, padding = 1, stride = 1)  # 5x10x1024
        self.batchnorm6 = nn.BatchNorm2d(1024)
        self.relu6 = nn.ReLU(inplace = True)

        self.conv7 = nn.ConvTranspose2d(1024, 256, kernel_size = 3, padding = 1, stride = 1)  # 5x10x256
        self.batchnorm7 = nn.BatchNorm2d(256)
        self.relu7 = nn.ReLU(inplace = True)

        self.conv8 = nn.ConvTranspose2d(256, 64, kernel_size = 3, padding = 1, stride = 1)  # 5x10x64
        self.batchnorm8 = nn.BatchNorm2d(64)
        self.relu8 = nn.ReLU(inplace = True)

        self.conv9 = nn.Conv2d(64, 8, kernel_size = 1, padding = 0, stride = 1)  # 5x10x8
        self.relu9 = nn.ReLU()
    
    def forward(self, x):
        debug = 0 # change this to 1 if you want to check network dimensions
        if debug == 1: print(0, x.shape)
        x = torch.relu(self.batchnorm1(self.conv1(x)))
        if debug == 1: print(1, x.shape)
        x = torch.relu(self.batchnorm2(self.conv2(x)))
        if debug == 1: print(2, x.shape)
        x = torch.relu(self.batchnorm3(self.conv3(x)))
        if debug == 1: print(3, x.shape)
        x = torch.relu(self.batchnorm4(self.conv4(x)))
        if debug == 1: print(4, x.shape)
        x = torch.relu(self.batchnorm5(self.conv5(x)))
        if debug == 1: print(5, x.shape)
        x = torch.relu(self.batchnorm6(self.conv6(x)))
        if debug == 1: print(6, x.shape)
        x = torch.relu(self.batchnorm7(self.conv7(x)))
        if debug == 1: print(7, x.shape)
        x = torch.relu(self.batchnorm8(self.conv8(x)))
        if debug == 1: print(8, x.shape)
        x = self.conv9(x)
        if debug == 1: print(9, x.shape)
        x = torch.cat([x[:, 0:3, :, :], torch.sigmoid(x[:, 3:5, :, :])], dim=1)

        return x

    def get_loss(self, result, truth, lambda_coord = 5, lambda_noobj = 1):
        x_loss = (result[:, 1, :, :] - truth[:, 1, :, :]) ** 2
        y_loss = (result[:, 2, :, :] - truth[:, 2, :, :]) ** 2
        w_loss = (torch.sqrt(result[:, 3, :, :]) - torch.sqrt(truth[:, 3, :, :])) ** 2
        h_loss = (torch.sqrt(result[:, 4, :, :]) - torch.sqrt(truth[:, 4, :, :])) ** 2
        class_loss_obj = truth[:, 0, :, :] * (truth[:, 0, :, :] - result[:, 0, :, :]) ** 2
        class_loss_noobj = (1 - truth[:, 0, :, :]) * lambda_noobj * (truth[:, 0, :, :] - result[:, 0, :, :]) ** 2

        total_loss = torch.sum(lambda_coord * truth[:, 0, :, :] * (x_loss + y_loss + w_loss + h_loss) + class_loss_obj + class_loss_noobj)
        
        return total_loss

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = F110_YOLO()
model_save_name = 'best_model.pt'
path = "" + model_save_name
model.load_state_dict(torch.load(path))
model = model.to(device)

# Convert the model to ONNX format
ONNX_FILE_PATH = 'yolo.onnx'
dummy_input = torch.randn(1, 3, 180, 320, requires_grad=True).to(device)
torch.onnx.export(model, dummy_input, ONNX_FILE_PATH, input_names=['input'], output_names=['output'], export_params=True)

In [4]:
# Check the model converted fine
import onnx
onnx_model = onnx.load(ONNX_FILE_PATH)
onnx.checker.check_model(onnx_model)