## How to transmit Machine Learning Model in GoodData System

### Option 1  -- Save model to file
#### Submit Query
1. Query Customer need to define a model locally and save the model to a file with ```torch.save()```.
2. Query Customer uploads the file through UI, UI (Javascript) should read the file, convert the file to byte steam and send it to the GDS Service using GRPC ```SubmitQuery```.
3. GRPC will store the serialized model into DB without deserializing it.
4. Once DO gets notification from the blockchain, it should query GDS Service to get the byte steam using GRPC ```GetQueryInfo```.
5. Once DO gets the byte steam, it will store it to a temp file with timestamp and query uuid(See example above).
6. DO calls ```torch.load()``` to load the model from the temp file and start training.

#### Query Completed
1. DO stores trained model's ```state_dict``` to temp file.
2. DO reads temp file as byte stream and call ```QueryCompleted``` to send query to GDS Service. **Delete all temp files related to this query**
3. GDS Service gets the result and dowa consistency check. (Need to call python script to compare model results from different DOs).
4. QC calla ```GetQueryExecutionInfo``` once it knows query completed. Download the result into a file. Load the file and get the model's ```state_dict```


In [1]:
import sys

import torch
import torch.nn as nn
import torch.nn.functional as F
import time

### Define A Model

In [2]:
# A Toy Model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.fc1(x)
        x = torch.sigmoid(x)
        return x


model = Net()

### Submit Query Use Case

In [3]:
# Save Model to temp file
filename_start = './model_uuid_start'+ str(time.time())+'.pt'
torch.save(model, filename_start, pickle_protocol=2)

# .... Sending and receiving with GRPC.......

# Load Model from file
model_new_start = torch.load(filename_start)

  "type " + obj.__name__ + ". It won't be checked "


In [4]:
print(filename_start)
print(model.state_dict())
print(model_new_start.state_dict())

./model_uuid_start1597822720.728288.pt
OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])
OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])


### Query Completed Use Case

In [5]:
# Save Model state_dict to temp file
filename_completed = './model_uuid_completed'+ str(time.time())+'.pt'
torch.save(model.state_dict(), filename_completed)

# .... Sending and receiving with GRPC.......

# Load Model state_dict from file
model_new_completed = Net()
print(model_new_completed.state_dict())
model_new_completed.load_state_dict(torch.load(filename_completed))

OrderedDict([('fc1.weight', tensor([[-0.0267, -0.4148,  0.3465,  0.0986, -0.3573]])), ('fc1.bias', tensor([0.4338]))])


<All keys matched successfully>

In [6]:
print(filename_completed)
print(model.state_dict())
print(model_new_completed.state_dict())

./model_uuid_completed1597822720.771007.pt
OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])
OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])


### Option 2  -- Save model in memory (Recommended)
#### Submit Query
1. Query Customer need to define a model locally and save the model to a file with ```pickle.dumps()```.
2. Query Customer uploads the file through UI, UI (Javascript) should read the file, converts the file to byte steam and sends it to the GDS Service using GRPC ```SubmitQuery```.
3. GRPC will store the serialized model into DB without deserializing it.
4. Once DO gets notification from the blockchain, it should query GDS Service to get the byte steam using GRPC ```GetQueryInfo```.
5. Once DO gets the byte steam, it will load it to memory via ```pickle.loads()```

#### Query Completed
1. DO calls ```QueryCompleted``` to send query to GDS Service with trained model's ```state_dict```.
3. GDS Service gets the result and do consistency check. (Need to call python script to compare model results from different DOs).
4. QC calls ```GetQueryExecutionInfo``` once it knows query completed. Download the result into a file. Load the file and get the model's ```state_dict```


### Submit Query Use Case

In [7]:
import pickle

buffer = pickle.dumps(model, protocol=2)

# .... Sending and receiving with GRPC.......

new_model2 = pickle.loads(buffer)

## Even better but I have some problems to make it
# import io
# buffer = io.BytesIO()
# torch.save(x, buffer, pickle_protocol=2)
# model2 = torch.load(buffer)
    

In [8]:
print(model.state_dict())
print(new_model2.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])
OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])


### Query Completed Use Case

In [11]:
# Save Model state_dict to bytes object
query_completed_buffer = pickle.dumps(model.state_dict(), protocol=2)

# .... Sending and receiving with GRPC.......

# Load Model state_dict from bytes object
model_new_completed2 = Net()
print(model_new_completed2.state_dict())
model_new_completed2.load_state_dict(pickle.loads(query_completed_buffer))

OrderedDict([('fc1.weight', tensor([[-0.0977, -0.0238,  0.1705,  0.2426, -0.2601]])), ('fc1.bias', tensor([-0.0855]))])


<All keys matched successfully>

In [12]:
print(model.state_dict())
print(model_new_completed2.state_dict())

OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])
OrderedDict([('fc1.weight', tensor([[-0.3154, -0.3879,  0.3616, -0.1143,  0.2432]])), ('fc1.bias', tensor([0.3548]))])



### Future work
1. Encrypt the model with DO public key when sending to DO.
2. Try to resolve the problem to use ```torch.save(x, io.BytesIO(), pickle_protocol=2)```

### Open question
1. Will Model be large enough to fit in memory?
If yes, option 1 is preferred, otherwise can use option2.

### Reference
1. https://pytorch.org/tutorials/beginner/saving_loading_models.html
2. https://docs.python.org/3/library/pickle.html
