In [None]:
!nvidia-smi

In [None]:
import os
if not os.path.isfile('master.zip'):
    !wget -q https://github.com/karoldvl/ESC-50/archive/master.zip
    !unzip -qo master.zip
    !pip install -qq wandb encodec
    !git clone -q https://github.com/davda54/sam
    !cp sam/sam.py sam.py
os.environ['WANDB_API_KEY'] = ''
import wandb
wandb.login()

In [None]:
import torch
from torch import nn
import torchaudio
from encodec import EncodecModel
import pandas as pd
from sklearn import preprocessing
from tqdm.auto import tqdm
import torch.nn.functional as F
from sam import SAM

batch_size = 16*4
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data = pd.read_csv('ESC-50-master/meta/esc50.csv')
data_dir = 'ESC-50-master/audio'

le = preprocessing.LabelEncoder()
le.fit(data.category)

X_train = data[data.fold != 5].filename.values
X_train = [os.path.join(data_dir, i) for i in X_train]
y_train = le.transform(data[data.fold != 5].category)
X_test = data[data.fold == 5].filename.values
X_test = [os.path.join(data_dir, i) for i in X_test]
y_test = le.transform(data[data.fold == 5].category)

In [None]:
class ESC50Dataset(torch.utils.data.Dataset):
    def __init__(self, files, labels):
        self.files, self.codes = files
        self.labels = labels

    def load_wave(self, file):
        return torchaudio.load(file)[0][0].unsqueeze(0)

    def __getitem__(self, index):
        return self.files[index], self.codes[index], self.labels[index]

    def __len__(self):
        return len(self.files)


encoder_model = EncodecModel.encodec_model_24khz().to(device)
encoder_model.set_target_bandwidth(1.5)
resampler = torchaudio.transforms.Resample(44100, encoder_model.sample_rate)

@torch.no_grad()
def get_encoder_output(files, resampler, model):
    encoded = torch.cat([model.encoder(resampler(torchaudio.load(file)[0]).unsqueeze(0).to(device)).cpu() for file in tqdm(files)])
    codes = encoder_model.quantizer.encode(encoded.to(device), model.frame_rate, model.bandwidth).transpose(0, 1).cpu()
    return encoded, codes

train_loader = torch.utils.data.DataLoader(ESC50Dataset(get_encoder_output(X_train, resampler, encoder_model), y_train), batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)
test_loader = torch.utils.data.DataLoader(ESC50Dataset(get_encoder_output(X_test, resampler, encoder_model), y_test), batch_size=batch_size, shuffle=False, num_workers=2, drop_last=False)
encoder_model = None

In [None]:
for batch in train_loader:
    print(batch[0].shape)
    print(batch[1].shape)
    num_vq_encodings = batch[1].shape[1]
    break

In [None]:
# https://github.com/pytorch/pytorch/issues/1333
class CausalConv1d(torch.nn.Conv1d):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 dilation=1,
                 groups=1,
                 bias=True):

        super(CausalConv1d, self).__init__(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=0,
            dilation=dilation,
            groups=groups,
            bias=bias)
        
        self.__padding = (kernel_size - 1) * dilation
        
    def forward(self, input):
        return super(CausalConv1d, self).forward(F.pad(input, (self.__padding, 0)))

In [None]:
def sum_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

class ConvNextBlock(nn.Module):
    def __init__(self, channels, kernel=49, dropout=0.0):
        super().__init__()
        self.conv = nn.Sequential(
            CausalConv1d(channels, channels, kernel_size=kernel, groups=channels),
            nn.GroupNorm(1, channels),
            nn.Conv1d(channels, channels*4, 1),
            nn.GELU(),
            nn.Conv1d(channels*4, channels, 1),
            nn.Dropout(dropout),
            nn.Dropout1d(dropout))
    def forward(self, x):
        shortcut = x
        x = self.conv(x)
        return shortcut + x

class ConvNextlassifier(nn.Module):
    def __init__(self, n_out, num_vq_encodings, num=8, channels=128, dropout=0.0):
        super().__init__()
        self.conv = nn.Sequential(*[ConvNextBlock(channels, dropout=dropout) for _ in range(num)])
        self.out = nn.Conv1d(channels, n_out, 1, 1)
        self.codes = nn.Conv1d(channels, 1024*num_vq_encodings, 1, 1)
    def forward(self, x):
        x = F.pad(x, (1, 0))
        x = self.conv(x)
        x_codes = x[:,:,:-1]
        x_codes = self.codes(x_codes)
        x_out = x[:,:,-1:]
        x_out = self.out(x_out).squeeze(-1).squeeze(-1)
        return x_out, x_codes

