In [10]:
# Task B: Train RNN on FSDD (quantized)
# - full precision 8 bit precision (mixed-precision)
# - scaling factor unrealistic for fully fixed point implementations

In [11]:
import torch
import yaml

from utils import get_rec_paths, load_data, train_model, validate_model

In [12]:
# Load from YAML file
with open("config.yaml", "r") as f:
    args = yaml.safe_load(f)

# labels and paths in pd frame
data = get_rec_paths('./free-spoken-digit-dataset/recordings')

# load train, val, test data
trainset, validset, trainlabels, validlabels, testset, testlabels = load_data(data,True,**args)
print(trainset.shape, validset.shape, testset.shape)

torch.Size([2430, 20, 20]) torch.Size([270, 20, 20]) torch.Size([300, 20, 20])


In [13]:
class FSDNN_RNN(torch.nn.Module):
    def __init__(self, input_channels, hidden_size, num_layers, output_size):
        super(FSDNN_RNN, self).__init__()
        self.rnn = torch.nn.RNN(input_size=input_channels, 
                          hidden_size=hidden_size, 
                          num_layers=num_layers, 
                          batch_first=True)  # (batch, seq, features)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)  # RNN output
        out = self.fc(out[:, -1, :])  # Take last time step for classification
        return out

In [None]:
class FSDNN_RNN_Q(torch.nn.Module):
    def __init__(self, input_channels, hidden_size, num_layers, output_size):
        super(FSDNN_RNN_Q, self).__init__()
        self.rnn = torch.nn.RNN(input_size=input_channels, 
                          hidden_size=hidden_size, 
                          num_layers=num_layers, 
                          batch_first=True)  # (batch, seq, features)
        self.fc = torch.nn.Linear(hidden_size, output_size)

        self.bits = 8
        self.enable_q = False

    # Quantization function
    def quantize(self, x, num_bits=8):
        scale = x.max() / (2 ** (num_bits-1) - 1)  # Scale factor for quantization
        x_quantized = torch.round(x / scale)  # Quantize by scaling and rounding
        x_quantized = torch.clamp(x_quantized, -2 ** (num_bits-1), 2 ** (num_bits-1) - 1)  # Clip to valid range
        return x_quantized, scale

    # Dequantization function
    def dequantize(self, x_quantized, scale):
        return x_quantized * scale
    
    def q_sym_noscale(self, x, num_bits=8, num_frac=6):
        s = 2 ** (num_bits - 1)
        q = torch.round(x * s)
        q = torch.clamp(q, -s, s - 1)
        q = q / (2 ** num_frac)
        return q

    # quantized (modified pytorch doc implementation -> fixed layered input)
    def forward(self, x):

        if self.rnn.batch_first:
            x = x.transpose(0, 1)
        seq_len, batch_size, _ = x.size()

        h_t_minus_1 = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
        h_t = torch.zeros_like(h_t_minus_1)

        output = []

        if self.enable_q:
            x_quantized, scale = self.quantize(x.clone(), self.bits)
            x = self.dequantize(x_quantized, scale)
            #import pdb; pdb.set_trace()

        for t in range(seq_len):
            h_t_new = []
            for layer in range(self.rnn.num_layers):
                if self.enable_q:
                    weight_ih = self.q_sym_noscale(getattr(self.rnn, f'weight_ih_l{layer}'), self.bits, self.bits-2)
                    bias_ih = self.q_sym_noscale(getattr(self.rnn, f'bias_ih_l{layer}'), self.bits, self.bits-2)
                    weight_hh = self.q_sym_noscale(getattr(self.rnn, f'weight_hh_l{layer}'), self.bits, self.bits-2)
                    bias_hh = self.q_sym_noscale(getattr(self.rnn, f'bias_hh_l{layer}'), self.bits, self.bits-2)

                    getattr(self.rnn, f'weight_ih_l{layer}').data = weight_ih
                    getattr(self.rnn, f'bias_ih_l{layer}').data = bias_ih
                    getattr(self.rnn, f'weight_hh_l{layer}').data = weight_hh
                    getattr(self.rnn, f'bias_hh_l{layer}').data = bias_hh
                else:
                    weight_ih = getattr(self.rnn, f'weight_ih_l{layer}')
                    bias_ih = getattr(self.rnn, f'bias_ih_l{layer}')
                    weight_hh = getattr(self.rnn, f'weight_hh_l{layer}')
                    bias_hh = getattr(self.rnn, f'bias_hh_l{layer}')

                xin = x[t] if layer == 0 else h_t_new[layer-1]

                h_layer = torch.tanh(
                    xin @ weight_ih.T
                    + bias_ih
                    + h_t_minus_1[layer] @ weight_hh.T
                    + bias_hh
                )

                if self.enable_q:
                    h_layer = self.q_sym_noscale(h_layer.clone(), self.bits, self.bits-2)
                    # h_layer_quantized, scale = self.quantize(h_layer.clone(), self.bits)
                    # h_layer_q = self.dequantize(h_layer_quantized, scale)
                #import pdb; pdb.set_trace()

                h_t_new.append(h_layer)

            h_t = torch.stack(h_t_new)
            output.append(h_t[-1])

            h_t_minus_1 = h_t.detach()

        output = torch.stack(output)
        if self.rnn.batch_first:
            output = output.transpose(0, 1)

        if self.enable_q:
            self.fc.weight.data = self.q_sym_noscale(self.fc.weight.data, self.bits, self.bits - 2)
            self.fc.bias.data = self.q_sym_noscale(self.fc.bias.data, self.bits, self.bits - 2)

        out = self.fc(output[:, -1, :])

        if self.enable_q:
            out_quantized, scale = self.quantize(out.clone(), self.bits)
            out = self.dequantize(out_quantized, scale)
            
        return out

