In [1]:
import numpy as np
import pandas as pd
import random
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn.utils import weight_norm
import torch.optim as optim
from torchinfo import summary
from tqdm import tqdm
from glob import glob

In [2]:
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.has_mps else "cpu")
epochs = 32
batch_size = 16

print(device)

cuda


In [22]:
class CustomDataset(Dataset):
    def __init__(self, path, is_train=True):
        self.path = path
        self.is_train = is_train
        self.mel_path = path + "mel"
        self.wav_path = path + "audio"
        self.mel_list = glob(self.mel_path + "/*.npy")
        self.wav_list = glob(self.wav_path + "/*.npy")
        self.hop_length = 256
        self.seq_len = 32

        if self.is_train:
            self.df = pd.read_csv("data/train_data.csv")
            self.y = self.df["covid19"]
        else:
            self.df = pd.read_csv("data/test_data.csv")

        self.mel_files = glob(self.mel_path + "/*.npy")
        self.wav_files = glob(self.wav_path + "/*.wav")

    def __len__(self):
        return len(self.wav_list)

    def random_select(self):
        self.random_covid = bool(random.randint(0, 1))

        if self.random_covid: self.data_list = self.df[self.df["covid19"] == 1].index.tolist()
        else: self.data_list = self.df[self.df["covid19"] == 0].index.tolist()

        self.index_list = self.data_list[random.randint(0, len(self.data_list)-1)]

        return self.index_list

    def __getitem__(self, idx):
        if self.is_train: index = self.random_select()
        else: index = idx

        mel = np.load(self.mel_list[index])
        mel = torch.from_numpy(mel).float()
        start = random.randint(0, mel.size(1) - self.seq_len - 1)
        wav = np.load(self.wav_list[index])
        wav = torch.from_numpy(wav).float()
        start *= self.hop_length
        wav = wav[start : start + self.seq_len * self.hop_length]

        if self.is_train:
            try:
                y = self.y[index]
            except:
                y = 0
            return wav.unsqueeze(0), y
        else:
            return wav.unsqueeze(0)

In [23]:
def decoder_sequential(input_size, output_size, *args, **kwargs):
    return nn.Sequential(
        weight_norm((nn.Conv1d(input_size, output_size, *args, **kwargs))),
        nn.LeakyReLU(0.2, inplace=True)
    )

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.discriminator = nn.ModuleList([
            # Feature map x 1
            nn.Sequential(
                nn.ReflectionPad1d(7), # 7+1+7 = 15
                weight_norm(nn.Conv1d(1, 16, kernel_size=15)),
                nn.LeakyReLU(0.2, inplace=True) # modify the input
            ),
            # Downsampling layer Feature map x 4
            decoder_sequential(16, 64, kernel_size=41, stride=4, padding=20, groups=4),
            decoder_sequential(64, 256, kernel_size=41, stride=4, padding=20, groups=16),
            decoder_sequential(256, 1024, kernel_size=41, stride=4, padding=20, groups=64),
            decoder_sequential(1024, 1024, kernel_size=41, stride=4, padding=20, groups=256),
            # Feature map x 1
            nn.Sequential(
                weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, padding=2)),
                nn.LeakyReLU(0.2, inplace=True)
            ),
            # Output x 1
            weight_norm(nn.Conv1d(1024, 1, kernel_size=3, padding=1))
        ])

    def forward(self, x):
        feature_map = []
        for module in self.discriminator:
            x = module(x)
            feature_map.append(x)
        return feature_map

class MultiScale(nn.Module):
    def __init__(self):
        super().__init__()

        self.block = nn.ModuleList([
            Discriminator() for _ in range(3)
        ])

        self.avgpool = nn.AvgPool1d(kernel_size=4, stride=2, padding=1, count_include_pad=False)

    def forward(self, x):
        result = []
        for idx, module in enumerate(self.block):
            result.append(module(x))
            if idx <= 1:
                x = self.avgpool(x)
        return result

