In [29]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
import pickle

In [9]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 32)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

In [10]:
model = Net()
model.load_state_dict(torch.load('mnist_dnn.pt'));

In [11]:
print("> printing the model")
print(model)

> printing the model
Net(
  (fc1): Linear(in_features=784, out_features=32, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=32, out_features=10, bias=True)
)


In [13]:
print("> showing how many trainable parameters there are in the model")
print( sum(p.numel() for p in model.parameters() if p.requires_grad) )

> showing how many trainable parameters there are in the model
25450


#### Model summary
```
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout(x) <== skip for inference
        x = self.fc2(x)
```

## Manually dividing the model into contracts
- **Future: build graph parser for pytorch models and auto-generate deployable cairo contracts + deployment script for address linkage, taking StarkNet constraints + parallelization scheme intended by the user**
- FC1: X(1x784) * M1 (784x32) = V1 (1x32)
    - dividing fc1 into 32 sop (sum-of-product; each being 784 multiply-add + bias), each producing one element in V1
    - further dividing each sop into 8 sub-sop, first 7 having 100 terms each, while the 8th has 84 terms.
        - **if not dividing into sub-sop's => would trigger exception on Lark's recursion depth limit; Lark is the parser used by the Cairo compiler by StarkWare**
    - one contract per sub-sop
    - **mnist_v1_sop<0-31>.cairo** one contract per sop, grouping 8 sub-sop's
    - **mnist_v1.cairo**: one contract groups 32 sop's to V1
    - => 1 + 32 + 32\*8 = **289 contracts**
- RELU: V1 (1x32) ==relu==> H1 (1x32)
    - one contract to produce H1
    - => **1 contract**
- FC2: H1 (1x32) * M2 (32x10) = Z (1x10)
    - dividing fc2 into 10 sop, each producing one element in Z
    - one contract per sop (32 multiply-add + bias)
    - => **10 contracts**
- Total: **300 contracts**

### Contract: V1

In [14]:
print("> Printing out cairo code for top-level contracts")

# for i in range(32):
#     print(f'    let (addr_{i}) = stored_addresses.read({i})')
#     print(f'    local pedersen_ptr : HashBuiltin* = pedersen_ptr')
#     print(f'    let (v1_{i}) = IContractV1SOP{i}.compute(addr_{i}, 784, x)')
#     print(f'    assert [v1+{i}] = v1_{i}')
#     print()

# for i in range(32):
#     print(f'@contract_interface')
#     print(f'namespace IContractV1SOP{i}:')
#     print(f'    func compute(x_len : felt, x : felt*) -> (res : felt):')
#     print(f'    end')
#     print(f'end')
#     print()

> Printing out cairo code for top-level contracts


### Contract: V1_SOP<0-31>

In [15]:
part1 = [
'%lang starknet',
'%builtins pedersen range_check',
'',
'from starkware.cairo.common.cairo_builtins import HashBuiltin',
'from starkware.cairo.common.math import signed_div_rem',
'',
'@storage_var',
'func stored_addresses (idx : felt) -> (addr : felt):',
'end',
'',
'@external',
'func admin_store_addresses {',
'        syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr',
'    } (idx : felt, addr : felt) -> ():',
'    stored_addresses.write(idx, addr)',
'    return ()',
'end',
'',
'################################################',
''
]

############

def part2 (sop_idx):
    ret = []
    for i in range(8):
        ret += [
            '@contract_interface',
            f'namespace IContractV1SOP{sop_idx}SUB{i}:',
            '    func compute(x_len : felt, x : felt*) -> (res : felt):',
            '    end',
            'end',
            ''
        ]
    ret.append('################################################')
    ret.append('')
    return ret

############

def part3 (sop_idx, quantize_depth=8):
    ret = []
    ret += [
        '@view',
        'func compute {',
        '        syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr',
        '    }(',
        '        x_len : felt,',
        '        x : felt*',
        '    ) -> (',
        '        res : felt',
        '    ):',
        '    alloc_locals',
        ''
    ]
    
    for i in range(8):
        x_len = 100 if i!=7 else 84
        ret.append(f'    let (addr_{i}) = stored_addresses.read({i})')
        ret.append(f'    local pedersen_ptr : HashBuiltin* = pedersen_ptr')
        ret.append(f'    let (local sub{i}) = IContractV1SOP{sop_idx}SUB{i}.compute(addr_{i}, {x_len}, x+100*{i})')
        ret.append('')
    ret.append('    let res_ = sub0 + sub1 + sub2 + sub3 + sub4 + sub5 + sub6 + sub7')
    ret.append(f'    let (res, _) = signed_div_rem(res_, {10**quantize_depth}, 2 ** 64)')
    ret.append('    return (res)')
    ret.append('end')
    ret.append('')

    return ret

