In [1]:
import onnx

import torch
from torch import nn
from torchvision import transforms
import torch.nn.functional as F
from PIL import Image

#### ADDING PREPSRATION LAYERS TO FORWARD

In [2]:
# loading pretrained models
eng_model = torch.load('models/ENG_PROD.pth')
eng_model.cpu()
eng_model.eval()

rus_model = torch.load('models/RUS_PROD.pth')
rus_model.cpu()
rus_model.eval()
None

In [None]:
class EngModelProd(nn.Module):
    '''For cropped'''
    def __init__(self, original_model):
        super(EngModelProd, self).__init__()
        self.features = nn.Sequential(*list(original_model.features))
        self.avgpool = nn.Sequential(original_model.avgpool)
        self.classifier = nn.Sequential(*list(original_model.classifier))
    
    
    def forward(self, x):
        
        x = self.features(x)
        x = self.avgpool(x)
        x = x.unsqueeze(0)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        
        return x

In [None]:
class EngModelProd(nn.Module):
    def __init__(self, original_model):
        super(EngModelProd, self).__init__()
        self.features = nn.Sequential(*list(original_model.features))
        self.avgpool = nn.Sequential(original_model.avgpool)
        self.classifier = nn.Sequential(*list(original_model.classifier))
        
    
    def forward(self, x):
        x = x.reshape(4, 720, 1280)
        x = x[:3, :, :]
        x = F.pad(x, mode='replicate', pad=(860, 860, 580, 580))
        x = nn.functional.avg_pool2d(x, kernel_size=10, stride=(10, 10))
        x = x.unsqueeze(0)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        
        return x

In [8]:
class ImageModel(nn.Module):
    '''USING IMAGE FROM PATH'''
    def __init__(self, original_model):
        super(ImageModel, self).__init__()
        self.features = nn.Sequential(*list(original_model.features))
        self.avgpool = nn.Sequential(original_model.avgpool)
        self.classifier = nn.Sequential(*list(original_model.classifier))
    
    def forward(self, x: str) -> torch.Tensor:
        x = Image.open(x)
        x = transforms.Resize((244, 244))(x)
        x = x.reshape(4, 720, 1280)
        x = x[:3, :, :]
        x = x.unsqueeze(0)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        
        return x

In [3]:
class TransformsModel(nn.Module):
    '''Standard model with input transformation'''
    
    def __init__(self, original_model):
        super(TransformsModel, self).__init__()
        self.features = nn.Sequential(*list(original_model.features))
        self.avgpool = nn.Sequential(original_model.avgpool)
        self.classifier = nn.Sequential(*list(original_model.classifier))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = transforms.Resize(244)(x)
        x = x.unsqueeze(0)
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        
        return x

In [4]:
eng_md_prod = TransformsModel(eng_model)
rus_md_prod = TransformsModel(rus_model)

Replacing hardswish with custom one

In [5]:
class New_Hardswish(nn.Module):
    @staticmethod
    def forward(x):
        return x * F.hardtanh(x + 3, 0.0, 6.0) / 6.0

In [6]:
def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            setattr(model, n, new)

In [7]:
replace_layers(eng_md_prod, nn.Hardswish, New_Hardswish())
replace_layers(rus_md_prod, nn.Hardswish, New_Hardswish())

eng_md_prod.load_state_dict(torch.load('models/eng_md_prod'))


<All keys matched successfully>

In [8]:
rus_md_prod.load_state_dict(torch.load('models/rus_new_dataset'))

<All keys matched successfully>

EXPORTING TO ONNX

In [9]:
dummy_input = torch.ones(3, 720, 1280)

eng_md_prod.cpu()
eng_md_prod.eval()
torch.onnx.export(eng_md_prod,
                    dummy_input,
                    'onnx_models/ENG_MD_PROD.onnx',
                    export_params=True,
                    do_constant_folding=False,
                    input_names = ['input'],
                    output_names = ['output'])

rus_md_prod.cpu()
rus_md_prod.eval()
torch.onnx.export(rus_md_prod,
                    dummy_input,
                    'onnx_models/RUS_MD_PRODv2.onnx',
                    export_params=True,
                    do_constant_folding=False,
                    input_names = ['input'],
                    output_names = ['output'])



Loading models

In [9]:
model = onnx.load('onnx_models/ENG_MD_PROD.onnx')
output = model.graph.output

input_all = model.graph.input
input_initializer = model.graph.initializer
input_all

[name: "input"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 720
      }
      dim {
        dim_value: 1280
      }
    }
  }
}
]

In [10]:
model = onnx.load('onnx_models/RUS_MD_PRODv2.onnx')
output = model.graph.output

input_all = model.graph.input
input_initializer = model.graph.initializer
input_all

[name: "input"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 720
      }
      dim {
        dim_value: 1280
      }
    }
  }
}
]