In [15]:
# load pretrained model
model_pre = FSDNN_RNN(args['ysize'], args['rnn_hidden'], args['rnn_layers'], args['rnn_outputs'])
model_pre.load_state_dict(torch.load('chkpt_t1.pt', weights_only=True)) #load pretrained 

# 
model = FSDNN_RNN_Q(args['ysize'], args['rnn_hidden'], args['rnn_layers'], args['rnn_outputs'])
pretrained_weights = model_pre.state_dict()
new_model_dict = model.state_dict()
pretrained_weights = {k: v for k, v in pretrained_weights.items() if k in new_model_dict}
new_model_dict.update(pretrained_weights)
model.load_state_dict(new_model_dict)

<All keys matched successfully>

In [16]:
model.rnn.weight_ih_l0.data

tensor([[-0.1432, -0.0888,  0.0867,  ..., -0.0444,  0.1106,  0.0258],
        [ 0.1011, -0.0047,  0.0121,  ..., -0.0466, -0.1202, -0.0847],
        [ 0.3043,  0.1816,  0.0604,  ..., -0.0505, -0.0778,  0.0912],
        ...,
        [ 0.1131, -0.1353,  0.0475,  ..., -0.1440, -0.0864, -0.2182],
        [ 0.0720, -0.1547, -0.1590,  ...,  0.0074,  0.0366, -0.0103],
        [-0.0082, -0.0992, -0.0657,  ...,  0.0088, -0.0150, -0.0058]])

In [17]:
# validate equivalence
acc = validate_model(testset,testlabels,model_pre,**args)
print(acc)
acc = validate_model(testset,testlabels,model,**args)
print(acc)
model.enable_q = True
acc = validate_model(testset,testlabels,model,**args)
print(acc)

91.66666666666667
91.66666666666667
78.33333333333333


In [None]:
# QAT
## Training Setup
optimizer = torch.optim.RMSprop(model.parameters(), lr=args['learning_rate'])
#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
criterion = torch.nn.CrossEntropyLoss()

## Training Loop
acc_best = 0
for ep in range(1,args['epochs']+1):

    # training
    loss = train_model(trainset,trainlabels,model,optimizer,criterion,**args)
    #scheduler.step()
    acc = validate_model(validset,validlabels,model,**args)

    # save best model
    if acc > acc_best:
        acc_best = acc
        torch.save(model.state_dict(), 'chkpt_2.pt')    

    # display progress
    # if ep % 10 == 0:
    #     print('Epoch {0:d} of {1:d}. Training loss: {2:.2f}, Validation accuracy: {3:.2f}%'.format(ep,args['epochs'],loss,acc))
    print('Epoch {0:d} of {1:d}. Training loss: {2:.2f}, Validation accuracy: {3:.2f}%'.format(ep,args['epochs'],loss,acc))

Epoch 1 of 1000. Training loss: 7.29, Validation accuracy: 79.26%
Epoch 2 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 3 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 4 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 5 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 6 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 7 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 8 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 9 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 10 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 11 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 12 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 13 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 14 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoch 15 of 1000. Training loss: 7.28, Validation accuracy: 79.26%
Epoc

In [None]:
acc = validate_model(testset,testlabels,model,**args)
acc

77.0