In [1]:
!pip install -qq torchaudio
!pip install -qq torch_optimizer
!pip install -qq editdistance
!pip install -qq wandb
!pip install -qq git+https://github.com/albumentations-team/albumentations.git

  Building wheel for albumentations (setup.py) ... [?25l[?25hdone


In [1]:
import torch
import os
import numpy as np
import random
import torchaudio
from torch import nn
import torch_optimizer
import sys
from torch.utils.data import DataLoader
from torchaudio.transforms import MelSpectrogram
import albumentations as A
from drive.MyDrive.quarznet.model.QuarzNet import QuarzNet5x3

In [2]:
# from https://stackoverflow.com/questions/57416925/best-practices-for-generating-a-random-seeds-to-seed-pytorch
def seed_everything(seed=42):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

In [3]:
train_ls = torchaudio.datasets.LIBRISPEECH(root="./", url="train-clean-100", download=True)
test_ls = torchaudio.datasets.LIBRISPEECH(root="./", url="test-clean", download=True)

In [None]:
import wandb
!wandb login

In [5]:
quarznet5x3_config = {
    'sample_rate': 16000,
    'n_mels': 128,
    'labels': 29,
    'blank_idx': 28,
    'train_batch_size': 32,
    'test_batch_size': 32,
    'spectral_augmentation': False,
    'spectral_cutout': True,
    'holes': 24,
    'epochs': 40,
    'lr': 0.015,
    'beta1': 0.95,
    'beta2': 0.25,
    'weight_decay': 0.001
}

In [6]:
sys.path.append('/content/drive/MyDrive/quarznet/utils')

In [7]:
from TextTransforms import TextTransform, greedy_path_search
from DataTransforms import DataCollate
from metrics import WER, CER

In [None]:
wandb.init(project='quarznet5x3', config=quarznet5x3_config, resume=True)

In [9]:
seed_everything()
train_dataloader = DataLoader(train_ls, wandb.config['train_batch_size'], collate_fn=DataCollate(n_mels=wandb.config['n_mels'], specCut=True, holes=wandb.config['holes']), shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_ls, wandb.config['test_batch_size'], collate_fn=DataCollate(n_mels=wandb.config['n_mels']), num_workers=2)

In [10]:
def trainEpoch(train_dataloader, model, criterion, optimizer, scheduler, scaler, epoch: int, device='cuda:0'):
  model.train()
  criterion_loss = []
  for (i, data) in enumerate(train_dataloader):
    spectrogram, targets, input_lengths, target_lengths = data
    spectrogram, targets = spectrogram.to(device), targets.to(device)

    optimizer.zero_grad(set_to_none=True)

    with torch.cuda.amp.autocast():
      log_probs = nn.functional.log_softmax(model(spectrogram), dim=1)
      loss = criterion(log_probs.permute(2, 0, 1), targets, input_lengths, target_lengths)

    criterion_loss.append(loss.item())
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scheduler.step()
    scaler.update()

  avg_loss = sum(criterion_loss) / len(criterion_loss)
  wandb.log({'train_loss': avg_loss})
  print(f"Train Epoch[{epoch}]. loss: {avg_loss} ")

  return avg_loss

@torch.no_grad()
def testEpoch(test_dataloader, model, criterion, scaler, epoch: int, device='cuda:0'):
    model.eval()
    criterion_loss = []
    wer = []
    cer = []
    for i, data in enumerate(test_dataloader):
      spectrogram, targets, input_lengths, target_lengths = data
      spectrogram, targets = spectrogram.to(device), targets.to(device)

      with torch.cuda.amp.autocast():
        log_probs = nn.functional.log_softmax(model(spectrogram), dim=1)
        loss = criterion(log_probs.permute(2, 0, 1), targets, input_lengths, target_lengths)

      scaler.scale(loss)
      criterion_loss.append(loss.item())
      

      sequences = log_probs.argmax(1)
      for k, target in enumerate(targets):
        hypothesis, reference = greedy_path_search(TextTransform(), sequences[k], target, target_lengths[k])
        cur_wer = WER(hypothesis.split(), reference.split())
        cur_cer = CER(hypothesis, reference)
        wer.append(cur_wer)
        cer.append(cur_cer)
        

    avg_loss = sum(criterion_loss) / len(criterion_loss)
    avg_wer = sum(wer) / len(wer)
    avg_cer = sum(cer) / len(cer)
    wandb.log({
        'test_loss': avg_loss,
        'WER': avg_wer,
        'CER': avg_cer
    })
    print(f"Test Epoch[{epoch}]. loss: {avg_loss}; wer: {avg_wer}; cer: {avg_cer} ")

    return avg_loss, avg_wer
      

In [11]:
seed_everything()
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f"device: {DEVICE}")
model = QuarzNet5x3(n_mels=wandb.config['n_mels'], labels=wandb.config['labels'])
model = model.to('cuda:0')
criterion = nn.CTCLoss(blank=wandb.config['blank_idx']).to(DEVICE)
novograd = torch_optimizer.NovoGrad(model.parameters(), lr=wandb.config['lr'] ,betas=(wandb.config['beta1'], wandb.config['beta2']), weight_decay=wandb.config['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(novograd, 1000)
scaler = torch.cuda.amp.GradScaler()

device: cuda:0


In [12]:
wandb.watch(model)

[<wandb.wandb_torch.TorchGraph at 0x7fabafa6e8d0>]

In [13]:
checkpoint = {}
for i in range(1, wandb.config['epochs'] + 1):
  trainEpoch(train_dataloader, model, criterion, novograd, scheduler, scaler, i, DEVICE)
  testEpoch(test_dataloader, model, criterion, scaler, i, DEVICE)
  if i % 5 == 0:
    checkpoint = {
      'epoch': i,
      'state_dict': model.state_dict(),
      'optimizer': novograd.state_dict(),
      'scheduler': scheduler.state_dict(),
      'scaler': scaler.state_dict()
    }
    torch.save(checkpoint, f'/content/drive/MyDrive/quarznet/checkpoints_5x3_cutouts128/model_state{i}.pt')



Train Epoch[1]. loss: 2.439578383626425 
Test Epoch[1]. loss: 1.849220648044493; wer: 0.9613689187016096; cer: 0.7683015229177264 
Train Epoch[2]. loss: 1.5636070939992042 
Test Epoch[2]. loss: 1.286498380143468; wer: 0.9163609570419777; cer: 0.7130836878269521 
Train Epoch[3]. loss: 1.2497756378265774 
Test Epoch[3]. loss: 1.1115765077311819; wer: 0.8935426522197566; cer: 0.6990423178146049 
Train Epoch[4]. loss: 1.0862405755861992 
Test Epoch[4]. loss: 1.0468240576546366; wer: 0.8868267258377311; cer: 0.6942301708404202 
Train Epoch[5]. loss: 0.9819518660483338 
Test Epoch[5]. loss: 1.0245512564007828; wer: 0.8809716216170617; cer: 0.6882509853300528 
Train Epoch[6]. loss: 0.9070724264923232 
Test Epoch[6]. loss: 0.9965544321188112; wer: 0.876632062774684; cer: 0.6872719025517232 
Train Epoch[7]. loss: 0.8508413179439279 
Test Epoch[7]. loss: 1.0432480391932697; wer: 0.8831011478223296; cer: 0.6881551117089934 
Train Epoch[8]. loss: 0.8070951626707086 
Test Epoch[8]. loss: 0.98174635

KeyboardInterrupt: ignored