model = ConvNextlassifier(n_out=len(le.classes_), num_vq_encodings=num_vq_encodings, dropout=0.05).to(device)
optimizer = SAM(model.parameters(), torch.optim.AdamW, lr=1e-4)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)

In [None]:
print(sum_params(model))
model

In [None]:
def train(model, optimizer, criterion, iterator):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    epoch_acc_codes = 0

    progress_bar = tqdm(range(len(iterator)), leave=False)
    for x, codes, y in iterator:
        x = x.to(device)
        codes = codes.to(device)
        y = y.to(device)
        y_pred, codes_pred = model(x)
        loss = criterion(y_pred, y)
        loss_codes = sum([criterion(code_pred, code.squeeze(1)) for code_pred, code in zip(codes_pred.chunk(num_vq_encodings, dim=1), codes.chunk(num_vq_encodings, dim=1))])
        (loss + loss_codes).backward()
        optimizer.first_step(zero_grad=True)
        epoch_loss += loss.item()
        epoch_acc += (y_pred.argmax(-1) == y).float().mean().item()
        epoch_acc_codes += torch.stack([(code_pred.argmax(1) == code.squeeze(1)).float().mean() for code_pred, code in zip(codes_pred.chunk(num_vq_encodings, dim=1), codes.chunk(num_vq_encodings, dim=1))]).mean().item()
        logs = {'loss': loss.item(),
                'loss_codes': loss_codes.item()}
        y_pred, codes_pred = model(x)
        (sum([criterion(code_pred, code.squeeze(1)) for code_pred, code in zip(codes_pred.chunk(num_vq_encodings, dim=1), codes.chunk(num_vq_encodings, dim=1))]) + criterion(y_pred, y)).backward()
        optimizer.second_step(zero_grad=True)
        progress_bar.update(1)
        progress_bar.set_postfix(**logs)
        wandb.log(logs)

    return epoch_loss / len(iterator), epoch_acc / len(iterator), epoch_acc_codes / len(iterator)

def evaluate(model, criterion, iterator):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    epoch_loss_codes = 0
    epoch_acc_codes = 0

    with torch.no_grad():
        for x, codes, y in iterator:
            x = x.to(device)
            codes = codes.to(device)
            y = y.to(device)
            y_pred, codes_pred = model(x)
            loss = criterion(y_pred, y)
            loss_codes = sum([criterion(code_pred, code.squeeze(1)) for code_pred, code in zip(codes_pred.chunk(num_vq_encodings, dim=1), codes.chunk(num_vq_encodings, dim=1))])
            epoch_loss += loss.item()
            epoch_loss_codes += loss_codes.item()
            epoch_acc += (y_pred.argmax(-1) == y).float().mean().item()
            epoch_acc_codes += torch.stack([(code_pred.argmax(1) == code.squeeze(1)).float().mean() for code_pred, code in zip(codes_pred.chunk(num_vq_encodings, dim=1), codes.chunk(num_vq_encodings, dim=1))]).mean().item()

    return epoch_loss / len(iterator), epoch_loss_codes / len(iterator), epoch_acc / len(iterator), epoch_acc_codes / len(iterator)

In [None]:
wandb.init(project='esc50', entity='had', name='on encoder with self supervision ConvNextlassifier 128 8 adamw lr 1e-4 bigger bs bandwidth 1.5 dropout 0.05 sam')
wandb.watch(model)
N_EPOCHS = 10000

best_valid_loss = float('inf')

for _ in tqdm(range(N_EPOCHS)):
    train_loss = train(model, optimizer, criterion, train_loader)
    valid_loss = evaluate(model, criterion, test_loader)
    print(train_loss, valid_loss)
    wandb.log({"eval_loss": valid_loss[0], "eval_codes": valid_loss[1], "eval_acc": valid_loss[2], "eval_acc_codes": valid_loss[3],
               "train_acc": train_loss[1], "train_acc_codes": train_loss[2]})
    if valid_loss[0] < best_valid_loss:
        best_valid_loss = valid_loss[0]
print(best_valid_loss)
wandb.finish()