In [1]:
# Task B: Train RNN on FSDD 
# - full precision 32 float
# - close to comparable designs (current deviation -6%, fixable by scaling network and tuning hyperparameters)
# - RNN layer sizing (64x64x2+64x2)x4bytes = ~33kB < 36 kB from Task B constraint

In [2]:
import glob
import pandas as pd
import torch
import random
from tqdm import tqdm
import torch.quantization as quantization

from collections import OrderedDict

from sigproc import gen_logmel, feat2img
import math

In [3]:
# mainly adapted from https://github.com/saztorralba/CNNWordReco due to following
# - deeplake / hub version broken -> replaced with original wavs (cloned orig repo: https://github.com/Jakobovski/free-spoken-digit-dataset)
# - logmel suitable for detection of spoken speech -> normalized, resampled, high-pass filtered, time axis scaling

In [4]:
args = {
    'train_val_percentage': 0.1,
    'xsize': 20,
    'ysize': 20,
    'rnn_layers': 3,
    'rnn_hidden': 64,
    'rnn_outputs': 10,
    'epochs': 1000,
    'batch_size': 32,
    'learning_rate': 0.001,
    'device': 'cpu',
    'verbose': 1,
    'augment': False,
    'vocab': OrderedDict({'ZERO': 0, 'ONE': 1, 'TWO': 2, 'THREE': 3, 'FOUR': 4, 'FIVE': 5, 'SIX': 6, 'SEVEN': 7, 'EIGHT': 8, 'NINE': 9})
}

In [5]:
## labels and paths in pd frame
wavfiles = glob.glob('./free-spoken-digit-dataset/recordings/*.wav')
speakers = [file.split('/')[-1].split('_')[1] for file in wavfiles]
words = [list(args['vocab'].keys())[int(file.split('/')[-1].split('_')[0])] for file in wavfiles]
rec_number = [int(file.split('/')[-1].split('_')[2].split('.')[0]) for file in wavfiles]
data = pd.DataFrame({'wavfile':wavfiles,'speaker':speakers,'word':words,'rec_number':rec_number})

## train/test split according to https://github.com/Jakobovski/free-spoken-digit-dataset
train_data = data.loc[data['rec_number']>=5].reset_index(drop=True)
test_data = data.loc[data['rec_number']<5].reset_index(drop=True)

In [6]:
# log mels for audio; time scaled by PIL.Image to xsize, 40 nmels
def load_data(data,cv=False,**kwargs):
    n_samples = len(data)
    dataset = torch.zeros((n_samples,kwargs['ysize'],kwargs['xsize']),dtype=torch.uint8)
    labels = torch.zeros((n_samples),dtype=torch.uint8)
    for i in tqdm(range(n_samples),disable=(kwargs['verbose']<2)):
        path = data['wavfile'][i]
        dataset[i,:,:] = torch.from_numpy(feat2img(gen_logmel(path,(kwargs['n_mels'] if 'n_mels' in kwargs else 40),(kwargs['sampling'] if 'sampling' in kwargs else 8000),True),kwargs['ysize'],kwargs['xsize']))
        labels[i] = kwargs['vocab'][data['word'][i]]

    if cv == False:
        return dataset, labels

    #Do random train/validation split
    idx = [i for i in range(n_samples)]
    random.shuffle(idx)
    trainset = dataset[idx[0:int(n_samples*(1-kwargs['train_val_percentage']))]]
    trainlabels = labels[idx[0:int(n_samples*(1-kwargs['train_val_percentage']))]]
    validset = dataset[idx[int(n_samples*(1-kwargs['train_val_percentage'])):]]
    validlabels = labels[idx[int(n_samples*(1-kwargs['train_val_percentage'])):]]
    return trainset, validset, trainlabels, validlabels

def load_test_data(data,**kwargs):
    n_samples = len(data)
    dataset = torch.zeros((n_samples,kwargs['ysize'],kwargs['xsize']),dtype=torch.uint8)
    labels = torch.zeros((n_samples),dtype=torch.uint8)
    for i in tqdm(range(n_samples),disable=(kwargs['verbose']<2)):
        path = data['wavfile'][i]
        dataset[i,:,:] = torch.from_numpy(feat2img(gen_logmel(path,(kwargs['n_mels'] if 'n_mels' in kwargs else 40),(kwargs['sampling'] if 'sampling' in kwargs else 8000),True),kwargs['ysize'],kwargs['xsize']))
        labels[i] = kwargs['vocab'][data['word'][i]]

    return dataset, labels

In [7]:
trainset, validset, trainlabels, validlabels = load_data(train_data,True,**args)
print(trainset.shape, validset.shape)
testset, testlabels = load_test_data(test_data,**args)
print(testset.shape)

torch.Size([2430, 20, 20]) torch.Size([270, 20, 20])
torch.Size([300, 20, 20])


