In [2]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets

import torch.nn.utils.prune as prune
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio import transforms
from torch.utils.data import random_split

In [3]:
# 여러가지 설정값들

SEED = 9814

batch_size = 32
EPOCHS = 50

n_fft = 4096
win_length = 400
hop_length = 160
n_mels = 80
# n_mfcc = 40
sr = 16000

input = n_mels
output = 10

In [None]:
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

print("Using PyTorch version:", torch.__version__,' Device:', DEVICE)


In [3]:
class Net(nn.Module): # nn.Module은 모든 neural network의 base class라고 한다. 
    def __init__(self, input_size, output_size, hidden_sizes):
        super(Net, self).__init__()
        self.fc1_2 = nn.Linear(input_size, hidden_sizes[0])
        self.fc2_3 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.fc3_4 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.fc4_5 = nn.Linear(hidden_sizes[2], hidden_sizes[3])
        self.fc5_6 = nn.Linear(hidden_sizes[3], output_size)

    def forward(self,x):
        x = self.fc1_2(x)
        x = F.relu(x)
        x = self.fc2_3(x)
        x = F.relu(x)
        x = self.fc3_4(x)
        x = F.relu(x)
        x = self.fc4_5(x)
        x = F.relu(x)
        out = self.fc5_6(x)

        out = F.log_softmax(out, dim=1)
        return out

In [None]:
# 여러가지 utility 함수들

def make_metadata_file(metadata_path, target_path, type):
    if type == "train":
        typeid = 1
    if type == "dev":
        typeid = 2
    if type == "test":
        typeid = 3
    speaker_id_set = set()
    cherrypick_list = []
    with open(metadata_path, "r") as f:
        for line in f:
            strid, path = line.split()
            if int(strid) == typeid:
                speaker_id_set.add(int(path.split("/")[0][-3:]))
                if len(speaker_id_set) == 11: break
                cherrypick_list.append(line)

    with open(target_path, "w+") as f: # 기존 내용은 지워짐
        f.write("".join(cherrypick_list))

    return len(speaker_id_set)

def get_num_speakers(path):
    speaker_id_set = set()
    with open(path, "r") as f:
        for line in f:
            speaker_id_set.add(int(line.split()[1].split("/")[0][-3:]))

    print(speaker_id_set)
    return len(speaker_id_set)

# torchaudio의 transforms.MFCC를 이용한 mfcc 추출 함수 반환
def get_mfcc_transform(sr, n_mfcc, n_fft, n_mels, win_length, hop_length):
    return transforms.MFCC(
        sample_rate=sr,
        n_mfcc=n_mfcc,
        melkwargs={
            "n_fft": n_fft,
            "n_mels": n_mels,
            "win_length": win_length,
            "hop_length": hop_length,
            "mel_scale": "htk"
        })

# torchaudio의 transforms.MelSpectrogram을 이용한 mel 추출 함수 반환
def get_mels_transform(sr, n_fft, n_mels, win_length, hop_length):
    return transforms.MelSpectrogram(
        sample_rate=sr,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        n_mels=n_mels
        )

def normalize_dataset(x, mean=None, std=None):
    if mean is None:
        mean = np.mean(x, axis=0)
    if std is None:
        std = np.std(x, axis=0)
        
    x = (x - mean) / std

    return x, mean, std

In [None]:
make_metadata_file("iden_split.txt", "train_moredata_list2.txt", "train")
make_metadata_file("iden_split.txt", "val_list2.txt", "dev")
make_metadata_file("iden_split.txt", "test_list2.txt", "test")

In [None]:
# voxceleb 데이터셋 다운 및 dataset 객체 얻기
train_dataset = torchaudio.datasets.VoxCeleb1Identification(root="./", subset="train", download=True, meta_url="train_moredata_list2.txt")
test_dataset = torchaudio.datasets.VoxCeleb1Identification(root="./", subset="test", download=True, meta_url="test_list2.txt")

train_dataset_size = int(len(train_dataset) * 0.80)
validation_dataset_size = len(train_dataset) - train_dataset_size

generator = torch.Generator().manual_seed(SEED)
train_dataset, val_dataset = random_split(train_dataset, [train_dataset_size, validation_dataset_size], generator=generator)

print("train size: ", train_dataset.__len__())
print("val size: ", val_dataset.__len__())
print("test size: ", test_dataset.__len__())

