In [1]:
import torch
import torchaudio
import numpy as np
from types import SimpleNamespace
from datasets import load_dataset, Audio
from fastprogress import master_bar, progress_bar
from torchvision.transforms.v2 import RandomCrop

In [2]:
torch.set_float32_matmul_precision('high')
device="cuda:1"
audioset = load_dataset("danjacobellis/audioset_opus_24kbps",split='train').cast_column('opus', Audio(decode=False))
dataset_train = audioset.select(range(1000000))
dataset_valid = audioset.select(range(1000000,1040000))

Resolving data files:   0%|          | 0/96 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/96 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/117 [00:00<?, ?it/s]

In [3]:
config = SimpleNamespace()
# Training and optimizer config
config.batch_size = 128
config.steps_per_epoch = dataset_train.num_rows//config.batch_size
config.grad_accum_steps = 1
config.max_lr = (config.batch_size/128)*6e-4
config.min_lr = config.max_lr/100
config.plot_update = 64
config.epochs = 200
config.lr_scale = 40
config.lr_offset = 0.25
config.lr_pow = 6
config.weight_decay = 0.
config.num_workers = 24
config.audio_len = 483840
config.crop_size = 512*512

# model config
config.channels = 1
config.J = 10
config.embed_dim = 256
config.dim_head = 32
config.exp_ratio = 4.0
config.classifier_num_classes = 632

In [4]:
import torch
import einops
from timm.models.maxxvit import LayerScale

class RMSNormAct(torch.nn.Module):
    def __init__(self, normalized_features):
        super(RMSNormAct, self).__init__()
        self.norm = torch.nn.RMSNorm(normalized_features)
        self.act = torch.nn.GELU()

    def forward(self, x):
        x = self.norm(x)
        x = self.act(x)
        return x

class InvertedResidual1D(torch.nn.Module):
    def __init__(self, in_dim, out_dim, sequence_dim, exp_ratio):
        super(InvertedResidual1D, self).__init__()
        self.exp_dim = int(in_dim * exp_ratio)
        self.pw_exp = torch.nn.Sequential(
            torch.nn.Conv1d(in_dim, self.exp_dim, kernel_size=1, stride=1, bias=False),
            RMSNormAct((self.exp_dim, sequence_dim))
        )
        self.dw_mid = torch.nn.Sequential(
            torch.nn.Conv1d(self.exp_dim, self.exp_dim, kernel_size=3, stride=1, padding=1, groups=self.exp_dim, bias=False),
            RMSNormAct((self.exp_dim, sequence_dim))
        )
        self.se = torch.nn.Identity()
        self.pw_proj = torch.nn.Sequential(
            torch.nn.Conv1d(self.exp_dim, out_dim, kernel_size=1, stride=1, bias=False),
            torch.nn.RMSNorm((out_dim, sequence_dim)) 
        )
        self.dw_end = torch.nn.Identity()
        self.layer_scale = LayerScale(out_dim)
        self.drop_path = torch.nn.Identity()
        
    def forward(self, x):
        shortcut = x if x.shape[1] == self.pw_proj[0].out_channels else None
        x = self.pw_exp(x)
        x = self.dw_mid(x)
        x = self.se(x)
        x = self.pw_proj(x)
        x = self.dw_end(x)
        x = self.layer_scale(x)
        x = self.drop_path(x)
        if shortcut is not None:
            x += shortcut
        return x

class TransformerBlock1D(torch.nn.Module):
    def __init__(self, embed_dim, dim_feedforward, nhead):
        super().__init__()
        self.layer = torch.nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=0.0,
            activation='gelu',
            batch_first=True
        )
    def forward(self, x):
        x = einops.rearrange(x, 'b c s -> b s c')
        x = self.layer(x)
        x = einops.rearrange(x, 'b s c -> b c s')
        return x