In [24]:
class Disc(nn.Module):
    def __init__(self):
        super().__init__()

        self.out_layer = nn.Sequential(
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(8192, 4196),
            nn.Conv1d(1, 1, kernel_size=1),
            nn.Linear(4196, 1024),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv1d(1, 1, kernel_size=1),
            nn.Linear(512, 64),
            nn.Linear(64, 1),
            nn.Flatten(),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        # out = self.in_layer(x)
        # print(out.shape)
        return self.out_layer(x)

In [25]:
train_dataset = CustomDataset(path="./data/train/")
test_dataset = CustomDataset(path="./data/test/", is_train=False)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
test_loader= DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [7]:
checkpoint_path = glob("MelGAN-pytorch/ckpt/train/ckpt-*.pt")[-1]
checkpoint = torch.load(checkpoint_path, map_location=device)

In [8]:
multi_scale = MultiScale()
multi_scale.load_state_dict(checkpoint["D"])

<All keys matched successfully>

In [9]:
next(iter(train_dataset))[0].unsqueeze(0)

tensor([[[-2.2815e-03, -1.5467e-03,  5.2056e-04,  ..., -1.5382e-04,
          -9.4758e-05, -1.4260e-05]]])

In [10]:
for param in multi_scale.parameters():
    param.requires_grad = False

In [11]:
disc = Disc().to(device)

In [12]:
summary(disc, (1, 1, 8192))

Layer (type:depth-idx)                   Output Shape              Param #
Disc                                     [1, 1]                    --
├─Sequential: 1-1                        [1, 1]                    --
│    └─LeakyReLU: 2-1                    [1, 1, 8192]              --
│    └─Linear: 2-2                       [1, 1, 4196]              34,377,828
│    └─Conv1d: 2-3                       [1, 1, 4196]              2
│    └─Linear: 2-4                       [1, 1, 1024]              4,297,728
│    └─Linear: 2-5                       [1, 1, 512]               524,800
│    └─LeakyReLU: 2-6                    [1, 1, 512]               --
│    └─Conv1d: 2-7                       [1, 1, 512]               2
│    └─Linear: 2-8                       [1, 1, 64]                32,832
│    └─Linear: 2-9                       [1, 1, 1]                 65
│    └─Flatten: 2-10                     [1, 1]                    --
│    └─LeakyReLU: 2-11                   [1, 1]                

In [13]:
optimizer = optim.NAdam(disc.parameters(), lr=0.002)
criterion = nn.BCEWithLogitsLoss()

In [14]:
val = torch.FloatTensor(next(iter(train_dataset))[0].unsqueeze(0))

In [15]:
ls = []
for i, (x, y) in enumerate(train_loader):
    for j in x:
        ls.append(j)
    break

out = torch.stack(ls).to(device)
out.shape

torch.Size([16, 1, 8192])

In [16]:
print("Start training")
multi_scale.to(device)
disc.to(device)

try:
    disc.load_state_dict(torch.load(glob("model/eval-*.pt")[-1], map_location=device))
    print("Loaded model")
except:
    print("No model found or model checkpoing load failed")

for epoch in range(epochs):
    multi_scale.eval()
    disc.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device).unsqueeze(1).type(torch.float)
        # print("x", x.shape)
        # print("y", y.shape)

        x_multiscale = [k for k in [j for j in x]]

        x_multiscale = torch.stack(x_multiscale).to(device)
        # print("x_multiscale", x_multiscale.shape)

        disc.zero_grad()
        out = disc(x_multiscale).type(torch.float)
        # print("disc out", out.shape)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()
        if i % 100 == 0:
            print(f"Epoch: {epoch}, Step: {i}, Loss: {loss.item()}")
    if epoch % 1 == 0:
        torch.save(disc.state_dict(), f"model/eval-{epoch}.pt")

Start training
No model found or model checkpoing load failed
Epoch: 0, Step: 0, Loss: 0.7237203121185303
Epoch: 0, Step: 100, Loss: 0.6931471824645996
Epoch: 0, Step: 200, Loss: 0.6931471824645996
Epoch: 1, Step: 0, Loss: 0.6931471824645996
Epoch: 1, Step: 100, Loss: 0.6931471824645996
Epoch: 1, Step: 200, Loss: 0.6931471824645996
Epoch: 2, Step: 0, Loss: 0.6931471824645996
Epoch: 2, Step: 100, Loss: 0.6931471824645996
Epoch: 2, Step: 200, Loss: 0.6931471824645996
Epoch: 3, Step: 0, Loss: 0.6931471824645996
Epoch: 3, Step: 100, Loss: 0.6931471824645996
Epoch: 3, Step: 200, Loss: 0.6931471824645996
Epoch: 4, Step: 0, Loss: 0.6931471824645996
Epoch: 4, Step: 100, Loss: 0.6931471824645996
Epoch: 4, Step: 200, Loss: 0.6931471824645996
Epoch: 5, Step: 0, Loss: 0.6931471824645996
Epoch: 5, Step: 100, Loss: 0.6931471824645996
Epoch: 5, Step: 200, Loss: 0.6931471824645996
Epoch: 6, Step: 0, Loss: 0.6931471824645996
Epoch: 6, Step: 100, Loss: 0.6931471824645996
Epoch: 6, Step: 200, Loss: 0.693

KeyboardInterrupt: 

In [17]:
summary(disc, (1, 1, 8192))

Layer (type:depth-idx)                   Output Shape              Param #
Disc                                     [1, 1]                    --
├─Sequential: 1-1                        [1, 1]                    --
│    └─LeakyReLU: 2-1                    [1, 1, 8192]              --
│    └─Linear: 2-2                       [1, 1, 4196]              34,377,828
│    └─Conv1d: 2-3                       [1, 1, 4196]              2
│    └─Linear: 2-4                       [1, 1, 1024]              4,297,728
│    └─Linear: 2-5                       [1, 1, 512]               524,800
│    └─LeakyReLU: 2-6                    [1, 1, 512]               --
│    └─Conv1d: 2-7                       [1, 1, 512]               2
│    └─Linear: 2-8                       [1, 1, 64]                32,832
│    └─Linear: 2-9                       [1, 1, 1]                 65
│    └─Flatten: 2-10                     [1, 1]                    --
│    └─LeakyReLU: 2-11                   [1, 1]                

# Export model to ONNX format

In [18]:
# torch.onnx.export(multi_scale.to('cpu'), torch.randn(1, 1, 108486), "MelGAN MultiScale.onnx")

In [19]:
torch.onnx.export(disc.to("cpu"), torch.randn(1, 1, 8192), "Discriminator.onnx")

# Predict

In [20]:
checkpoint_path = glob("model/eval-*.pt")[-1]
checkpoint = torch.load(checkpoint_path, map_location=device)
disc.load_state_dict(checkpoint)

<All keys matched successfully>

In [26]:
disc.eval()
disc.to(device)

predict_list = np.empty(0)

with torch.no_grad():
    for x in tqdm(test_loader):
        x = x.to(device)
        x_multiscale = [k for k in [j for j in x]]
        x_multiscale = torch.stack(x_multiscale).to(device)
        out = disc(x_multiscale).type(torch.float)
        predict_list = np.concatenate((predict_list, out.cpu().numpy().reshape(-1)), axis=0)

100%|██████████| 359/359 [02:22<00:00,  2.53it/s]


In [27]:
predict_list = predict_list.astype(int)
predict_list[:5]

array([0, 0, 0, 0, 0])

In [28]:
predict_list[predict_list == 1]

array([], dtype=int32)