In [1]:
import sys
# appending a path
sys.path.append('../')

import torch
import torch.nn as nn
import torch.nn.functional as F

import pychop
from pychop import Chopi
pychop.backend('torch')

class QuantizedNet(nn.Module):
    def __init__(self):
        super(QuantizedNet, self).__init__()
        # Convolutional Layers
        self.conv1d = nn.Conv1d(1, 16, 3, padding=1)
        self.conv2d = nn.Conv2d(1, 16, 3, padding=1)
        self.conv3d = nn.Conv3d(1, 16, 3, padding=1)

        self.wquant_conv1d = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)
        self.wquant_conv2d = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)
        self.wquant_conv3d = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)

        # Recurrent Layers
        self.lstm = nn.LSTM(16, 32, batch_first=True)
        self.gru = nn.GRU(32, 16, batch_first=True)

        self.wquant_lstm_ih = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)
        self.wquant_lstm_hh = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)
        self.wquant_gru_ih = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)
        self.wquant_gru_hh = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)

        # BatchNorm, Pooling, and Linear
        self.bn2d = nn.BatchNorm2d(16)
        self.pool = nn.MaxPool2d(2)
        self.fc = nn.Linear(16 * 14 * 14, 10)  # 14x14 after pooling
        self.wquant_fc = Chopi(8, symmetric=True, per_channel=True, channel_dim=0)

        # Activation Chopis
        self.aquant_conv1d = Chopi(8, symmetric=False, per_channel=True, channel_dim=1)
        self.aquant_conv2d = Chopi(8, symmetric=False, per_channel=True, channel_dim=1)
        self.aquant_conv3d = Chopi(8, symmetric=False, per_channel=True, channel_dim=1)
        self.aquant_lstm = Chopi(8, symmetric=False, per_channel=True, channel_dim=2)  # dim 2 for [batch, seq, feat]
        self.aquant_gru = Chopi(8, symmetric=False, per_channel=True, channel_dim=2)
        self.aquant_fc = Chopi(8, symmetric=False)

        self.relu = nn.ReLU()

    def forward(self, x_1d, x_2d, x_3d, x_seq, training=True):
        # Conv1d
        w1d = self.wquant_conv1d(self.conv1d.weight, training=training)

        x_1d = F.conv1d(x_1d, w1d, self.conv1d.bias, padding=1)
        x_1d = self.aquant_conv1d(x_1d, training=training)
        x_1d = self.relu(x_1d)

        # Conv2d
        w2d = self.wquant_conv2d(self.conv2d.weight, training=training)
        x_2d = F.conv2d(x_2d, w2d, self.conv2d.bias, padding=1)
        x_2d = self.bn2d(x_2d)
        x_2d = self.aquant_conv2d(x_2d, training=training)
        x_2d = self.relu(x_2d)
        x_2d = self.pool(x_2d)  # [2, 16, 14, 14]
        x_2d = x_2d.view(x_2d.size(0), -1)  # [2, 16*14*14]

        # Conv3d
        w3d = self.wquant_conv3d(self.conv3d.weight, training=training)
        x_3d = F.conv3d(x_3d, w3d, self.conv3d.bias, padding=1)
        x_3d = self.aquant_conv3d(x_3d, training=training)
        x_3d = self.relu(x_3d)

        # RNN (LSTM + GRU)
        # Note: PyTorch RNNs use fused ops, so we quantize weights but apply them manually is complex.
        # For simplicity, we quantize inputs/outputs here; true integer RNNs need custom ops.
        w_lstm_ih = self.wquant_lstm_ih(self.lstm.weight_ih_l0, training=training)
        w_lstm_hh = self.wquant_lstm_hh(self.lstm.weight_hh_l0, training=training)
        x_seq, _ = self.lstm(x_seq)  # Fused op, weights not directly applied here
        x_seq = self.aquant_lstm(x_seq, training=training)
        w_gru_ih = self.wquant_gru_ih(self.gru.weight_ih_l0, training=training)
        w_gru_hh = self.wquant_gru_hh(self.gru.weight_hh_l0, training=training)
        x_seq, _ = self.gru(x_seq)
        x_seq = self.aquant_gru(x_seq, training=training)
        x_seq = x_seq[:, -1, :]  # Last timestep

        # Linear
        w_fc = self.wquant_fc(self.fc.weight, training=training)
        x_2d = F.linear(x_2d, w_fc, self.fc.bias)  # [2, 3136] * [3136, 10]
        x_2d = self.aquant_fc(x_2d, training=training)

        return x_1d, x_2d, x_3d, x_seq


In [2]:

def test():
    model = QuantizedNet()
    model.train()

    x_1d = torch.randn(2, 1, 64)  # [batch, channels, length]
    x_2d = torch.randn(2, 1, 28, 28)  # [batch, channels, height, width]
    x_3d = torch.randn(2, 1, 16, 16, 16)  # [batch, channels, depth, height, width]
    x_seq = torch.randn(2, 10, 16)  # [batch, seq_len, features]

    # Training mode
    out_1d, out_2d, out_3d, out_seq = model(x_1d, x_2d, x_3d, x_seq, training=True)
    print("Training outputs:")
    print("Conv1d:", out_1d.shape, out_1d[0, 0, :5])
    print("Conv2d+FC:", out_2d.shape, out_2d[0, :5])
    print("Conv3d:", out_3d.shape, out_3d[0, 0, 0, 0, :5])
    print("LSTM+GRU:", out_seq.shape, out_seq[0, :5])

    # Inference mode
    model.eval()
    with torch.no_grad():
        out_1d, out_2d, out_3d, out_seq = model(x_1d, x_2d, x_3d, x_seq, training=False)
        print("\nInference outputs (INT8):")
        print("Conv1d:", out_1d.shape, out_1d[0, 0, :5])
        print("Conv2d+FC:", out_2d.shape, out_2d[0, :5])
        print("Conv3d:", out_3d.shape, out_3d[0, 0, 0, :5])
        print("LSTM+GRU:", out_seq.shape, out_seq[0, :5])

if __name__ == "__main__":
    test()

Training outputs:
Conv1d: torch.Size([2, 16, 64]) tensor([0.0862, 0.0000, 0.0000, 0.0000, 0.0000])
Conv2d+FC: torch.Size([2, 10]) tensor([ 1.8073, -0.0506, -0.1454,  0.5371,  1.0490])
Conv3d: torch.Size([2, 16, 16, 16, 16]) tensor([0.6106, 0.3408, 0.4564, 1.0732, 0.1095])
LSTM+GRU: torch.Size([2, 16]) tensor([ 0.1751,  0.0754, -0.1278, -0.0274, -0.1226])

Inference outputs (INT8):
Conv1d: torch.Size([2, 16, 64]) tensor([72., 51., 56., 65., 53.])
Conv2d+FC: torch.Size([2, 10]) tensor([127.,  44.,  26.,  27.,  86.])
Conv3d: torch.Size([2, 16, 16, 16, 16]) tensor([[76., 69., 72., 88., 63., 61., 54., 70., 59., 62., 55., 60., 59., 50.,
         70., 74.],
        [55., 46., 81., 78., 70., 70., 80., 54., 55., 79., 81., 65., 57., 51.,
         77., 68.],
        [66., 69., 55., 78., 81., 77., 50., 77., 67., 72., 59., 84., 43., 56.,
         55., 65.],
        [73., 68., 76., 60., 60., 82., 55., 40., 51., 70., 74., 80., 65., 70.,
         51., 70.],
        [64., 74., 90., 58., 60., 58., 61., 