In [8]:
# Quantization function (simulating int8 quantization)
def quantize(x, num_bits=8):
    scale = x.max() / (2 ** num_bits - 1)  # Scale factor for quantization
    x_quantized = torch.round(x / scale)  # Quantize by scaling and rounding
    x_quantized = torch.clamp(x_quantized, 0, 2 ** num_bits - 1)  # Clip to valid range
    return x_quantized, scale

# Dequantization function
def dequantize(x_quantized, scale):
    return x_quantized * scale

In [9]:
class FSDNN_RNN(torch.nn.Module):
    def __init__(self, input_channels, hidden_size, num_layers, output_size):
        super(FSDNN_RNN, self).__init__()
        self.rnn = torch.nn.RNN(input_size=input_channels, 
                          hidden_size=hidden_size, 
                          num_layers=num_layers, 
                          batch_first=True)  # (batch, seq, features)
        self.fc = torch.nn.Linear(hidden_size, output_size)

    def forward(self, x):
        out, _ = self.rnn(x)  # RNN output
        out = self.fc(out[:, -1, :])  # Take last time step for classification
        return out

In [25]:
class FSDNN_RNN_Q(torch.nn.Module):
    def __init__(self, input_channels, hidden_size, num_layers, output_size):
        super(FSDNN_RNN_Q, self).__init__()
        self.rnn = torch.nn.RNN(input_size=input_channels, 
                          hidden_size=hidden_size, 
                          num_layers=num_layers, 
                          batch_first=True)  # (batch, seq, features)
        self.fc = torch.nn.Linear(hidden_size, output_size)

        self.bits = 8
        self.enable_q = False

    # quantized (modified pytorch doc implementation -> fixed layered input)
    def forward(self, x):
        hx = None
        if self.rnn.batch_first:
            x = x.transpose(0, 1)
        seq_len, batch_size, _ = x.size()

        if hx is None:
            hx = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size, device=x.device)
        h_t_minus_1 = hx.detach()
        h_t = torch.zeros_like(h_t_minus_1)

        output = []

        # Quantize inputs
        if self.enable_q:
            x_quantized, scale = quantize(x.clone(), self.bits)
            x = dequantize(x_quantized, scale).requires_grad_()

        for t in range(seq_len):
            h_t_new = []
            for layer in range(self.rnn.num_layers):
                weight_ih = getattr(self.rnn, f'weight_ih_l{layer}')
                bias_ih = getattr(self.rnn, f'bias_ih_l{layer}')
                weight_hh = getattr(self.rnn, f'weight_hh_l{layer}')
                bias_hh = getattr(self.rnn, f'bias_hh_l{layer}')

                xin = x[t] if layer == 0 else h_t[layer-1]

                h_layer = torch.tanh(
                    xin @ weight_ih.T
                    + bias_ih
                    + h_t_minus_1[layer] @ weight_hh.T
                    + bias_hh
                )

                # Quantize hidden state
                if self.enable_q:
                    ht_quantized, scale = quantize(h_layer, self.bits)
                    h_layer = dequantize(ht_quantized, scale).requires_grad_()

                h_t_new.append(h_layer)

            h_t = torch.stack(h_t_new)
            output.append(h_t[-1])

            h_t_minus_1 = h_t.detach()

        output = torch.stack(output)
        if self.rnn.batch_first:
            output = output.transpose(0, 1)

        out = self.fc(output[:, -1, :])
        return out

In [26]:
# load pretrained model
model_pre = FSDNN_RNN(args['ysize'], args['rnn_hidden'], args['rnn_layers'], args['rnn_outputs'])
model_pre.load_state_dict(torch.load('chkpt_t1.pt', weights_only=True)) #load pretrained 

# 
model = FSDNN_RNN_Q(args['ysize'], args['rnn_hidden'], args['rnn_layers'], args['rnn_outputs'])
pretrained_weights = model_pre.state_dict()
new_model_dict = model.state_dict()
pretrained_weights = {k: v for k, v in pretrained_weights.items() if k in new_model_dict}
new_model_dict.update(pretrained_weights)
model.load_state_dict(new_model_dict)

<All keys matched successfully>

In [12]:
for layer in range(model.rnn.num_layers):
    print(f"Layer {layer} weight_ih diff:", torch.norm(
        getattr(model.rnn, f'weight_ih_l{layer}') - getattr(model_pre.rnn, f'weight_ih_l{layer}')
    ).item())
    print(f"Layer {layer} weight_hh diff:", torch.norm(
        getattr(model.rnn, f'weight_hh_l{layer}') - getattr(model_pre.rnn, f'weight_hh_l{layer}')
    ).item())

Layer 0 weight_ih diff: 0.0
Layer 0 weight_hh diff: 0.0
Layer 1 weight_ih diff: 0.0
Layer 1 weight_hh diff: 0.0
Layer 2 weight_ih diff: 0.0
Layer 2 weight_hh diff: 0.0