def gen_v1_sop_contract(sop_idx, quantize_depth):
    ret = []
    ret += part1
    ret += part2(sop_idx)
    ret += part3(sop_idx, quantize_depth)
    
    with open(f'gen_contract/mnist_v1_sop{sop_idx}.cairo', 'a') as f:
        for line in ret:
            f.write(line+'\n')

In [23]:
print("> Generating 32 sop contracts for V1")
# for sop_idx in range(32):
#     gen_v1_sop_contract(sop_idx, quantize_depth=8)

> Generating 32 sop contracts for V1


### Contract: V1_SOP<0-31>_SUBSOP<0-7>
- subsop0 - subsop6: adding 100 product terms
- subsop7: adding 84 product terms + bias

#### Contract structure
```
%lang starknet
%builtins pedersen range_check

from starkware.cairo.common.cairo_builtins import HashBuiltin
from starkware.cairo.common.alloc import alloc

@view
func compute {
        syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr
    }(
        x_len : felt,
        x : felt*
    ) -> (
        res : felt
    ):
    
    let res = [x+0]*___ + [x+1]*___ + .... + [x+99]*___
    
    return (res)
end
```

In [274]:
def gen_v1_sop_subsop_contract(sop_idx, subsop_idx, model, quantize_depth = 8):
    b_s = model.fc1.bias.tolist()
    b = b_s[sop_idx]
    w_list = model.fc1.weight[sop_idx].tolist()
    f_quantize = lambda e,n : int(e* 10**n)
    w_list_quantized = [f_quantize(w, quantize_depth) for w in w_list]
    b_quantized = quantize(b, quantize_depth)
    ret = []
    
    ret += [
        '%lang starknet',
        '%builtins pedersen range_check',
        '',
        'from starkware.cairo.common.cairo_builtins import HashBuiltin',
        ''
    ]
    
    ret += [
        '@view',
        'func compute {',
        '        syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr',
        '    }(',
        '        x_len : felt,',
        '        x : felt*',
        '    ) -> (',
        '        res : felt',
        '    ):',
        ''
    ]
    if subsop_idx == 7:
        ## add 84 product terms and bias
        w_list_quantized_ = w_list_quantized[700 : 783]
        product_terms = [f'[x+{i}] * {w}' for i,w in enumerate(w_list_quantized_)]
        expression = ' + '.join([str(b_quantized)] + product_terms)
    else:
        ## add 100 product terms
        w_list_quantized_ = w_list_quantized[subsop_idx*100 : (subsop_idx+1)*100-1]
        product_terms = [f'[x+{i}] * {w}' for i,w in enumerate(w_list_quantized_)]
        expression = ' + '.join(product_terms)
    ret.append(f'    let res = {expression}')
    
    ret += [
        '    return (res)',
        'end'
        ''
    ]
    
    with open(f'gen_contract/mnist_v1_sop{sop_idx}_subsop{subsop_idx}.cairo', 'a') as f:
        for line in ret:
            f.write(line+'\n')
            
    return ret

In [22]:
print("> Generating 8 subsop contracts for each of the 32 sop's")
# for sop_idx in range(32):
#     for subsop_idx in range(8):
#         ret = gen_v1_sop_subsop_contract(sop_idx, subsop_idx, model, quantize_depth = 8)

> Generating 8 subsop contracts for each of the 32 sop's


### relu contract

In [21]:
print("> Generating cairo code for relu contract")
# for i in range(32):
#     print(f'    let (h) = _relu([v1+{i}])')
#     print(f'    assert [h1+{i}] = h')
#     print()

> Generating cairo code for relu contract


### Z contract

In [25]:
print("> Generating cairo code for the Z (FC2) contract")
# for i in range(10):
#     print(f'    let (addr_{i}) = stored_addresses.read({i})')
#     print(f'    local pedersen_ptr : HashBuiltin* = pedersen_ptr')
#     print(f'    let (z_{i}) = IContractZSOP{i}.compute(addr_{i}, 32, h1)')
#     print(f'    assert [z+{i}] = z_{i}')
#     print()

> Generating cairo code for the z (FC2) contract


### z_sop<0-9> contracts

