In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import ipywidgets as widgets

dataset_list = ['GTZAN',
                'EmoMusic',
                'Deezer',
                'MagnaTagATune']

menu = widgets.RadioButtons(
    options=dataset_list,
    description='Dataset:',
    disabled=False)

menu

RadioButtons(description='Dataset:', options=('GTZAN', 'EmoMusic', 'Deezer', 'MagnaTagATune'), value='GTZAN')

In [5]:
from source.datasets.fast_datasets import *
from source.utils.load_utils import *
from source.datasets.pretrain_datasets import MSDDatasetPretrain

if menu.value == 'GTZAN':
    classification = True
    num_classes = 10
    dataset = GTZANFastDataset()
elif menu.value == 'EmoMusic':
    classification = False
    num_outputs = 2
    dataset = EmoMusicFastDataset()
elif menu.value == 'Deezer':
    classification = False
    num_outputs = 2
    dataset = DeezerFastDataset(length=5000)

    
train_dataloader, val_dataloader = split_and_load(dataset, workers=4, batch_size=4, split_size=0.75)
song_dataset = MSDDatasetPretrain(length=2000)

In [6]:
from source.models_task_specific.mb_classification import MusicBertClassifier
from source.models_task_specific.mb_regression import MusicBertRegression

evals = 0.2

if classification:
    print("Classification Task!")
    teacher = MusicBertClassifier(num_classes, RNN=False, num_encoder_layers=4).cuda()
else:
    print("Regression Task!")
    teacher = MusicBertRegression(num_outputs, RNN=False, num_encoder_layers=4).cuda()
    
# teacher.load_pretrained()

Regression Task!


In [7]:
teacher.train_model(train_dataloader, val_dataloader, epochs = 100, eval_per_epoch=evals)

HBox(children=(IntProgress(value=0, description='Train (0/10 Epoch) - Loss...', max=1390, style=ProgressStyle(…

Eval Loss (1387 steps) -0.5250              


In [18]:
%matplotlib widget

from source.utils.plot_utils import *

loss = smooth(teacher.loss_curve.cpu().numpy(), 20)
plot_curve(loss, 1, color="red")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [17]:
%matplotlib widget

plot_curve(teacher.validation_curve.cpu().numpy(), 1, color="red")
# plot_curve(student.validation_curve.cpu().numpy(), 1, color="green")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [10]:
### Generate Pseudo-Labels from Teacher
from tqdm.notebook import tqdm
import torch

for i, item in enumerate(tqdm(song_dataset)):
    
    pseudo_out = teacher(item).squeeze().cpu()
    
    if classification:
        pseudo_out = torch.argmax(pseudo_out).unsqueeze(0).float()

    
    song_dataset.set_pseudo_label(i, pseudo_out)


HBox(children=(IntProgress(value=0, max=2000), HTML(value='')))




In [11]:
from torch.utils.data import ConcatDataset


# new_dataset = ConcatDataset([song_dataset, genre_dataset])
extra_dataloader = just_load(song_dataset, workers=4, batch_size=4)

In [12]:
if classification:
    student = MusicBertClassifier(num_classes, RNN=False, num_encoder_layers=4).cuda()
else:
    student = MusicBertRegression(num_outputs, RNN=False, num_encoder_layers=4).cuda()

# student.load_pretrained()

for _ in range(2):
    student.train_model(extra_dataloader, val_dataloader, epochs = 10, eval_per_epoch=evals)
    student.train_model(train_dataloader, val_dataloader, epochs = 10, eval_per_epoch=evals)

HBox(children=(IntProgress(value=0, description='Train (0/10 Epoch) - Loss...', max=5000, style=ProgressStyle(…

Eval Loss (4997 steps) -0.4608              


HBox(children=(IntProgress(value=0, description='Train (0/10 Epoch) - Loss...', max=1390, style=ProgressStyle(…

Eval Loss (1387 steps) 0.0006               


HBox(children=(IntProgress(value=0, description='Train (0/10 Epoch) - Loss...', max=5000, style=ProgressStyle(…

Eval Loss (4997 steps) -0.5316              


HBox(children=(IntProgress(value=0, description='Train (0/10 Epoch) - Loss...', max=1390, style=ProgressStyle(…

Eval Loss (1387 steps) -0.4556              


In [13]:
from source.utils.generic_utils import allDone

allDone()

In [14]:
%matplotlib widget

from source.utils.plot_utils import *

loss = smooth(student.loss_curve.cpu().numpy(), 100)
plot_curve(loss, 1, color="red")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [16]:
%matplotlib widget

plot_curve(teacher.validation_curve.cpu().numpy(), 1, color="red")
plot_curve(student.validation_curve.cpu().numpy(), 1, color="green")

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …