forked from ivy-llc/ivy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
multiversion_frontend_test.py
164 lines (132 loc) · 4.92 KB
/
multiversion_frontend_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
from ivy_tests import config
import sys
import jsonpickle
import importlib
def available_frameworks():
available_frameworks_lis = ["numpy", "jax", "tensorflow", "torch"]
try:
import jax
assert jax, "jax is imported to see if the user has it installed"
except ImportError:
available_frameworks_lis.remove("jax")
try:
import tensorflow as tf
assert tf, "tensorflow is imported to see if the user has it installed"
except ImportError:
available_frameworks_lis.remove("tensorflow")
try:
import torch
assert torch, "torch is imported to see if the user has it installed"
except ImportError:
available_frameworks_lis.remove("torch")
return available_frameworks_lis
def convtrue(argument):
"""Convert NativeClass in argument to true framework counter part"""
if isinstance(argument, NativeClass):
return argument._native_class
return argument
class NativeClass:
"""
An empty class to represent a class that only exist in a specific framework.
Attributes
----------
_native_class : class reference
A reference to the framework-specific class.
"""
def __init__(self, native_class):
"""
Constructs the native class object.
Parameters
----------
native_class : class reference
A reperence to the framework-specific class being represented.
"""
self._native_class = native_class
def _get_type_dict(framework):
return {
"valid": framework.valid_dtypes,
"numeric": framework.valid_numeric_dtypes,
"float": framework.valid_float_dtypes,
"integer": framework.valid_int_dtypes,
"unsigned": framework.valid_uint_dtypes,
"signed_integer": tuple(
set(framework.valid_int_dtypes).difference(framework.valid_uint_dtypes)
),
"complex": framework.valid_complex_dtypes,
"real_and_complex": tuple(
set(framework.valid_numeric_dtypes).union(framework.valid_complex_dtypes)
),
"float_and_complex": tuple(
set(framework.valid_float_dtypes).union(framework.valid_complex_dtypes)
),
"bool": tuple(
set(framework.valid_dtypes).difference(framework.valid_numeric_dtypes)
),
}
def dtype_handler(framework):
framework = importlib.import_module("ivy.functional.backends." + framework)
dtypes = _get_type_dict(framework)
dtypes = jsonpickle.dumps(dtypes)
print(dtypes)
if __name__ == "__main__":
arg_lis = sys.argv
fw_lis = []
for i in arg_lis[1:]:
if i.split("/")[0] == "jax":
fw_lis.append(i.split("/")[0] + "/" + i.split("/")[1])
fw_lis.append(i.split("/")[2] + "/" + i.split("/")[3])
else:
fw_lis.append(i)
config.allow_global_framework_imports(fw=fw_lis)
j = 1
import ivy
# ivy.bfloat16
ivy.set_backend(arg_lis[2].split("/")[0])
import numpy
while j:
try:
z = input()
if z == "1":
dtype_handler(arg_lis[2].split("/")[0])
continue
pickle_dict = jsonpickle.loads(z)
frontend_fw = input()
frontend_fw = importlib.import_module(frontend_fw)
func = input()
args_np, kwargs_np = pickle_dict["a"], pickle_dict["b"]
args_frontend = ivy.nested_map(
args_np,
lambda x: ivy.native_array(x)
if isinstance(x, numpy.ndarray)
else ivy.as_native_dtype(x)
if isinstance(x, ivy.Dtype)
else x,
shallow=False,
)
kwargs_frontend = ivy.nested_map(
kwargs_np,
lambda x: ivy.native_array(x) if isinstance(x, numpy.ndarray) else x,
shallow=False,
)
# change ivy dtypes to native dtypes
if "dtype" in kwargs_frontend:
kwargs_frontend["dtype"] = ivy.as_native_dtype(kwargs_frontend["dtype"])
# change ivy device to native devices
if "device" in kwargs_frontend:
kwargs_frontend["device"] = ivy.as_native_dev(kwargs_frontend["device"])
# check and replace the NativeClass objects in arguments
# with true counterparts
args_frontend = ivy.nested_map(
args_frontend, fn=convtrue, include_derived=True, max_depth=10
)
kwargs_frontend = ivy.nested_map(
kwargs_frontend, fn=convtrue, include_derived=True, max_depth=10
)
frontend_ret = frontend_fw.__dict__[func](*args_frontend, **kwargs_frontend)
frontend_ret = ivy.to_numpy(frontend_ret)
frontend_ret = jsonpickle.dumps(frontend_ret)
print(frontend_ret)
except EOFError:
continue
except Exception as e:
raise e