In [1]:
import syft as sy
sy.requires(">=0.8,<0.8.1")



✅ The installed version of syft==0.8.0 matches the requirement >=0.8 and the requirement <0.8.1


In [10]:
node = sy.orchestra.launch(name="test-domain-1", port=8080, dev_mode=True)

Starting test-domain-1 server on 0.0.0.0:8080

SQLite Store Path:
!open file:///var/folders/4s/vsgdjz495453yymsd09nx80c0000gn/T/7bca415d13ed1ec841f0d0aede098dbb.sqlite

> Domain: test-domain-1 - 7bca415d13ed1ec841f0d0aede098dbb - NodeType.DOMAIN

Services:
ActionService
DataSubjectMemberService
DataSubjectService
DatasetService
MessageService
MetadataService
NetworkService
PolicyService
ProjectService
RequestService
UserCodeService
UserService


INFO:     Started server process [3353]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8080 (Press CTRL+C to quit)


INFO:     127.0.0.1:50022 - "GET /api/v1/new/metadata HTTP/1.1" 200 OK
INFO:     127.0.0.1:50022 - "POST /api/v1/new/login HTTP/1.1" 200 OK
INFO:     127.0.0.1:50022 - "GET /api/v1/new/api?verify_key=aec6ea4dfc049ceacaeeebc493167a88a200ddc367b1fa32da652444b635d21f HTTP/1.1" 200 OK
INFO:     127.0.0.1:50027 - "POST /api/v1/new/api_call HTTP/1.1" 200 OK
INFO:     127.0.0.1:50030 - "POST /api/v1/new/api_call HTTP/1.1" 200 OK
INFO:     127.0.0.1:50033 - "POST /api/v1/new/api_call HTTP/1.1" 200 OK
INFO:     127.0.0.1:50036 - "GET /api/v1/new/metadata HTTP/1.1" 200 OK
INFO:     127.0.0.1:50039 - "POST /api/v1/new/api_call HTTP/1.1" 200 OK
INFO:     127.0.0.1:50036 - "GET /api/v1/new/metadata HTTP/1.1" 200 OK
INFO:     127.0.0.1:50044 - "POST /api/v1/new/api_call HTTP/1.1" 200 OK
INFO:     127.0.0.1:50047 - "GET /api/v1/new/api?verify_key=aec6ea4dfc049ceacaeeebc493167a88a200ddc367b1fa32da652444b635d21f HTTP/1.1" 200 OK
INFO:     127.0.0.1:50050 - "POST /api/v1/new/api_call HTTP/1.1" 200 OK
IN

INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [3353]


In [11]:
domain_client = node.login(email="info@openmined.org", password="changethis")

In [12]:
from jax import random
from flax import linen as nn
key = random.PRNGKey(42)

In [13]:
train_data = random.uniform(key, shape=(4, 28, 28, 1))

In [14]:
assert round(train_data.sum()) == 1602

In [15]:
train = sy.ActionObject.from_obj(train_data)

In [16]:
type(train.syft_action_data), train.id, train.shape

(jaxlib.xla_extension.DeviceArray,
 <UID: 0089f873adf74a0fa284290ce119bc02>,
 (4, 28, 28, 1))

In [17]:
train_domain_obj = domain_client.api.services.action.set(train)

In [18]:
class MLP(nn.Module):
    out_dims: int

    @nn.compact
    def __call__(self, x):
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(self.out_dims)(x)
        return x

model = MLP(out_dims=10)

In [19]:
weights = model.init(key, train.syft_action_data)

In [20]:
w = sy.ActionObject.from_obj(weights)

In [21]:
type(w.syft_action_data), w.id

(flax.core.frozen_dict.FrozenDict, <UID: fd3e9ba7920d4c469bc259b62aaf29dd>)

In [22]:
weight_domain_obj = domain_client.api.services.action.set(w)

In [23]:
@sy.syft_function(input_policy=sy.ExactMatch(weights=weight_domain_obj.id, data=train_domain_obj.id),
                  output_policy=sy.SingleExecutionExactOutput())
def train_mlp(weights, data):
    from flax import linen as nn

    class MLP(nn.Module):
        out_dims: int

        @nn.compact
        def __call__(self, x):
            x = x.reshape((x.shape[0], -1))
            x = nn.Dense(128)(x)
            x = nn.relu(x)
            x = nn.Dense(self.out_dims)(x)
            return x

    model = MLP(out_dims=10)
    output = model.apply(weights, data)
    return output

In [24]:
output = train_mlp(weights=weights, data=train_data)

In [25]:
assert round(output.sum(), 2) == -3.24

In [26]:
request = domain_client.api.services.code.request_code_execution(train_mlp)
request

```python
class Request:
  id: str = 226c8fc6d5104a329cf32b5353f6a036
  requesting_user_verify_key: str = aec6ea4dfc049ceacaeeebc493167a88a200ddc367b1fa32da652444b635d21f
  approving_user_verify_key: str = None
  request_time: str = 2023-05-01 23:36:24
  approval_time: str = None
  status: str = RequestStatus.PENDING
  node_uid: str = 7bca415d13ed1ec841f0d0aede098dbb
  request_hash: str = "4450f15c19d0acfdaa84d7f3d34d2670b0953c5b068572fe4d272b4df41c5720"
  changes: str = [syft.service.request.request.UserCodeStatusChange]

```

In [27]:
request.approve()

In [28]:
domain_client._api = None
_ = domain_client.api

In [29]:
result = domain_client.api.services.code.train_mlp(weights=w.id, data=train.id)

In [30]:
result

DeviceArray([[ 0.14943041, -0.36854096, -0.64575584, -0.38621526,
              -0.28981561,  0.14723957,  0.35607396,  0.898455  ,
              -0.46983801,  0.21583178],
             [-0.36093625, -0.0785419 , -0.41703793, -0.829131  ,
               0.06887782,  0.079618  ,  0.22278813,  0.55593109,
              -0.53083418, -0.0054186 ],
             [-0.31463861,  0.0295174 , -0.62358003, -0.08584506,
              -0.24341324, -0.17701983,  0.3985397 ,  0.67374497,
              -0.14091304,  0.0577738 ],
             [-0.3278211 , -0.35691213, -0.77101191, -0.52124855,
               0.10943515, -0.01648953,  0.27638874,  0.55057775,
              -0.11716184,  0.05130892]], dtype=float64)

In [31]:
assert round(float(result.sum()), 2) == -3.24

In [32]:
if node.node_type.value == "python":
    node.land()

Stopping test-domain-1
