In [1]:
import torch
import pytorch_lightning as pl
import torchmetrics
import numpy as np
import random

  from .autonotebook import tqdm as notebook_tqdm
2023-08-02 10:50:55.237579: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-02 10:50:55.284690: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [2]:
class ConvBlock(torch.nn.Module):
    def __init__(self,  chi, cho, k, s,activation=torch.nn.ReLU, dropout_rate=0.2):
        super().__init__()
        self.layer = torch.nn.Sequential(torch.nn.Conv1d(chi,cho,k,s), torch.nn.Dropout(dropout_rate), activation())
        
    def forward(self,x):
        return self.layer(x)

class SpeakerCounter(pl.LightningModule):
    def __init__(self, cnn_channels=[64,64,128,128,256,256], cnn_kernel_size=[16,8,4,4,4,4], cnn_strides=[8,1,4,1,4,1], activation=torch.nn.ReLU, dropout_rate=0.2, lstm_hidden_size=128, n_lstm=2):
        super().__init__()
        cnn_layers = [ConvBlock(chi,cho,k,s,activation,dropout_rate) for chi,cho,k,s in zip([1]+cnn_channels[:-1],cnn_channels,cnn_kernel_size,cnn_strides)]
        self.cnn_encoder = torch.nn.Sequential(*cnn_layers)
        self.lstm = torch.nn.LSTM(cnn_channels[-1],lstm_hidden_size, batch_first=True, num_layers=n_lstm)
        self.classification_layer = torch.nn.Linear(lstm_hidden_size,1)
        self.train_acc = torchmetrics.Accuracy('binary')
        self.val_acc = torchmetrics.Accuracy('binary')
        
    def forward(self, x):
        x = torch.transpose(x,1,2)
        cnn_out = self.cnn_encoder(x)
        cnn_out = torch.transpose(cnn_out,1,2)
        out, (final_hidden, cn) = self.lstm(cnn_out)
        pooled_out = torch.mean(out, axis=1)
        prob = self.classification_layer(pooled_out)
        
        return prob
    
    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters())
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        print(x.shape)
        print(y.shape)
        yhat = self(x)
        yhat = yhat[:,0]
        loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat,y)
        #self.train_acc(yhat[:,0],y)
        self.log('train_loss', loss)
        
        return loss
        
    def validation_step(self, batch, batch_idx):
        x,y = batch
        yhat = self(x)
        yhat = yhat[:,0]
        loss = torch.nn.functional.binary_cross_entropy_with_logits(yhat,y)
        #self.val_acc(yhat[:,0],y)
        #self.log('val_loss', loss)
        self.log('val_acc', self.val_acc)


In [6]:
model = SpeakerCounter()
trainer = pl.Trainer(callbacks=[pl.callbacks.ModelCheckpoint('checkpoints')],gpus=[0],logger=pl.loggers.TensorBoardLogger('tb_logs'), max_epochs=10)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 17:  47%|██████████████████████████▋                              | 15/32 [00:13<00:15,  1.07it/s, loss=0.693, v_num=0]

In [7]:
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        
    def __getitem__(self, idx):
        return torch.randn((16000,1)), np.array(random.randint(0,1), dtype=np.float32)
    
    def __len__(self):
        return 1000

train_dataset = DummyDataset()
val_dataset = DummyDataset()

train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=64,shuffle=False)

