In [1]:
# Task C: Train RNN on FSDD; quantized; weights in pow2
# - same as Task B, custom quant fcn for params

In [2]:
import glob
import pandas as pd
import torch
import random
from tqdm import tqdm
import torch.quantization as quantization

from collections import OrderedDict

from sigproc import gen_logmel, feat2img
import math

In [3]:
# mainly adapted from https://github.com/saztorralba/CNNWordReco due to following
# - deeplake / hub version broken -> replaced with original wavs (cloned orig repo: https://github.com/Jakobovski/free-spoken-digit-dataset)
# - logmel suitable for detection of spoken speech -> normalized, resampled, high-pass filtered, time axis scaling

In [4]:
args = {
    'train_val_percentage': 0.1,
    'xsize': 20,
    'ysize': 20,
    'rnn_layers': 3,
    'rnn_hidden': 64,
    'rnn_outputs': 10,
    'epochs': 1000,
    'batch_size': 32,
    'learning_rate': 0.001,
    'device': 'cpu',
    'verbose': 1,
    'augment': False,
    'vocab': OrderedDict({'ZERO': 0, 'ONE': 1, 'TWO': 2, 'THREE': 3, 'FOUR': 4, 'FIVE': 5, 'SIX': 6, 'SEVEN': 7, 'EIGHT': 8, 'NINE': 9})
}

In [5]:
## labels and paths in pd frame
wavfiles = glob.glob('./free-spoken-digit-dataset/recordings/*.wav')
speakers = [file.split('/')[-1].split('_')[1] for file in wavfiles]
words = [list(args['vocab'].keys())[int(file.split('/')[-1].split('_')[0])] for file in wavfiles]
rec_number = [int(file.split('/')[-1].split('_')[2].split('.')[0]) for file in wavfiles]
data = pd.DataFrame({'wavfile':wavfiles,'speaker':speakers,'word':words,'rec_number':rec_number})

## train/test split according to https://github.com/Jakobovski/free-spoken-digit-dataset
train_data = data.loc[data['rec_number']>=5].reset_index(drop=True)
test_data = data.loc[data['rec_number']<5].reset_index(drop=True)

In [6]:
# log mels for audio; time scaled by PIL.Image to xsize, 40 nmels
def load_data(data,cv=False,**kwargs):
    n_samples = len(data)
    dataset = torch.zeros((n_samples,kwargs['ysize'],kwargs['xsize']),dtype=torch.uint8)
    labels = torch.zeros((n_samples),dtype=torch.uint8)
    for i in tqdm(range(n_samples),disable=(kwargs['verbose']<2)):
        path = data['wavfile'][i]
        dataset[i,:,:] = torch.from_numpy(feat2img(gen_logmel(path,(kwargs['n_mels'] if 'n_mels' in kwargs else 40),(kwargs['sampling'] if 'sampling' in kwargs else 8000),True),kwargs['ysize'],kwargs['xsize']))
        labels[i] = kwargs['vocab'][data['word'][i]]

    if cv == False:
        return dataset, labels

    #Do random train/validation split
    idx = [i for i in range(n_samples)]
    random.shuffle(idx)
    trainset = dataset[idx[0:int(n_samples*(1-kwargs['train_val_percentage']))]]
    trainlabels = labels[idx[0:int(n_samples*(1-kwargs['train_val_percentage']))]]
    validset = dataset[idx[int(n_samples*(1-kwargs['train_val_percentage'])):]]
    validlabels = labels[idx[int(n_samples*(1-kwargs['train_val_percentage'])):]]
    return trainset, validset, trainlabels, validlabels

def load_test_data(data,**kwargs):
    n_samples = len(data)
    dataset = torch.zeros((n_samples,kwargs['ysize'],kwargs['xsize']),dtype=torch.uint8)
    labels = torch.zeros((n_samples),dtype=torch.uint8)
    for i in tqdm(range(n_samples),disable=(kwargs['verbose']<2)):
        path = data['wavfile'][i]
        dataset[i,:,:] = torch.from_numpy(feat2img(gen_logmel(path,(kwargs['n_mels'] if 'n_mels' in kwargs else 40),(kwargs['sampling'] if 'sampling' in kwargs else 8000),True),kwargs['ysize'],kwargs['xsize']))
        labels[i] = kwargs['vocab'][data['word'][i]]

    return dataset, labels

In [7]:
trainset, validset, trainlabels, validlabels = load_data(train_data,True,**args)
print(trainset.shape, validset.shape)
testset, testlabels = load_test_data(test_data,**args)
print(testset.shape)

torch.Size([2430, 20, 20]) torch.Size([270, 20, 20])
torch.Size([300, 20, 20])


In [8]:
# Quantization function (simulating int8 quantization)
def quantize(x, num_bits=8):
    scale = x.max() / (2 ** (num_bits-1) - 1)  # Scale factor for quantization
    x_quantized = torch.round(x / scale)  # Quantize by scaling and rounding
    x_quantized = torch.clamp(x_quantized, -2 ** (num_bits-1), 2 ** (num_bits-1) - 1)  # Clip to valid range
    return x_quantized, scale