In [13]:
#Train the model for an epoch
def train_model(trainset,trainlabels,model,optimizer,criterion,**kwargs):
    trainlen = trainset.shape[0]
    nbatches = math.ceil(trainlen/kwargs['batch_size'])
    if trainlen % kwargs['batch_size'] == 1:
        nbatches -= 1
    total_loss = 0
    total_backs = 0
    with tqdm(total=nbatches,disable=(kwargs['verbose']<2)) as pbar:
        model = model.train()
        for b in range(nbatches):

            #Obtain batch
            X = trainset[b*kwargs['batch_size']:min(trainlen,(b+1)*kwargs['batch_size'])].clone().float()
            X = X.to(kwargs['device'])
            Y = trainlabels[b*kwargs['batch_size']:min(trainlen,(b+1)*kwargs['batch_size'])].clone().long().to(kwargs['device'])
            #import pdb; pdb.set_trace()

            #Propagate
            posteriors = model(X)

            #Backpropagate
            loss = criterion(posteriors,Y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            #Track loss
            if total_backs == 100:
                total_loss = total_loss*0.99+loss.detach().cpu().numpy()
            else:
                total_loss += loss.detach().cpu().numpy()
                total_backs += 1
            pbar.set_description(f'Training epoch. Loss {total_loss/(total_backs+1):.2f}')
            pbar.update()
    return total_loss/(total_backs+1)

#Validate last epoch's model
def validate_model(validset,validlabels,model,**kwargs):
    validlen = validset.shape[0]
    acc = 0
    total = 0
    nbatches = math.ceil(validlen/kwargs['batch_size'])
    with torch.no_grad():
        with tqdm(total=nbatches,disable=(kwargs['verbose']<2)) as pbar:
            model = model.eval()
            for b in range(nbatches):
                #Obtain batch
                X = validset[b*kwargs['batch_size']:min(validlen,(b+1)*kwargs['batch_size'])].clone().float().to(kwargs['device'])
                Y = validlabels[b*kwargs['batch_size']:min(validlen,(b+1)*kwargs['batch_size'])].clone().long().to(kwargs['device'])
                #Propagate
                posteriors = model(X)
                #Accumulate accuracy
                estimated = torch.argmax(posteriors,dim=1)
                acc += sum((estimated.cpu().numpy() == Y.cpu().numpy()))
                total+=Y.shape[0]
                pbar.set_description(f'Evaluating epoch. Accuracy {100*acc/total:.2f}%')
                pbar.update()
    return 100*acc/total

In [19]:
# validate equivalence
acc = validate_model(testset,testlabels,model_pre,**args)
print(acc)
acc = validate_model(testset,testlabels,model,**args)
print(acc)
model.enable_q = True
acc = validate_model(testset,testlabels,model,**args)
print(acc)

90.33333333333333
90.33333333333333
19.333333333333332


In [27]:
# QAT
torch.autograd.set_detect_anomaly(True)
## Training Setup
optimizer = torch.optim.RMSprop(model.parameters(), lr=args['learning_rate'])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.5)
criterion = torch.nn.CrossEntropyLoss()

## Training Loop
for ep in range(1,args['epochs']+1):
    #Do backpropgation and validation epochs
    loss = train_model(trainset,trainlabels,model,optimizer,criterion,**args)
    scheduler.step()
    acc = validate_model(validset,validlabels,model,**args)
    print('Epoch {0:d} of {1:d}. Training loss: {2:.2f}, Validation accuracy: {3:.2f}%'.format(ep,args['epochs'],loss,acc))

Epoch 1 of 1000. Training loss: 2.67, Validation accuracy: 39.63%
Epoch 2 of 1000. Training loss: 1.67, Validation accuracy: 45.93%
Epoch 3 of 1000. Training loss: 1.54, Validation accuracy: 55.19%
Epoch 4 of 1000. Training loss: 1.44, Validation accuracy: 53.70%
Epoch 5 of 1000. Training loss: 1.32, Validation accuracy: 53.33%
Epoch 6 of 1000. Training loss: 1.31, Validation accuracy: 56.30%
Epoch 7 of 1000. Training loss: 1.34, Validation accuracy: 51.11%
Epoch 8 of 1000. Training loss: 1.26, Validation accuracy: 55.56%
Epoch 9 of 1000. Training loss: 1.20, Validation accuracy: 58.15%
Epoch 10 of 1000. Training loss: 1.13, Validation accuracy: 59.63%
Epoch 11 of 1000. Training loss: 1.12, Validation accuracy: 59.63%
Epoch 12 of 1000. Training loss: 1.09, Validation accuracy: 60.74%
Epoch 13 of 1000. Training loss: 1.11, Validation accuracy: 59.26%
Epoch 14 of 1000. Training loss: 1.14, Validation accuracy: 59.63%
Epoch 15 of 1000. Training loss: 1.13, Validation accuracy: 58.89%
Epoc

In [28]:
torch.save(model.state_dict(), 'chkpt_t2.pt')

In [29]:
acc = validate_model(testset,testlabels,model,**args)
acc

67.66666666666667