class AsCAN1D(torch.nn.Module):
    def __init__(self, input_dim, embed_dim, sequence_dim, dim_head, exp_ratio):
        super().__init__()
        C=lambda:InvertedResidual1D(embed_dim, embed_dim, sequence_dim, exp_ratio)
        T=lambda:TransformerBlock1D(embed_dim, int(exp_ratio*embed_dim), embed_dim//dim_head)
        self.layers=torch.nn.Sequential(
            torch.nn.Conv1d(input_dim,embed_dim,kernel_size=1),
            RMSNormAct((embed_dim, sequence_dim)),
            C(),C(),C(),T(),
            C(),C(),T(),T(),
            C(),T(),T(),T()
        )
    def forward(self,x):
        return self.layers(x)

class WaveletPooling1D(torch.nn.Module):
    def __init__(self, embed_dim, sequence_dim, wpt, num_levels):
        super().__init__()
        self.wpt = wpt
        self.num_levels = num_levels
        current_sequence_dim = sequence_dim
        self.projection_down = torch.nn.ModuleList()
        for _ in range(num_levels):
            self.projection_down.append(
                torch.nn.Sequential(
                    torch.nn.Conv1d(embed_dim, embed_dim // 2, kernel_size=1, padding=0),
                    torch.nn.RMSNorm((embed_dim // 2, current_sequence_dim))
                )
            )
            current_sequence_dim //= 2
    def forward(self, x):
        for i in range(self.num_levels):
            x = self.projection_down[i](x)
            x = self.wpt.analysis_one_level(x)
        return x


class TFTClassifier1D(torch.nn.Module):
    def __init__(self, config, wpt):
        super().__init__()
        self.wpt = wpt
        self.ascan = AsCAN1D(
            input_dim=config.channels*(2**config.J),
            embed_dim=config.embed_dim,
            sequence_dim=config.crop_size//(2**config.J),
            dim_head=config.dim_head,
            exp_ratio=config.exp_ratio
        )
        self.pool = WaveletPooling1D(
            embed_dim=config.embed_dim,
            sequence_dim=config.crop_size//(2**config.J),
            wpt=wpt,
            num_levels=int(np.log2(config.crop_size) - config.J)
        )
        self.classifier = torch.nn.Sequential(
            torch.nn.Conv1d(config.embed_dim, config.classifier_num_classes, kernel_size=1),
            torch.nn.Flatten()
        )
    def forward(self,x):
        x = self.wpt(x)
        x = self.ascan(x)
        x = self.pool(x)
        return self.classifier(x)

In [5]:
from pytorch_wavelets import DWT1DForward
from tft.transforms import WPT1D

wt = DWT1DForward(J=1, mode='periodization', wave='bior4.4')
wpt = WPT1D(wt,J=config.J).to(device)

model = TFTClassifier1D(config,wpt).to(device)

for name, module in model.named_children():
    print(f"{sum(p.numel() for p in module.parameters())/1e6} \t {name}")

0.0 	 wpt
11.771136 	 ascan
0.328448 	 pool
0.162424 	 classifier


In [6]:
rand_crop = RandomCrop(
    size=(1,config.crop_size),
    pad_if_needed=True,
    fill=0
)
def train_collate_fn(batch):
    B = len(batch)
    x = torch.zeros((B, config.channels, 1, config.crop_size), dtype=torch.float)
    y = []
    for i_sample, sample in enumerate(batch):
        y.append(sample['label'])
        x_raw, fs = torchaudio.load(uri = sample['opus']['bytes'],normalize=False)
        x[i_sample,:,:,:] = rand_crop(x_raw.unsqueeze(0).unsqueeze(0))
    return x[:,:,0,:], y

def create_multi_hot_labels(batch_of_label_indices, num_classes):
    batch_size = len(batch_of_label_indices)
    labels = torch.zeros(batch_size, num_classes, dtype=torch.float32)
    for i, sample_labels in enumerate(batch_of_label_indices):
        indices = torch.tensor(sample_labels, dtype=torch.long)
        labels[i].scatter_(0, indices, 1.0)
    return labels

In [7]:
mb = master_bar(range(config.epochs))
mb.names = ['per batch','smoothed']
for i_epoch in mb:
    dataloader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        drop_last=True,
        pin_memory=True,
        collate_fn=train_collate_fn
    )
    pb = progress_bar(dataloader_train, parent=mb)
    for i_batch, (x,y) in enumerate(pb):
        y = create_multi_hot_labels(y, config.classifier_num_classes).to(device)
        x = x.to(device)
        logits = model(x)
        loss = torch.nn.BCEWithLogitsLoss()(y,logits)
        assert 1==0

AssertionError: 