In [39]:
from itertools import chain

import torch
from torch import nn
from torch.utils.data import random_split, DataLoader

from elasticai.creator.nn.fixed_point._two_complement_fixed_point_config import FixedPointConfig
from elasticai.creator.vhdl.testbench_helper import tensor_to_vhdl_vector
from elasticai.creator.nn.fixed_point._number_conversion import bits_to_integer, bits_to_rational
from elasticai.creator.vhdl.code_generation.code_abstractions import to_vhdl_binary_string
from elasticai.creator.nn.fixed_point._math_operations import MathOperations
from elasticai.creator.nn.fixed_point import Linear, Tanh, HardTanh, ReLU
from elasticai.creator.nn.sequential import Sequential
from elasticai.creator.vhdl.testbench_helper import tensor_to_vhdl_vector
from examples.cable_length_dataset import CableLengthDataset
from elasticai.creator.vhdl.code_generation.addressable import calculate_address_width

from serial import Serial

In [40]:
total_bits = 8
frac_bits = 5
config = FixedPointConfig(total_bits=total_bits, frac_bits=frac_bits)
ops = MathOperations(config=config)
config_double = FixedPointConfig(total_bits=2*total_bits, frac_bits=2*frac_bits)
double_ops = MathOperations(config=config_double)

In [41]:
dataset = CableLengthDataset("fixed_current_training_data.csv")

train, valid, test = random_split(dataset, lengths=[0.8, 0.1, 0.1], generator=torch.Generator().manual_seed(0))
train_dl = DataLoader(train, batch_size=1000, shuffle=True)
valid_dl = DataLoader(valid, batch_size=1000, shuffle=True)
test_dl = DataLoader(test, batch_size=1000, shuffle=True)

In [42]:
test_data, test_data_label = dataset[5]
test_data_1, test_data_1_label = dataset[500]
test_data_2, test_data_2_label = dataset[1000]
test_data_3, test_data_3_label = dataset[1500]
print(test_data)

tensor([2.6197, 2.6543, 2.6839, 2.7106, 2.7356, 2.7583, 2.7800, 2.8006, 2.8202,
        2.8390, 2.8567, 2.8739, 2.8903, 2.9060, 2.9212, 2.9357, 2.9497, 2.9632,
        2.9764])


In [43]:
q_test_data = ops.quantize(test_data)
q_test_data_1 = ops.quantize(test_data_1)
q_test_data_2 = ops.quantize(test_data_2)
q_test_data_3 = ops.quantize(test_data_3)
print(q_test_data)

tensor([2.5938, 2.6250, 2.6562, 2.6875, 2.7188, 2.7500, 2.7500, 2.7812, 2.8125,
        2.8125, 2.8438, 2.8438, 2.8750, 2.8750, 2.9062, 2.9062, 2.9375, 2.9375,
        2.9688])


In [44]:
print(tensor_to_vhdl_vector(q_test_data, config, True))
print(tensor_to_vhdl_vector(q_test_data_1, config, True))
print(tensor_to_vhdl_vector(q_test_data_2, config, True))
print(tensor_to_vhdl_vector(q_test_data_3, config, True))

("01010011","01010100","01010101","01010110","01010111","01011000","01011000","01011001","01011010","01011010","01011011","01011011","01011100","01011100","01011101","01011101","01011110","01011110","01011111")
("01001110","01010000","01010010","01010011","01010101","01010110","01011000","01011001","01011010","01011011","01011100","01011101","01011110","01011111","01100000","01100000","01100001","01100010","01100010")
("00111011","00111100","00111100","00111101","00111110","00111110","00111111","00111111","00111111","01000000","01000000","01000000","01000001","01000001","01000001","01000010","01000010","01000010","01000010")
("00111101","00111101","00111110","00111111","00111111","00111111","01000000","01000000","01000000","01000001","01000001","01000001","01000001","01000010","01000010","01000010","01000010","01000010","01000011")


In [45]:
multi_linear = Sequential(
    Linear(in_features=19, out_features=10, total_bits=total_bits, frac_bits=frac_bits, parallel=False),
    Tanh(total_bits=total_bits, frac_bits=frac_bits, num_steps=2**8, sampling_intervall=(-4, 3.96875)),
    Linear(in_features=10, out_features=4, total_bits=total_bits, frac_bits=frac_bits, parallel=True),
)