trainer.fit(model, train_dataloader, val_dataloader)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                 | Type           | Params
--------------------------------------------------------
0 | cnn_encoder          | Sequential     | 526 K 
1 | lstm                 | LSTM           | 329 K 
2 | classification_layer | Linear         | 129   
3 | train_acc            | BinaryAccuracy | 0     
4 | val_acc              | BinaryAccuracy | 0     
--------------------------------------------------------
856 K     Trainable params
0         Non-trainable params
856 K     Total params
3.424     Total estimated model params size (MB)


                                                                                                                             

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Epoch 0:   0%|                                                                                        | 0/32 [00:00<?, ?it/s]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 0:   3%|█▊                                                         | 1/32 [00:00<00:00, 32.80it/s, loss=0.698, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 0:   6%|███▋                                                       | 2/32 [00:00<00:00, 32.80it/s, loss=0.696, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 0:   9%|█████▌                                                     | 3/32 [00:00<00:00, 32.80it/s, loss=0.695, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 0:  12%|███████▍                                                   | 4/32 [00:00<00:00, 32.78it/s, loss=0.696, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 0:  16%|█████████▏                                                 | 5/32 [00:00<00:00, 32.76it/s, loss=0.695, v_num=1]torch.Size([64, 16000, 1])
tor

Validation:   0%|                                                                                     | 0/16 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                                                                        | 0/16 [00:00<?, ?it/s][A
Epoch 1:  53%|██████████████████████████████▊                           | 17/32 [00:00<00:00, 36.45it/s, loss=0.693, v_num=1][A
Epoch 1:  56%|████████████████████████████████▋                         | 18/32 [00:00<00:00, 38.03it/s, loss=0.693, v_num=1][A
Epoch 1:  59%|██████████████████████████████████▍                       | 19/32 [00:00<00:00, 39.57it/s, loss=0.693, v_num=1][A
Epoch 1:  62%|████████████████████████████████████▎                     | 20/32 [00:00<00:00, 41.04it/s, loss=0.693, v_num=1][A
Epoch 1:  66%|██████████████████████████████████████                    | 21/32 [00:00<00:00, 42.50it/s, loss=0.693, v_num=1][A
Epoch 1:  69%|███████████████████████████████████████▉                  | 22/32 [00:00<00:00, 43.

Epoch 3:   6%|███▋                                                       | 2/32 [00:00<00:00, 34.06it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 3:   9%|█████▌                                                     | 3/32 [00:00<00:00, 34.04it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 3:  12%|███████▍                                                   | 4/32 [00:00<00:00, 34.08it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 3:  16%|█████████▏                                                 | 5/32 [00:00<00:00, 34.11it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 3:  19%|███████████                                                | 6/32 [00:00<00:00, 34.16it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 3:  22%|████████████▉                                              | 7/32 [00:00<00:00, 34.12it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
tor

Epoch 4:  53%|██████████████████████████████▊                           | 17/32 [00:00<00:00, 36.37it/s, loss=0.694, v_num=1][A
Epoch 4:  56%|████████████████████████████████▋                         | 18/32 [00:00<00:00, 37.93it/s, loss=0.694, v_num=1][A
Epoch 4:  59%|██████████████████████████████████▍                       | 19/32 [00:00<00:00, 39.47it/s, loss=0.694, v_num=1][A
Epoch 4:  62%|████████████████████████████████████▎                     | 20/32 [00:00<00:00, 40.96it/s, loss=0.694, v_num=1][A
Epoch 4:  66%|██████████████████████████████████████                    | 21/32 [00:00<00:00, 42.40it/s, loss=0.694, v_num=1][A
Epoch 4:  69%|███████████████████████████████████████▉                  | 22/32 [00:00<00:00, 43.79it/s, loss=0.694, v_num=1][A
Epoch 4:  72%|█████████████████████████████████████████▋                | 23/32 [00:00<00:00, 45.14it/s, loss=0.694, v_num=1][A
Epoch 4:  75%|███████████████████████████████████████████▌              | 24/32 [00:00<00:00, 46.

Epoch 6:  12%|███████▍                                                   | 4/32 [00:00<00:00, 34.18it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 6:  16%|█████████▏                                                 | 5/32 [00:00<00:00, 34.19it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 6:  19%|███████████                                                | 6/32 [00:00<00:00, 34.24it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 6:  22%|████████████▉                                              | 7/32 [00:00<00:00, 34.30it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 6:  25%|██████████████▊                                            | 8/32 [00:00<00:00, 34.35it/s, loss=0.693, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 6:  28%|████████████████▌                                          | 9/32 [00:00<00:00, 34.37it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
tor

Epoch 7:  62%|████████████████████████████████████▎                     | 20/32 [00:00<00:00, 40.83it/s, loss=0.693, v_num=1][A
Epoch 7:  66%|██████████████████████████████████████                    | 21/32 [00:00<00:00, 42.27it/s, loss=0.693, v_num=1][A
Epoch 7:  69%|███████████████████████████████████████▉                  | 22/32 [00:00<00:00, 43.64it/s, loss=0.693, v_num=1][A
Epoch 7:  72%|█████████████████████████████████████████▋                | 23/32 [00:00<00:00, 45.00it/s, loss=0.693, v_num=1][A
Epoch 7:  75%|███████████████████████████████████████████▌              | 24/32 [00:00<00:00, 46.35it/s, loss=0.693, v_num=1][A
Epoch 7:  78%|█████████████████████████████████████████████▎            | 25/32 [00:00<00:00, 47.65it/s, loss=0.693, v_num=1][A
Epoch 7:  81%|███████████████████████████████████████████████▏          | 26/32 [00:00<00:00, 48.90it/s, loss=0.693, v_num=1][A
Epoch 7:  84%|████████████████████████████████████████████████▉         | 27/32 [00:00<00:00, 50.

Epoch 9:  19%|███████████                                                | 6/32 [00:00<00:00, 34.07it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 9:  22%|████████████▉                                              | 7/32 [00:00<00:00, 34.04it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 9:  25%|██████████████▊                                            | 8/32 [00:00<00:00, 34.10it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 9:  28%|████████████████▌                                          | 9/32 [00:00<00:00, 34.14it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 9:  31%|██████████████████▏                                       | 10/32 [00:00<00:00, 34.18it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
torch.Size([64])
Epoch 9:  34%|███████████████████▉                                      | 11/32 [00:00<00:00, 34.20it/s, loss=0.694, v_num=1]torch.Size([64, 16000, 1])
tor

In [None]:
x = torch.randn((16, 16000, 1))
out = model(x)