In [1]:
import torch
import torch, torchaudio, glob
import random
import numpy as np  

def seed_everything(seed):      
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42)

# Transformer Encoder

The transformer encoder is a stack of self-attention and feed-forward layers.

In [2]:
class FeedForward(torch.nn.Module):
    def __init__(self, d_model=512, d_ff=1024, dropout=0.1, **kwargs):
        super().__init__()
        self.ff = torch.nn.Sequential(
            torch.nn.LayerNorm(d_model),
            torch.nn.Linear(d_model, d_ff),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(d_ff, d_model),
        )
        
    def forward(self, x):
        return self.ff(x)

class SelfAttention(torch.nn.Module):
    def __init__(self, d_model, n_heads=8, d_head=64, dropout=0.1, **kwargs):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_head
        self.scale = torch.sqrt(torch.tensor(d_head, dtype=torch.float32))
        self.norm = torch.nn.LayerNorm(d_model)
        self.q_linear = torch.nn.Linear(d_model, d_head*n_heads)
        self.v_linear = torch.nn.Linear(d_model, d_head*n_heads)
        self.k_linear = torch.nn.Linear(d_model, d_head*n_heads)
        self.dropout = torch.nn.Dropout(dropout)
        self.out = torch.nn.Linear(d_head*n_heads, d_model)

    def forward(self, x):
        x = self.norm(x)
        b = x.shape[0]
        q = self.q_linear(x).view(b, -1, self.n_heads, self.d_head)
        k = self.k_linear(x).view(b, -1, self.n_heads, self.d_head)
        v = self.v_linear(x).view(b, -1, self.n_heads, self.d_head) 
        scores = torch.einsum('bihd,bjhd->bhij', q, k) / self.scale       
        att = scores.softmax(dim=-1)
        att = self.dropout(att)
        out = torch.einsum('bhij,bjhd->bihd', att, v).reshape(b, -1, self.n_heads*self.d_head)
        out = self.dropout(out)
        out = self.out(out)
        return out

class Encoder(torch.nn.Module):
    def __init__(self, nb_layers=6, seq_len=200, **kwargs):
        super().__init__()        
        self.pos = torch.nn.Parameter(torch.randn(1, seq_len, kwargs['d_model']))
        self.att = torch.nn.ModuleList([SelfAttention(**kwargs) for _ in range(nb_layers)])
        self.ff = torch.nn.ModuleList([FeedForward(**kwargs) for _ in range(nb_layers)])
        
    def forward(self, x):
        b, t, d = x.shape
        x = x + self.pos[:, :t, :]
        for att, ff in zip(self.att, self.ff):
            x = x + att(x)
            x = x + ff(x)            
        return x
    

# Feature Extractor

The feature extractor is composed of a pre-trained wav2vec2 model and a linear layer.
The output of the convolutional layers of the wav2vec2 model is used as features.

In [3]:
class PretrainedFeatures(torch.nn.Module):
    def __init__(self, freeze=True, d_model=512, **kwargs):
        super().__init__()
        from transformers import WavLMModel
        self.fe = WavLMModel.from_pretrained("patrickvonplaten/wavlm-libri-clean-100h-base-plus")
        self.linear = torch.nn.Linear(512, d_model)
        if freeze:
            for p in self.fe.parameters():
                p.requires_grad = False

    def forward(self, x): 
        with torch.no_grad():
            x = self.fe.feature_extractor(x)
            
        x = x.transpose(1, 2)
        x = self.linear(x)
        return x

# Classification network

The classification network is composed of an audio feature extractor and a transformer encoder.
The prediction is the mean of the transformer encoder output.

In [4]:
class ClassificationNetwork(torch.nn.Module):
    def __init__(self, output_dim, feat_dim=80, **kwargs):
        super().__init__()
        self.fe = PretrainedFeatures(**kwargs)
        self.encoder = Encoder(**kwargs)
        self.norm = torch.nn.LayerNorm(kwargs['d_model'])
        self.out = torch.nn.Linear(kwargs['d_model'], output_dim)

    def forward(self, x): 
        x = self.fe(x)
        x = self.encoder(x)
        x = self.norm(x)
        x = x.mean(1)
        x = self.out(x)
        return x
    
model = ClassificationNetwork(output_dim=10, 
                              d_model=256, 
                              n_heads=4, 
                              d_head=32, 
                              dropout=0.1, 
                              d_ff=256, 
                              nb_layers=4)


print( model(torch.randn(10, 16000)).shape )

Some weights of WavLMModel were not initialized from the model checkpoint at patrickvonplaten/wavlm-libri-clean-100h-base-plus and are newly initialized: ['wavlm.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wavlm.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([10, 10])


# Dataset 

In [5]:
def identity(x):
    return x

class TrainDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir='data1/train', audio_len = 16000, transform=[identity]):        
        self.transform = transform
        self.audio_len = audio_len
        self.files = sorted( glob.glob(data_dir+'/*.wav') )        
        print(len(self.files))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        x, fs = torchaudio.load(self.files[idx])
        if x.shape[1] < self.audio_len:
            x = torch.nn.functional.pad(x, (0, self.audio_len-x.shape[1]), value=0)
        
        x = x[0]
        for t in self.transform:
            x = t(x)

        label = self.files[idx].split('.')[-2].split('_')[-1]
        return x, int(label)
    

class TestDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir='data1/test', audio_len = 16000):
        self.audio_len = audio_len       
        self.files = sorted(glob.glob(data_dir+'/*.wav'))        
        print(len(self.files))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        x, fs = torchaudio.load(self.files[idx])
        if x.shape[1] < self.audio_len:
            x = torch.nn.functional.pad(x, (0, self.audio_len-x.shape[1]), value=0)
        
        x = x[0]
        label = self.files[idx].split('.')[-2].split('_')[-1]
        return x, int(label)

trainset = TrainDataset()
testset = TestDataset()

25000
5000


# Train the network

In [6]:
device = 'cuda'
model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

nb_epochs = 5
batch_size = 32
model.train()   
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
for e in range(nb_epochs):
    loss_sum = 0
    for x, y in trainloader:
        x = x.to(device)
        y = y.to(device)
        opt.zero_grad()
        out = model(x)
        # print(out.shape, y.shape)
        loss = torch.nn.functional.cross_entropy(out, y)
        loss.backward()
        opt.step()
        loss_sum += loss.item() / len(trainloader)
    print('epoch %d, loss %.4f' % (e, loss_sum))

# torch.save([model, opt], 'model62.pt')

epoch 0, loss 0.9540
epoch 1, loss 0.2394
epoch 2, loss 0.2209
epoch 3, loss 0.1741
epoch 4, loss 0.1288


# Test the network

In [7]:
model.eval()

err = 0
for x, y in testset:
    x = x.to(device)
    
    out = model(x[None,...])
    y_pred = out.argmax(dim=1).item()
    # print(y_pred, y)
    if y_pred != y:
        err += 1

print('error rate: %.4f' % (err/len(testset)))

error rate: 0.0538