In [275]:
def gen_z_sop_contract(sop_idx, model, quantize_depth = 8):
    b_s = model.fc2.bias.tolist()
    b = b_s[sop_idx]
    w_list = model.fc2.weight[sop_idx].tolist()
    f_quantize = lambda e,n : int(e* 10**n)
    w_list_quantized = [f_quantize(w, quantize_depth) for w in w_list]
    b_quantized = quantize(b, quantize_depth)
    ret = []
    
    ret += [
        '%lang starknet',
        '%builtins pedersen range_check',
        '',
        'from starkware.cairo.common.cairo_builtins import HashBuiltin',
        ''
    ]
    
    ret += [
        '@view',
        'func compute {',
        '        syscall_ptr : felt*, pedersen_ptr : HashBuiltin*, range_check_ptr',
        '    }(',
        '        x_len : felt,',
        '        x : felt*',
        '    ) -> (',
        '        res : felt',
        '    ):',
        ''
    ]
    
    product_terms = [f'[x+{i}] * {w}' for i,w in enumerate(w_list_quantized)]
    expression = ' + '.join([str(b_quantized)] + product_terms)
    ret.append(f'    let res = {expression}')
    
    ret += [
        '    return (res)',
        'end'
        ''
    ]
    
    with open(f'gen_contract/mnist_z_sop{sop_idx}.cairo', 'a') as f:
        for line in ret:
            f.write(line+'\n')
            
    return ret

In [26]:
print("> Generating 10 sop contracts for Z contract")
# for i in range(10):
#     gen_z_sop_contract(i, model, 8);

> Generating 10 sop contracts for Z contract


-------

### Training

In [237]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args['log_interval'] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args['dry_run']:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))