In [6]:
class MelDataset(Dataset):
    def __init__(self, dataloader, type, mean=0, std=0):
        self.dataloader = dataloader
        self.x_train = []
        self.y_train = []
        self.mean = mean
        self.std = std
        self.sr = dataloader.dataset.__getitem__(0)[1]
        print("detected sample rate: ", self.sr)
        self.get_features = nn.Sequential(
            get_mels_transform(self.sr, n_fft, n_mels, win_length, hop_length),
            torchaudio.transforms.AmplitudeToDB(),
        )
        print(f"====================== ({type}) Generating Mel dataset ======================")
        for idx, sample in enumerate(dataloader):
            wave = torch.flatten(sample[0]) # Tensor
            id = sample[2] # Tensor
            mels = self.get_features(wave)
            mels = mels.transpose(0, 1) # (len, n_mfcc)으로 변경

            # mels = torch.mean(mels, dim=0)
            
            converted_mel_list = []
            for i in range(mels.shape[0] // 25):
                converted_mel_list.append(torch.mean(mels[i*25:i*25+50, :], dim=0))
            labels = [id] * len(converted_mel_list) # id를 len만큼 늘림
            self.x_train += converted_mel_list
            self.y_train += labels
        
        if type == "train":
            self.x_train, mean, std = normalize_dataset(self.x_train)
            self.mean = mean
            self.std = std
        else:
            self.x_train, _, _ = normalize_dataset(self.x_train, self.mean, self.std)
        print(f"====================== ({type}) x: ", np.shape(self.x_train), "y: ", np.shape(self.y_train), " ======================")
    
    def __getitem__(self, index):
        return self.x_train[index], self.y_train[index]

    def __len__(self):
        return len(self.x_train)

In [None]:
# 훈련에 쓰일 mels feature DataLoader 생성

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=4, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, prefetch_factor=2)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4, prefetch_factor=2)

train_meldataset = MelDataset(train_dataloader, "train")
val_meldataset = MelDataset(val_dataloader, "val", train_meldataset.mean, train_meldataset.std)
test_meldataset = MelDataset(test_dataloader, "test", train_meldataset.mean, train_meldataset.std)

mels_train_dataset = DataLoader(train_meldataset, batch_size=batch_size, shuffle=True)
mels_val_dataset = DataLoader(val_meldataset, batch_size=batch_size, shuffle=True)
mels_test_dataset = DataLoader(test_meldataset, batch_size=batch_size, shuffle=True)

for batch in mels_val_dataset:
    print(batch)
    print(f"mels shape: {batch[0].shape}")
    print(f"id shape: {batch[1].shape}")

    break

In [None]:
# audio backend가 있는 지 확인
print(str(torchaudio.list_audio_backends()))

In [None]:
from DCMLP_r2 import DCMLPr2
from DCMLP_r3 import DCMLPr3
from DCMLP_r3_deep import DeepDCMLPr3
from MLP_r2 import MLPr2
from MLP_r3 import MLPr3
from MLP_r3 import DeepMLPr3
from MLP_deep import DeepMLP