# Dequantization function
def dequantize(x_quantized, scale):
    return x_quantized * scale

In [9]:
class FSDNN_RNN(torch.nn.Module):
    def __init__(self, input_channels, hidden_size, num_layers, output_size):
        super(FSDNN_RNN, self).__init__()
        self.rnn = torch.nn.RNN(input_size=input_channels, 
                          hidden_size=hidden_size, 
                          num_layers=num_layers, 
                          batch_first=True)  # (batch, seq, features)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)  # RNN output
        out = self.fc(out[:, -1, :])  # Take last time step for classification
        return out

In [14]:
class FSDNN_RNN_POW2WEIGHTS_Q(torch.nn.Module):
    def __init__(self, input_channels, hidden_size, num_layers, output_size):
        super(FSDNN_RNN_POW2WEIGHTS_Q, self).__init__()
        self.rnn = torch.nn.RNN(input_size=input_channels, 
                          hidden_size=hidden_size, 
                          num_layers=num_layers, 
                          batch_first=True)  # (batch, seq, features)
        self.fc = torch.nn.Linear(hidden_size, output_size)

        self.bits = 8
        self.enable_q = False

    # Quantization function
    def quantize(self, x, num_bits=8):
        scale = x.max() / (2 ** (num_bits-1) - 1)  # Scale factor for quantization
        x_quantized = torch.round(x / scale)  # Quantize by scaling and rounding
        x_quantized = torch.clamp(x_quantized, -2 ** (num_bits-1), 2 ** (num_bits-1) - 1)  # Clip to valid range
        return x_quantized, scale

    # Dequantization function
    def dequantize(self, x_quantized, scale):
        return x_quantized * scale
    
    def q_sym_noscale(self, x, num_bits=8, num_frac=6):
        s = 2 ** (num_bits - 1)
        q = torch.round(x * s)
        q = torch.clamp(q, -s, s - 1)
        q = q / (2 ** num_frac)
        return q
    
    def q_pow2_w(self, win):
        sgn = torch.sign(win)
        w_q = torch.pow(2, torch.round(torch.log2(torch.abs(win))))
        #import pdb; pdb.set_trace()
        return sgn * w_q  

    # quantized (modified pytorch doc implementation -> fixed layered input)
    def forward(self, x):

        if self.rnn.batch_first:
            x = x.transpose(0, 1)
        seq_len, batch_size, _ = x.size()

        h_t_minus_1 = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size)
        h_t = torch.zeros_like(h_t_minus_1)

        output = []

        if self.enable_q:
            x_quantized, scale = self.quantize(x.clone(), self.bits)
            x = self.dequantize(x_quantized, scale)
            #import pdb; pdb.set_trace()

        for t in range(seq_len):
            h_t_new = []
            for layer in range(self.rnn.num_layers):
                if self.enable_q:
                    weight_ih = self.q_pow2_w(getattr(self.rnn, f'weight_ih_l{layer}'))
                    bias_ih = self.q_pow2_w(getattr(self.rnn, f'bias_ih_l{layer}'))
                    weight_hh = self.q_pow2_w(getattr(self.rnn, f'weight_hh_l{layer}'))
                    bias_hh = self.q_pow2_w(getattr(self.rnn, f'bias_hh_l{layer}'))
                else:
                    weight_ih = getattr(self.rnn, f'weight_ih_l{layer}')
                    bias_ih = getattr(self.rnn, f'bias_ih_l{layer}')
                    weight_hh = getattr(self.rnn, f'weight_hh_l{layer}')
                    bias_hh = getattr(self.rnn, f'bias_hh_l{layer}')

                xin = x[t] if layer == 0 else h_t_new[layer-1]

                h_layer = torch.tanh(
                    xin @ weight_ih.T
                    + bias_ih
                    + h_t_minus_1[layer] @ weight_hh.T
                    + bias_hh
                )

                if self.enable_q:
                    h_layer = self.q_sym_noscale(h_layer.clone(), self.bits, self.bits-2)
                # h_layer_quantized, scale = self.quantize(h_layer.clone(), self.bits)
                # h_layer_q = self.dequantize(h_layer_quantized, scale)
                #import pdb; pdb.set_trace()

                h_t_new.append(h_layer)

            h_t = torch.stack(h_t_new)
            output.append(h_t[-1])

            h_t_minus_1 = h_t.detach()

        output = torch.stack(output)
        if self.rnn.batch_first:
            output = output.transpose(0, 1)

        if self.enable_q:
            self.fc.weight.data = self.q_pow2_w(self.fc.weight.data)
            self.fc.bias.data = self.q_pow2_w(self.fc.bias.data)

        out = self.fc(output[:, -1, :])

        if self.enable_q:
            out_quantized, scale = self.quantize(out.clone(), self.bits)
            out = self.dequantize(out_quantized, scale)

        return out

In [15]:
# load pretrained model
model_pre = FSDNN_RNN(args['ysize'], args['rnn_hidden'], args['rnn_layers'], args['rnn_outputs'])
model_pre.load_state_dict(torch.load('chkpt_t1.pt', weights_only=True)) #load pretrained 

# 
model = FSDNN_RNN_POW2WEIGHTS_Q(args['ysize'], args['rnn_hidden'], args['rnn_layers'], args['rnn_outputs'])
pretrained_weights = model_pre.state_dict()
new_model_dict = model.state_dict()
pretrained_weights = {k: v for k, v in pretrained_weights.items() if k in new_model_dict}
new_model_dict.update(pretrained_weights)
model.load_state_dict(new_model_dict)

<All keys matched successfully>

In [16]:
#Train the model for an epoch
def train_model(trainset,trainlabels,model,optimizer,criterion,**kwargs):
    trainlen = trainset.shape[0]
    nbatches = math.ceil(trainlen/kwargs['batch_size'])
    if trainlen % kwargs['batch_size'] == 1:
        nbatches -= 1
    total_loss = 0
    total_backs = 0
    with tqdm(total=nbatches,disable=(kwargs['verbose']<2)) as pbar:
        model = model.train()
        for b in range(nbatches):

            #Obtain batch
            X = trainset[b*kwargs['batch_size']:min(trainlen,(b+1)*kwargs['batch_size'])].clone().float()
            X = X.to(kwargs['device'])
            Y = trainlabels[b*kwargs['batch_size']:min(trainlen,(b+1)*kwargs['batch_size'])].clone().long().to(kwargs['device'])
            #import pdb; pdb.set_trace()

            #Propagate
            posteriors = model(X)

            #Backpropagate
            loss = criterion(posteriors,Y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            #Track loss
            if total_backs == 100:
                total_loss = total_loss*0.99+loss.detach().cpu().numpy()
            else:
                total_loss += loss.detach().cpu().numpy()
                total_backs += 1
            pbar.set_description(f'Training epoch. Loss {total_loss/(total_backs+1):.2f}')
            pbar.update()
    return total_loss/(total_backs+1)

#Validate last epoch's model
def validate_model(validset,validlabels,model,**kwargs):
    validlen = validset.shape[0]
    acc = 0
    total = 0
    nbatches = math.ceil(validlen/kwargs['batch_size'])
    with torch.no_grad():
        with tqdm(total=nbatches,disable=(kwargs['verbose']<2)) as pbar:
            model = model.eval()
            for b in range(nbatches):
                #Obtain batch
                X = validset[b*kwargs['batch_size']:min(validlen,(b+1)*kwargs['batch_size'])].clone().float().to(kwargs['device'])
                Y = validlabels[b*kwargs['batch_size']:min(validlen,(b+1)*kwargs['batch_size'])].clone().long().to(kwargs['device'])
                #Propagate
                posteriors = model(X)
                #Accumulate accuracy
                estimated = torch.argmax(posteriors,dim=1)
                acc += sum((estimated.cpu().numpy() == Y.cpu().numpy()))
                total+=Y.shape[0]
                pbar.set_description(f'Evaluating epoch. Accuracy {100*acc/total:.2f}%')
                pbar.update()
    return 100*acc/total

In [17]:
# validate equivalence
acc = validate_model(testset,testlabels,model_pre,**args)
print(acc)
acc = validate_model(testset,testlabels,model,**args)
print(acc)
model.enable_q = True
acc = validate_model(testset,testlabels,model,**args)
print(acc)

93.0
93.0
81.33333333333333


In [18]:
# QAT
torch.autograd.set_detect_anomaly(True)
## Training Setup
optimizer = torch.optim.RMSprop(model.parameters(), lr=args['learning_rate'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
criterion = torch.nn.CrossEntropyLoss()

## Training Loop
for ep in range(1,args['epochs']+1):
    #Do backpropgation and validation epochs
    loss = train_model(trainset,trainlabels,model,optimizer,criterion,**args)
    scheduler.step()
    acc = validate_model(validset,validlabels,model,**args)
    print('Epoch {0:d} of {1:d}. Training loss: {2:.2f}, Validation accuracy: {3:.2f}%'.format(ep,args['epochs'],loss,acc))

Epoch 1 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 2 of 1000. Training loss: 2.17, Validation accuracy: 85.19%
Epoch 3 of 1000. Training loss: 2.17, Validation accuracy: 85.19%
Epoch 4 of 1000. Training loss: 2.17, Validation accuracy: 85.19%
Epoch 5 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 6 of 1000. Training loss: 2.17, Validation accuracy: 85.19%
Epoch 7 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 8 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 9 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 10 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 11 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 12 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 13 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 14 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoch 15 of 1000. Training loss: 2.18, Validation accuracy: 85.19%
Epoc

KeyboardInterrupt: 

In [19]:
torch.save(model.state_dict(), 'chkpt_t3.pt')

In [20]:
acc = validate_model(testset,testlabels,model,**args)
acc

81.33333333333333