def run():
    args = {
        'no_cuda' : True,
        'seed' : 1,
        'batch_size' : 64,
        'test_batch_size' : 1000,
        'lr' : 1.0,
        'gamma' : 0.7,
        'epochs' : 14,
        'save_model' : True,
        'log_interval' : 10,
        'dry_run' : False
        
    }
   
    use_cuda = not args['no_cuda'] and torch.cuda.is_available()

    torch.manual_seed(args['seed'])

    device = torch.device("cuda" if use_cuda else "cpu")

    train_kwargs = {'batch_size': args['batch_size']}
    test_kwargs = {'batch_size': args['test_batch_size']}
    if use_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adadelta(model.parameters(), lr=args['lr'])

    scheduler = StepLR(optimizer, step_size=1, gamma=args['gamma'])
    for epoch in range(1, args['epochs'] + 1):
        train(args, model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if args['save_model']:
        torch.save(model.state_dict(), "mnist_dnn.pt")

In [58]:
run()


Test set: Average loss: 0.2707, Accuracy: 9185/10000 (92%)




Test set: Average loss: 0.2419, Accuracy: 9271/10000 (93%)


Test set: Average loss: 0.2263, Accuracy: 9343/10000 (93%)




Test set: Average loss: 0.2152, Accuracy: 9377/10000 (94%)


Test set: Average loss: 0.2060, Accuracy: 9395/10000 (94%)




Test set: Average loss: 0.2032, Accuracy: 9397/10000 (94%)




Test set: Average loss: 0.2007, Accuracy: 9426/10000 (94%)


Test set: Average loss: 0.2000, Accuracy: 9427/10000 (94%)




Test set: Average loss: 0.1990, Accuracy: 9435/10000 (94%)


Test set: Average loss: 0.1976, Accuracy: 9435/10000 (94%)




Test set: Average loss: 0.1977, Accuracy: 9437/10000 (94%)


Test set: Average loss: 0.1973, Accuracy: 9447/10000 (94%)




Test set: Average loss: 0.1965, Accuracy: 9443/10000 (94%)




Test set: Average loss: 0.1959, Accuracy: 9436/10000 (94%)



---

### Datastructure exploration

In [72]:
b_s = model.fc1.bias.tolist()
b_0 = b_s[0]
w_0_list = model.fc1.weight[0].tolist()

In [73]:
b_0

-0.03493674099445343

In [74]:
w_0_list

[0.02740933559834957,
 -0.006756555289030075,
 0.0020833066664636135,
 0.02577025257050991,
 -0.024615662172436714,
 0.03042559139430523,
 0.001659609959460795,
 0.027176501229405403,
 0.01397180836647749,
 0.0046339379623532295,
 0.018912842497229576,
 0.010768760927021503,
 0.01650378294289112,
 -0.011312826536595821,
 0.006403053179383278,
 0.005791680887341499,
 0.01418349239975214,
 0.008864249102771282,
 0.04022698849439621,
 0.02012101747095585,
 -0.004293275065720081,
 -0.012563132680952549,
 0.003020653733983636,
 -0.006398078054189682,
 -0.0024376793298870325,
 0.010716964490711689,
 0.03029737062752247,
 0.028419116511940956,
 -0.025905447080731392,
 0.03114704042673111,
 0.01898438297212124,
 0.0428837388753891,
 0.032580479979515076,
 -0.02377472072839737,
 -0.04583946242928505,
 -0.03219585120677948,
 0.019051160663366318,
 -0.01690705679357052,
 0.0018371023470535874,
 -0.02316276542842388,
 0.012664335779845715,
 -0.04223133996129036,
 0.021850770339369774,
 -0.01896392

In [32]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
ds = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
IDX = 5
d = ds.data[IDX].tolist()
d_flatten = []
for i in range(28):
    for j in range(28):
        print("%3d" % d[i][j], end=" ")
        d_flatten.append(d[i][j])
    print()

# print( len(d_flatten) )
# print('[%s]' % ','.join( [str(e) for e in d_flatten] ))

  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0  13  25 100 122   7   0   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0   0   0   0   0  33 151 208 252 252 252 146   0   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0   0   0  40 152 244 252 253 224 211 252 232  40   0   0   0   0   0   0   0 
  0   0   0   0   0   0   0   0   0  15 152 239 252 252 252 216  31  37 252 252  60   0   0   0 

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [34]:
print("> Investigating FC1 output collected from StarkNet contract for precision loss")
sol_s = [28255435953, -16423638096, -463597710, 13498191087, -30114092289, 45508329950, 26430724623, -7281002187, -28972881867, -37352675430, -22022185386, -33144123396, -16708183373, -1922013630, 25706864873, -3648043006, 22349402213, 20949610819, -25246367889, 19770606824, -3150400393, 86447757560, -32367798672, -14193289090, 72816301052, -42810611972, -1820129923, 13129317624, -7273541488, 97032159880, -54870506948, 49598410122]

for sol in sol_s:
    print(sol / 10**8)

> Investigating FC1 output collected from StarkNet contract for precision loss
282.55435953
-164.23638096
-4.6359771
134.98191087
-301.14092289
455.0832995
264.30724623
-72.81002187
-289.72881867
-373.5267543
-220.22185386
-331.44123396
-167.08183373
-19.2201363
257.06864873
-36.48043006
223.49402213
209.49610819
-252.46367889
197.70606824
-31.50400393
864.4775756
-323.67798672
-141.9328909
728.16301052
-428.10611972
-18.20129923
131.29317624
-72.73541488
970.3215988
-548.70506948
495.98410122


In [291]:
idx_s = range(10)
img_array_s = []
for idx in idx_s:
    d = ds.data[idx].tolist()
    d_flatten = []
    for i in range(28):
        for j in range(28):
            d_flatten.append(d[i][j])
    img_array_s.append(d_flatten)

with open("img_array_s.txt", "wb") as fp:   #Pickling
    pickle.dump(img_array_s, fp)

In [38]:
sol_s = [-16926425655610295354, -25769741692563623387, -12526590447583294615, -745904275336316246, -28866927377392675394, -1563107310419300343, -19919289842112617675, -21459652795700226659, -9713915730007115903, -13271941145501024253]

print("> Investigating fixed-point precision on starknet-NN output precision")
print("> Scale floats by 10^16:")
for sol in sol_s:
    print(sol / 10**16)


> Investigating fixed-point precision on starknet-NN output precision
> Scale floats by 10^16:
-1692.6425655610296
-2576.9741692563625
-1252.6590447583294
-74.59042753363163
-2886.6927377392676
-156.31073104193004
-1991.9289842112619
-2145.9652795700226
-971.3915730007116
-1327.1941145501023


In [37]:
sol_s = [-169122213668, -257449510423, -125158504463, -7408666570, -288326465022, -15572522167, -198993928326, -214469740798, -96905331675, -132517456612]

print("> Scale floats by 10^8:")
for sol in sol_s:
    print(sol / 10**8)


> Scale floats by 10^8:
-1691.22213668
-2574.49510423
-1251.58504463
-74.0866657
-2883.26465022
-155.72522167
-1989.93928326
-2144.69740798
-969.05331675
-1325.17456612