In [46]:
multi_linear.load_state_dict(torch.load(f="/home/silas/PycharmProjects/elastic-ai.creator/examples/build_dir/leds/eval_model_3bit.pth"))
multi_linear.eval()

  multi_linear.load_state_dict(torch.load(f="/home/silas/PycharmProjects/elastic-ai.creator/examples/build_dir/leds/eval_model_3bit.pth"))


Sequential(
  (0): Linear(in_features=19, out_features=10, bias=True)
  (1): Tanh(
    (_base_module): Tanh()
  )
  (2): Linear(in_features=10, out_features=4, bias=True)
)

In [47]:
#with torch.no_grad():
#    multi_linear[0].weight.data = ops.quantize(multi_linear[0].weight)
#    multi_linear[0].bias.data = ops.quantize(multi_linear[0].bias)
#    multi_linear[2].weight.data = ops.quantize(multi_linear[2].weight)
#    multi_linear[2].bias.data = ops.quantize(multi_linear[2].bias)

In [48]:
multi_linear[0].weight

Parameter containing:
tensor([[ 0.2188,  0.0312,  0.0000,  0.0000,  0.1250, -0.1250, -0.0938,  0.0938,
         -0.2188, -0.0938, -0.2188, -0.2188, -0.2188,  0.0625, -0.2188, -0.0625,
         -0.0312, -0.0938, -0.0625],
        [ 0.9688,  0.5938,  0.4688,  0.6562,  0.4688,  0.5000,  0.3125,  0.2188,
          0.1875,  0.1250,  0.0625, -0.0312, -0.4688, -0.4062, -0.5000, -0.5625,
         -0.2812, -0.6875, -0.6250],
        [-0.8438, -0.6875, -0.7500, -0.2812, -0.5938, -0.4375, -0.3750,  0.0000,
          0.1250, -0.0312,  0.1250,  0.1875,  0.3125,  0.3750,  0.4062,  0.5938,
          0.5938,  0.3750,  0.5312],
        [ 0.0938,  0.0625,  0.0625,  0.0000,  0.1875,  0.1562,  0.0938,  0.1250,
         -0.1250,  0.2500,  0.2188,  0.0312,  0.2188,  0.0625,  0.0000,  0.1562,
          0.1875,  0.1875,  0.0000],
        [ 0.2500,  0.3750,  0.0312,  0.2500, -0.0625, -0.1562, -0.1562, -0.0938,
          0.0312,  0.0938, -0.0312,  0.0000, -0.2812,  0.0000, -0.1562, -0.4375,
         -0.4688, -0

In [49]:
mult1 = q_test_data * multi_linear[0].weight[0]
mult1   

tensor([ 0.5674,  0.0820,  0.0000,  0.0000,  0.3398, -0.3438, -0.2578,  0.2607,
        -0.6152, -0.2637, -0.6221, -0.6221, -0.6289,  0.1797, -0.6357, -0.1816,
        -0.0918, -0.2754, -0.1855], grad_fn=<MulBackward0>)

In [50]:
torch.sum(mult1)

tensor(-3.2939, grad_fn=<SumBackward0>)

In [51]:
out1 = multi_linear[0](q_test_data)
out1

tensor([-0.7500,  0.5312, -0.0625,  3.9688, -0.7188, -0.0625,  0.6562, -4.0000,
         0.3125,  0.6562], grad_fn=<RoundToFixedPointBackward>)

In [52]:
multi_linear[1](out1)

tensor([-0.6250,  0.4688, -0.0312,  0.9688, -0.5938, -0.0312,  0.5625, -0.9688,
         0.2812,  0.5625], grad_fn=<RoundToFixedPointBackward>)

In [53]:
from typing import cast
from elasticai.creator.nn.fixed_point.precomputed.identity_step_function import IdentityStepFunction

sampling_intervall = (-4, 3.96875)

step_lut = torch.nn.Parameter(
            torch.linspace(*sampling_intervall, 2**8), requires_grad=False)
print(step_lut)

cast(torch.Tensor, IdentityStepFunction.apply(out1, step_lut))

Parameter containing:
tensor([-4.0000, -3.9688, -3.9375, -3.9062, -3.8750, -3.8438, -3.8125, -3.7812,
        -3.7500, -3.7188, -3.6875, -3.6562, -3.6250, -3.5938, -3.5625, -3.5312,
        -3.5000, -3.4688, -3.4375, -3.4062, -3.3750, -3.3438, -3.3125, -3.2812,
        -3.2500, -3.2188, -3.1875, -3.1562, -3.1250, -3.0938, -3.0625, -3.0312,
        -3.0000, -2.9688, -2.9375, -2.9062, -2.8750, -2.8438, -2.8125, -2.7812,
        -2.7500, -2.7188, -2.6875, -2.6562, -2.6250, -2.5938, -2.5625, -2.5312,
        -2.5000, -2.4688, -2.4375, -2.4062, -2.3750, -2.3438, -2.3125, -2.2812,
        -2.2500, -2.2188, -2.1875, -2.1562, -2.1250, -2.0938, -2.0625, -2.0312,
        -2.0000, -1.9688, -1.9375, -1.9062, -1.8750, -1.8438, -1.8125, -1.7812,
        -1.7500, -1.7188, -1.6875, -1.6562, -1.6250, -1.5938, -1.5625, -1.5312,
        -1.5000, -1.4688, -1.4375, -1.4062, -1.3750, -1.3438, -1.3125, -1.2812,
        -1.2500, -1.2188, -1.1875, -1.1562, -1.1250, -1.0938, -1.0625, -1.0312,
        -1.0000, -

tensor([-0.7500,  0.5312, -0.0625,  3.9688, -0.7188, -0.0625,  0.6562, -4.0000,
         0.3125,  0.6562], grad_fn=<IdentityStepFunctionBackward>)

# On Device inference

In [54]:
fpga = Serial(port="/dev/ttyUSB1", baudrate=115200)

In [55]:
def value_to_fxp(x: float) -> str:
    val = config.as_integer(x)
    return val

def inference_data_point(input):
    q_input = ops.quantize(input)
    
    exp_output = multi_linear(q_input).detach().numpy()
    
    input_as_fxp = [value_to_fxp(data.item()).to_bytes(1, 'little', signed=True) for data in q_input]
    
    real_output = []
    
    for idx, input in enumerate(input_as_fxp):
        fpga.write(input)
        
    for idx, output in enumerate(exp_output):
        out = fpga.read(1)
        out = int.from_bytes(out, 'little', signed=True)
        real_output.append(config.as_rational(out))
    
    return {'exp_output':exp_output, 'real_output':real_output}

In [56]:
print(inference_data_point(q_test_data))
print(inference_data_point(q_test_data_1))
print(inference_data_point(q_test_data_2))

{'exp_output': array([ 3.5625 , -3.59375, -4.     , -4.     ], dtype=float32), 'real_output': [3.5625, -3.59375, -4.0, -4.0]}
{'exp_output': array([-0.875  ,  3.96875, -1.375  , -4.     ], dtype=float32), 'real_output': [-0.875, 3.96875, -1.375, -4.0]}
{'exp_output': array([-4.     , -4.     ,  2.84375,  0.21875], dtype=float32), 'real_output': [-4.0, -4.0, 2.84375, 0.21875]}


In [57]:
def infer_batch_on_device(batch: torch.Tensor):
    outputs = torch.zeros(batch.shape[0], 4)
    for idx, datapoint in enumerate(batch):
        datapoint = ops.quantize(datapoint)
        output = inference_data_point(datapoint)
        outputs[idx] = torch.tensor(output['real_output'])
        print(idx)
    
    return outputs

In [59]:
# Testing
criterion = nn.CrossEntropyLoss()

test_loss = 0.0
test_correct_predictions = 0
test_total_samples = 0
for batch_features, batch_targets in test_dl:
    with torch.no_grad():
        outputs = infer_batch_on_device(batch_features)
        loss = criterion(outputs, batch_targets)

        # Accumulate loss
        test_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)  # Get predicted labels
        test_correct_predictions += (predicted == batch_targets).sum().item()
        test_total_samples += batch_targets.size(0)
test_accuracy = test_correct_predictions / test_total_samples
print(f"Test Loss: {test_loss/1000:.6f}, Test Accuracy: {test_accuracy:.2%}")
print(f"{fpga.in_waiting=}")


0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
Test Loss: 0.000049, Test Accuracy: 100.00%
fpga.in_waiting=0
