Skip to content
Permalink
Browse files

more robust PyTorch draw_model using distiller

  • Loading branch information
sytelus committed Nov 13, 2019
1 parent 4e79878 commit 6822d434954545bce2b4a46335e3d25896e3f973
@@ -448,4 +448,6 @@ ASALocalRun/
.localhistory/

# BeatPulse healthcheck temp database
healthchecksdb
healthchecksdb

runs/
@@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}
BIN +389 KB abc
Binary file not shown.
@@ -24,6 +24,7 @@
),
include_package_data=True,
install_requires=[
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'ipywidgets', 'sklearn', 'nbformat', 'scikit-image' # , 'receptivefield'
'matplotlib', 'numpy', 'pyzmq', 'plotly', 'torchstat', 'ipywidgets',
'sklearn', 'nbformat', 'scikit-image', 'nbformat', 'yaml' # , 'receptivefield'
]
)
@@ -30,9 +30,9 @@



def draw_model(model, input_shape=None, orientation='TB'): #orientation = 'LR' for landscpe
from .model_graph.hiddenlayer import graph
g = graph.build_graph(model, input_shape, orientation=orientation)
def draw_model(model, input_shape=None, orientation='TB', png_filename=None): #orientation = 'LR' for landscpe
from .model_graph.hiddenlayer import pytorch_draw_model
g = pytorch_draw_model.draw_graph(model, input_shape)
return g


@@ -14,13 +14,13 @@ def __init__(self, for_write:bool, file_name:str, stream_name:str=None, console_
self._file = open(file_name, 'wb' if for_write else 'rb')
self.file_name = file_name
self.for_write = for_write
utils.debug_log('FileStream started', self.file_name, verbosity=1)
utils.debug_log('FileStream started', os.path.realpath(self._file.name), verbosity=0)

def close(self):
if not self._file.closed:
self._file.close()
self._file = None
utils.debug_log('FileStream is closed', self.file_name, verbosity=1)
utils.debug_log('FileStream is closed', os.path.realpath(self._file.name), verbosity=0)
super(FileStream, self).close()

def write(self, val:Any, from_stream:'Stream'=None):
@@ -0,0 +1,85 @@
#
# Copyright (c) 2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import torch
from .distiller_utils import *

import logging
logging.captureWarnings(True)

def model_find_param_name(model, param_to_find):
"""Look up the name of a model parameter.
Arguments:
model: the model to search
param_to_find: the parameter whose name we want to look up
Returns:
The parameter name (string) or None, if the parameter was not found.
"""
for name, param in model.named_parameters():
if param is param_to_find:
return name
return None


def model_find_module_name(model, module_to_find):
"""Look up the name of a module in a model.
Arguments:
model: the model to search
module_to_find: the module whose name we want to look up
Returns:
The module name (string) or None, if the module was not found.
"""
for name, m in model.named_modules():
if m == module_to_find:
return name
return None


def model_find_param(model, param_to_find_name):
"""Look a model parameter by its name
Arguments:
model: the model to search
param_to_find_name: the name of the parameter that we are searching for
Returns:
The parameter or None, if the paramter name was not found.
"""
for name, param in model.named_parameters():
if name == param_to_find_name:
return param
return None


def model_find_module(model, module_to_find):
"""Given a module name, find the module in the provided model.
Arguments:
model: the model to search
module_to_find: the module whose name we want to look up
Returns:
The module or None, if the module was not found.
"""
for name, m in model.named_modules():
if name == module_to_find:
return m
return None

0 comments on commit 6822d43

Please sign in to comment.
You can’t perform that action at this time.