Skip to content
This repository has been archived by the owner. It is now read-only.
Permalink
Browse files

fix for plaidML

  • Loading branch information...
iperov committed May 24, 2019
1 parent 3114ae9 commit c9da4cd89b5af4f65efe7aba6cab6947d20eb81b
Showing with 6 additions and 4 deletions.
  1. +1 −1 nnlib/device.py
  2. +5 −3 nnlib/nnlib.py
@@ -299,7 +299,7 @@ def get_plaidML_devices():
except:
pass
return plaidML_devices

if not has_nvidia_device:
get_plaidML_devices()

@@ -161,7 +161,6 @@ def import_keras(device_config):
return nnlib.code_import_keras

nnlib.backend = device_config.backend

if "tensorflow" in nnlib.backend:
nnlib._import_tf(device_config)
elif nnlib.backend == "plaidML":
@@ -174,6 +173,9 @@ def import_keras(device_config):
import keras as keras_
nnlib.keras = keras_

if 'KERAS_BACKEND' in os.environ:
os.environ.pop('KERAS_BACKEND')

if nnlib.backend == "plaidML":
import plaidml
import plaidml.tile
@@ -591,7 +593,7 @@ def get_config(self):
nnlib.Adam = Adam

def CAInitializerMP( conv_weights_list ):
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run()
for idx, weights in result:
K.set_value ( conv_weights_list[idx], weights )
@@ -696,7 +698,7 @@ def __call__(self,x):
x = ReflectionPadding2D( self.pad ) (x)
return self.func(x)
nnlib.Conv2D = Conv2D

class Conv2DTranspose():
def __init__ (self, *args, **kwargs):
self.reflect_pad = False

0 comments on commit c9da4cd

Please sign in to comment.
You can’t perform that action at this time.