# Model Storage Demo

This notebook demonstrates API usage of the storage manager.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import os

extra_path = os.path.join(os.getcwd(), "..")
if extra_path not in sys.path:
    sys.path.append(extra_path)

In [3]:
import torch.nn as nn

In [5]:
# import APIs
from storage_manager.utils import ModelSerializer
from storage_manager.sql import NeurDB

In [6]:
# create a model
"""
This model has two linear layers, each followed by a ReLU activation function.
"""
class DemoModel(nn.Module):
    def __init__(self):
        super(DemoModel, self).__init__()
        self.fc1 = nn.Linear(10, 6)
        self.fc2 = nn.Linear(6, 3)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x

In [7]:
demo_model = DemoModel()
print(demo_model)

DemoModel(
  (fc1): Linear(in_features=10, out_features=6, bias=True)
  (fc2): Linear(in_features=6, out_features=3, bias=True)
  (relu): ReLU()
)


  return tensor.uniform_(-bound, bound, generator=generator)


In [8]:
# to serialize the model:
serialized_model = ModelSerializer.serialize_model(demo_model)
print(serialized_model)

<storage_manager.common.storage.ModelStorage.Pickled object at 0x7f579b8c05b0>


In [11]:
# to connect to the database:
connection = {
    "db_name": "postgres",
    "user": "postgres",
    "host": "localhost",
    "port": "5432",
}
database = NeurDB(
    db_name=connection["db_name"], 
    db_user=connection["user"], 
    db_host=connection["host"], 
    db_port=connection["port"]
)

A connection to the database has been established once the `NeurDB` object is created.

In [12]:
# save the serialized model to the database in layer-by-layer format
model_id = database.save_model(serialized_model)

In [13]:
# load, unpack, and convert the model to a nn.Module object
model = database.load_model(model_id).unpack().to_model()
print(model)

DemoModel(
  (fc1): Linear(in_features=10, out_features=6, bias=True)
  (fc2): Linear(in_features=6, out_features=3, bias=True)
  (relu): ReLU()
)


In [9]:
# to delete the model from the database
database.delete_model(model_id)