def lr_lambda(step):
    decay_steps = 5 * 10**6 # 5 million
    decay_factor = 0.1
    return decay_factor ** (step // decay_steps)

# model = DCMLPr2(input, output, (512, 256, 256, 256)).to(DEVICE)
# model = DCMLPr3(input, output, (256, 256, 256, 256)).to(DEVICE)
model = DeepDCMLPr3(input, output, 128).to(DEVICE)
# model = MLPr3(input, output, (512, 256, 256, 256)).to(DEVICE)
# model = MLPr2(input, output, (512, 256, 256, 256)).to(DEVICE)
# model = DeepMLP(input, output, 128).to(DEVICE)
# model = DeepMLPr3(input, output, 128).to(DEVICE)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
criterion = nn.CrossEntropyLoss()

print(model)

In [11]:
def train_with_decay(model, train_loader, optimizer, log_interval, scheduler, epoch):
    model.train()
    train_loss_sum = train_correct = train_total = 0
    total_train_batch = len(train_loader)

    for batch_idx, (mels, label) in enumerate(train_loader):
        mels = mels.to(DEVICE)
        label = torch.flatten(label) - 3 # label을 1차원으로 바꾸고 3~12번 사이의 label들을 0~9 범위로 맞춤
        label = label.to(DEVICE)
        optimizer.zero_grad()
        output = model(mels)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        scheduler.step()

        train_loss_sum += loss.item()
        train_total += label.size(0)
        train_correct += ((torch.argmax(output, 1) == label)).sum().item()
        
        if batch_idx % log_interval == 0:
            print("Train epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}".format(
                epoch, batch_idx * len(mels),
                len(train_loader.dataset), 100. * batch_idx / len(train_loader),
                loss.item()))
            
    train_avg_loss = train_loss_sum / total_train_batch
    train_avg_acc = 100 * train_correct / train_total

    return (train_avg_loss, train_avg_acc)

In [12]:
def evaluate(model, test_loader, criterion):
    model.eval()
    test_loss_sum = test_correct = test_total = 0
    total_test_batch = len(test_loader)
    
    with torch.no_grad():
        for mels, label in test_loader:
            mels = mels.to(DEVICE)
            label = torch.flatten(label) - 3 # label을 1차원으로 바꾸고 3~12번 사이의 label들을 0~9 범위로 맞춤
            label = label.to(DEVICE)
            output = model(mels)
            
            test_loss_sum += criterion(output, label).item()

            test_total += label.size(0)
            test_correct += ((torch.argmax(output, 1) == label)).sum().item()
    
    test_avg_loss = test_loss_sum / total_test_batch
    test_avg_acc = 100 * test_correct / test_total
    return (test_avg_loss, test_avg_acc)

In [13]:
prune_rate = [0.1, 0.4, 0.6, 0.8]

# MLP-r=2

# parameters_to_prune = (
#     (model.fc1_2, 'weight'),

#     (model.fc2_3_i, 'weight'),

#     (model.fc2_3_ii, 'weight'),

#     (model.fc3_4_i, 'weight'),

#     (model.fc3_4_ii, 'weight'),

#     (model.fc4_5, 'weight'),

#     (model.fc5_6, 'weight'),

# )

# MLP-r=3
# parameters_to_prune = (
#     (model.fc1_2, 'weight'),
#     (model.fc2_3_i, 'weight'),
#     (model.fc2_3_ii, 'weight'),
# 	(model.fc2_3_iii, 'weight'),
#     (model.fc3_4_i, 'weight'),
#     (model.fc3_4_ii, 'weight'),
#     (model.fc4_5, 'weight'),
#     (model.fc5_6, 'weight'),
# )

# DeepMLP-r=3
# parameters_to_prune = (
#     (model.layers["fc1_2"], 'weight'),
#     (model.layers["fc2_3"], 'weight'),
#     (model.layers["fc3_4"], 'weight'),
#     (model.layers["fc4_5"], 'weight'),
#     (model.layers["fc5_6"], 'weight'),
#     (model.layers["fc6_7"], 'weight'),
#     (model.layers["fc7_8"], 'weight'),
#     (model.layers["fc8_9"], 'weight'),
#     (model.layers["fc9_10"], 'weight'),
#     (model.layers["fc10_11"], 'weight'),
#     (model.layers["fc11_12"], 'weight'),
#     (model.layers["fc12_13"], 'weight'),
#     (model.layers["fc13_14"], 'weight'),
#     (model.layers["fc14_15"], 'weight'),
#     (model.layers["fc2_3_ii"], 'weight'),
#     (model.layers["fc2_3_iii"], 'weight'),
#     (model.layers["fc6_7_ii"], 'weight')
# )

# DCMLP-r=2
# parameters_to_prune = (
#     (model.fc1_2, "weight"),
#     (model.fc2_3, "weight"),
#     (model.fc2_4, "weight"),
#     (model.fc3_4, "weight"),
#     (model.fc3_5, "weight"),
#     (model.fc4_5, "weight"),
#     (model.fc5_6, "weight"),
# )

# DCMLP-r=3
# parameters_to_prune = (
#     (model.fc1_2, "weight"),
#     (model.fc2_3, "weight"),
#     (model.fc2_4, "weight"),
# 	(model.fc2_5, "weight"),
#     (model.fc3_4, "weight"),
#     (model.fc3_5, "weight"),
#     (model.fc4_5, "weight"),
#     (model.fc5_6, "weight"),
# )

# DeepDCMLPr3
# parameters_to_prune = (
#     (model.layers["fc1_2"], "weight"),
#     (model.layers["fc2_3"], "weight"),
#     (model.layers["fc3_4"], "weight"),
#     (model.layers["fc4_5"], "weight"),
#     (model.layers["fc5_6"], "weight"),
#     (model.layers["fc6_7"], "weight"),
#     (model.layers["fc7_8"], "weight"),
#     (model.layers["fc8_9"], "weight"),
#     (model.layers["fc9_10"], "weight"),
#     (model.layers["fc10_11"], "weight"),
#     (model.layers["fc11_12"], "weight"),
#     (model.layers["fc12_13"], "weight"),
#     (model.layers["fc13_14"], "weight"),
#     (model.layers["fc14_15"], "weight"),
#     (model.layers["fc2_10_i"], "weight"),
#     (model.layers["fc2_14_i"], "weight"),
#     (model.layers["fc6_14_i"], "weight")
# )

# DeepDCMLP-r=2toAny
# parameters_to_prune = (
#     (model.layers["fc1_2"], "weight"),
#     (model.layers["fc2_3"], "weight"),
#     (model.layers["fc3_4"], "weight"),
#     (model.layers["fc4_5"], "weight"),
#     (model.layers["fc5_6"], "weight"),
#     (model.layers["fc6_7"], "weight"),
#     (model.layers["fc7_8"], "weight"),
#     (model.layers["fc8_9"], "weight"),
#     (model.layers["fc9_10"], "weight"),
#     (model.layers["fc10_11"], "weight"),
#     (model.layers["fc11_12"], "weight"),
#     (model.layers["fc12_13"], "weight"),
#     (model.layers["fc13_14"], "weight"),
#     (model.layers["fc14_15"], "weight"),
#     (model.layers["fc2_4"], "weight"),
#     (model.layers["fc2_5"], "weight"),
#     (model.layers["fc2_6"], "weight"),
#     (model.layers["fc2_7"], "weight"),
#     (model.layers["fc2_8"], "weight"),
#     (model.layers["fc2_9"], "weight"),
#     (model.layers["fc2_10"], "weight"),
#     (model.layers["fc2_11"], "weight"),
#     (model.layers["fc2_12"], "weight"),
#     (model.layers["fc2_13"], "weight"),
#     (model.layers["fc2_14"], "weight")
# )

# DeepDCMLP-r=Anyto14
parameters_to_prune = (
    (model.layers["fc1_2"], "weight"),
    (model.layers["fc2_3"], "weight"),
    (model.layers["fc2_14"], "weight"),
    (model.layers["fc3_4"], "weight"),
    (model.layers["fc3_14"], "weight"),
    (model.layers["fc4_5"], "weight"),
    (model.layers["fc4_14"], "weight"),
    (model.layers["fc5_6"], "weight"),
    (model.layers["fc5_14"], "weight"),
    (model.layers["fc6_7"], "weight"),
    (model.layers["fc6_14"], "weight"),
    (model.layers["fc7_8"], "weight"),
    (model.layers["fc7_14"], "weight"),
    (model.layers["fc8_9"], "weight"),
    (model.layers["fc8_14"], "weight"),
    (model.layers["fc9_10"], "weight"),
    (model.layers["fc9_14"], "weight"),
    (model.layers["fc10_11"], "weight"),
    (model.layers["fc10_14"], "weight"),
    (model.layers["fc11_12"], "weight"),
    (model.layers["fc11_14"], "weight"),
    (model.layers["fc12_13"], "weight"),
    (model.layers["fc12_14"], "weight"),
    (model.layers["fc13_14"], "weight"),
    (model.layers["fc14_15"], "weight")
)


flag=0

In [None]:
train_loss_list = []
train_acc_list = []

val_loss_list = []
val_acc_list = []

for epoch in range(1, EPOCHS + 1):
    train_avg_loss, train_avg_acc = train_with_decay(model, mels_train_dataset, optimizer, log_interval = 200, scheduler=scheduler, epoch=epoch)
    train_loss_list.append(train_avg_loss)
    train_acc_list.append(train_avg_acc)
    
    val_loss, val_accuracy = evaluate(model, mels_val_dataset, criterion)
    val_loss_list.append(val_loss)
    val_acc_list.append(val_accuracy)

    print("\n[EPOCH: {}], \tTrain Loss: {:.4f}, \tTrain Accuracy: {:.2f}, \tTest Loss: {:.4f}, \tTest Accuracy: {:.2f} %\n".format(
        epoch, train_avg_loss, train_avg_acc, val_loss, val_accuracy))
    
    if epoch%(EPOCHS//5) == 0:
        if flag >3:
            continue
        print('prune_step')
        if flag == 0:
            amount_prune = prune_rate[flag]
        else:
            amount_prune = 1-(1/(1-prune_rate[flag-1])*(1-prune_rate[flag]))
        print('amount to prune: ', amount_prune)
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=amount_prune,
        )
        flag += 1
    

In [None]:
print(f"min train loss: {min(train_loss_list)}")
print(f"max train acc: {max(train_acc_list)}")
print(f"min val loss: {min(val_loss_list)}")
print(f"max val acc: {max(val_acc_list)}")

In [19]:
# plotting train, val results

def plot_trend(train_loss_list, train_acc_list, val_loss_list, val_acc_list):
	plt.title('Loss trend')
	plt.xlabel('epochs')
	plt.ylabel('loss')
	plt.grid()

	plt.plot(train_loss_list, label='train_loss')
	plt.plot(val_loss_list, label='val_loss')
	plt.legend(loc='best')

	plt.show()

	plt.title('Accuracy trend')
	plt.xlabel('epochs')
	plt.ylabel('accuracy')
	plt.grid()

	plt.plot(train_acc_list, label='train_acc')
	plt.plot(val_acc_list, label='val_acc')
	plt.legend(loc='best')

	plt.show()

In [None]:
plot_trend(train_loss_list, train_acc_list, val_loss_list, val_acc_list)