In [1]:
import torch
from pprint import pprint
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
import torchaudio
from SongDataset import SongDataset
from TranscriptionModel import GuitarModel
from torch import nn
import os
import logging
from timeit import default_timer as timer
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from TUtils import random_string
from Tokenizer import GuitarTokenizer
from SongDataset import ArrangementUtils
import wandb
from torchmetrics.classification import MulticlassF1Score
from torch.profiler import profile, record_function, ProfilerActivity


# check if tensorflow is working correctly
plt.rcParams['figure.figsize'] = [12, 8]
print(f"Cuda : {torch.cuda.is_available()}")

os.environ['WANDB_MODE'] = 'offline'

Cuda : True


In [2]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#TODO: load this from dataset file
SAMPLE_RATE = 44100
DatasetFile = "./Trainsets/massive.hdf5"

In [3]:
from SongDataset import GuitarCollater

mel_spectrogram = torchaudio.transforms.MelSpectrogram(
    sample_rate=SAMPLE_RATE,
    n_fft=2048,
    hop_length=512,
    n_mels=128
)
dataset = SongDataset(DatasetFile, sampleRate=SAMPLE_RATE)
collate_fn = GuitarCollater(dataset.pad_token)

In [4]:
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
BATCH_SIZE = 64
NUM_ENCODER_LAYERS = 6
NUM_DECODER_LAYERS = 6
LEARNING_RATE = 0.0001
epoch = 1

In [5]:

model = GuitarModel((BATCH_SIZE, 2, 128, 87),
                    mel_spectrogram,
                    emb_size=EMB_SIZE,
                    num_encoder_layers=NUM_ENCODER_LAYERS,
                    num_decoder_layers=NUM_DECODER_LAYERS,
                    multi_head_attention_size=NHEAD,
                    dim_feedforward=FFN_HID_DIM,
                    tgt_vocab_size=dataset.vocabSize)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
model = model.to(DEVICE)

# loss_fn = torch.nn.CrossEntropyLoss(ignore_index=dataset.pad_token)
loss_fn = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)
metric = MulticlassF1Score(num_classes=int(dataset.vocabSize)).to(DEVICE)

In [6]:
class CheckpointSaver:
    def __init__(self, dirpath, decreasing=True, top_n=5):
        """
        dirpath: Directory path where to store all model weights
        decreasing: If decreasing is `True`, then lower metric is better
        top_n: Total number of models to track based on validation metric value
        """
        if not os.path.exists(dirpath): os.makedirs(dirpath)
        self.dirpath = dirpath
        self.top_n = top_n
        self.decreasing = decreasing
        self.top_model_paths = []
        self.best_metric_val = np.Inf if decreasing else -np.Inf
        self.run_id = random_string()

    def __call__(self, model, epoch, metric_val,optimizer,force_save=False):
        model_path = os.path.join(self.dirpath, model.__class__.__name__ + f'_epoch{epoch}_run{self.run_id}.pt')
        save = metric_val<self.best_metric_val if self.decreasing else metric_val>self.best_metric_val
        if save or force_save:
            logging.info(f"Current metric value better than {metric_val} better than best {self.best_metric_val}, saving model at {model_path}")
            self.best_metric_val = metric_val
            torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            # 'loss': loss,
            }, model_path)
            self.top_model_paths.append({'path': model_path, 'score': metric_val})
            self.top_model_paths = sorted(self.top_model_paths, key=lambda o: o['score'], reverse=not self.decreasing)
        if len(self.top_model_paths)>self.top_n:
            self.cleanup()

    def cleanup(self):
        to_remove = self.top_model_paths[self.top_n:]
        logging.info(f"Removing extra models.. {to_remove}")
        for o in to_remove:
            os.remove(o['path'])
        self.top_model_paths = self.top_model_paths[:self.top_n]

def save_model(PATH):
    torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    # 'loss': loss,
    }, PATH)

def load_model(PATH):
    global model
    global optimizer
    global epoch
    checkpoint = torch.load(PATH)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']+1


In [7]:
train_set, val_set = torch.utils.data.random_split(dataset, [0.9,0.1], generator=torch.Generator().manual_seed(47))

In [8]:
def train_epoch(model, optimizer, epoch):
    model.train()
    losses = 0
    count = 0
    f1_accuracy = 0
    train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE,shuffle=True,num_workers=4,collate_fn=collate_fn)

    for spectrogram, tuning, tokens,oneHot in (pbar := tqdm(train_dataloader,desc=f"Epoch {epoch}")):
        spectrogram = spectrogram.to(DEVICE)
        tuning = tuning.to(DEVICE)
        tokens = tokens.to(DEVICE)
        oneHot = oneHot.to(DEVICE)

        tokens_input = tokens[:, :-1]
        tokens_expected = tokens[:, 1:]

        target_mask, token_padding_mask = model.create_masks(tokens_input)

        target_mask = target_mask.to(DEVICE)
        token_padding_mask = token_padding_mask.to(DEVICE)

        logits = model(spectrogram, tuning, tokens_input, target_mask, token_padding_mask)
        # logits = logits.permute(1, 2, 0)
        # output = torch.argmax(p,dim=2,keepdim=True).squeeze(-1)
        # output = output.permute(1, 0)
        # expected_output = torch.nn.functional.one_hot(tokens_expected, num_classes=dataset.vocabSize)
        # expected_output = expected_output.permute(1,0,2).type(torch.float)
        # -------------------------------------------

        output = torch.nn.functional.softmax(logits, dim=2)
        loss = loss_fn(output, oneHot)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses += loss.detach().item()
        # f1_accuracy += metric(output,oneHot)
        count+=1
        metrics_to_save = {
            "train_loss": losses/count,
        #     "f1_acc_train":f1_accuracy/count
        }
        if count % config["LogInterval"] == 0:
            wandb.log(metrics_to_save)
        pbar.set_postfix(metrics_to_save)
        # break

    return losses / len(train_dataloader)


def evaluate(model,epoch,checkpoint_saver,optimizer):
    model.eval()
    losses = 0
    count=0
    f1_accuracy = 0

    val_dataloader = DataLoader(val_set, batch_size=BATCH_SIZE,shuffle=True,num_workers=4,collate_fn=collate_fn)
    for spectrogram, tuning, tokens,oneHot in (pbar := tqdm(val_dataloader,desc=f"Eval {epoch}")):
        spectrogram = spectrogram.to(DEVICE)
        tuning = tuning.to(DEVICE)
        tokens = tokens.to(DEVICE)
        oneHot = oneHot.to(DEVICE)

        tokens_input = tokens[:, :-1]
        tokens_expected = tokens[:, 1:]

        target_mask, token_padding_mask = model.create_masks(tokens_input)
        target_mask = target_mask.to(DEVICE)
        token_padding_mask = token_padding_mask.to(DEVICE)
        logits = model(spectrogram, tuning, tokens_input, target_mask, token_padding_mask)
        # p = torch.nn.functional.softmax(logits, dim=2)
        # output = torch.argmax(p,dim=2,keepdim=True).squeeze(-1)
        # output = output.permute(1, 0)
        output = torch.nn.functional.softmax(logits, dim=2)
        loss = loss_fn(output, oneHot)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses += loss.detach().item()
        f1_accuracy += metric(output,oneHot)
        count+=1
        metrics_to_save = {
            "val_loss": losses/count,
            "f1_acc_val":f1_accuracy/count
        }
        if count % config["LogInterval"] == 0:
            wandb.log(metrics_to_save)
        pbar.set_postfix(metrics_to_save)
        break

    losses = losses/len(val_dataloader)
    checkpoint_saver(model, epoch, losses,optimizer,True)
    return losses


In [9]:
# load_model("./model_weights/massive_first_test_0epoch.pt")

In [None]:
NUM_EPOCHS = 1
checkpoint_saver = CheckpointSaver(dirpath='./model_weights', decreasing=True, top_n=5)
config = {
    "EMB_SIZE":EMB_SIZE,
    "NHEAD":NHEAD,
    "FFN_HID_DIM":FFN_HID_DIM,
    "BATCH_SIZE":BATCH_SIZE,
    "NUM_ENCODER_LAYERS":NUM_ENCODER_LAYERS,
    "NUM_DECODER_LAYERS":NUM_DECODER_LAYERS,
    "VocabSize":dataset.vocabSize,
    "TotalSize":dataset.size,
    "MaxTokens":dataset.maxTokens,
    "NumberOfTimeTokensPerSecond":dataset.numberOfTimeTokensPerSecond,
    "SpectrogramSizeInSeconds":dataset.spectrogramSizeInSeconds,
    "SampleRate":dataset.sample_rate,
    "LearningRate":LEARNING_RATE,
    "DataSetFile":DatasetFile,
    "LogInterval":1000
}

with profile(activities=[
        ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True) as prof:
        with wandb.init(project="MusicTranscription",config=config,notes=f"CheckpointID : {checkpoint_saver.run_id}"):
            wandb.watch(model, log_freq=config["LogInterval"],log_graph=True,criterion=loss_fn)
            for epoch in range(epoch, NUM_EPOCHS+1):
                start_time = timer()
                with record_function("model_inference"):
                    train_loss = train_epoch(model, optimizer,epoch)
                end_time = timer()
                # val_loss = evaluate(model,epoch,checkpoint_saver,optimizer)
                # print(f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s")

print(prof.key_averages().table(sort_by="cuda_time_total"))

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


Epoch 1:   0%|          | 0/254902 [00:11<?, ?it/s]

In [11]:
def predict(Tokenizer,model,spectrogram : torch.tensor,tuning : torch.tensor,arrangement : str,capo : float):
    model.eval()
    SOS_token = Tokenizer.sosToken
    EOS_token = Tokenizer.eosToken
    max_length = dataset.maxTokens
    y_input = torch.tensor([[SOS_token]], dtype=torch.long, device=DEVICE)
    tuningAndArrangement = ArrangementUtils.getArrangementTensor(tuning,arrangement,capo)
    tuningAndArrangement = torch.unsqueeze(tuningAndArrangement,dim=0).to(DEVICE)
    spectrogram = torch.unsqueeze(spectrogram,dim=0)
    spectrogram = spectrogram.to(DEVICE)
    for _ in range(max_length):
        # Get source mask
        target_mask, token_padding_mask = model.create_masks(y_input)
        target_mask = target_mask.to(DEVICE)
        token_padding_mask = token_padding_mask.to(DEVICE)

        # token_padding_mask = torch.unsqueeze(token_padding_mask,dim=0)
        pred = model(spectrogram, tuningAndArrangement, y_input, target_mask, token_padding_mask)

        next_item = pred.topk(1)[1].view(-1)[-1].item() # num with highest probability
        next_item = torch.tensor([[next_item]], device=DEVICE)

        # Concatenate previous input with predicted best word
        y_input = torch.cat((y_input, next_item), dim=1)

        # Stop if model predicts end of sentence
        if next_item.view(-1).item() == EOS_token:
            break

    return y_input.view(-1).tolist()

def get_spectrogram(filepath,location_in_secs):
    info = torchaudio.info(filepath)
    file_sample_rate = info.sample_rate
    file_start_offset = int(location_in_secs * file_sample_rate)
    file_number_samples_to_read = int(dataset.spectrogramSizeInSeconds * file_sample_rate)
    waveform, sample_rate = torchaudio.load(filepath, normalize=True, frame_offset=file_start_offset,
                                            num_frames=file_number_samples_to_read)
    if waveform.size(1) != file_number_samples_to_read:
        # print(song, index, sectionIndex)
        return None
        # raise Exception("Read Less than expected")


    if sample_rate != dataset.sample_rate:
        waveform = torchaudio.functional.resample(waveform, orig_freq=file_sample_rate, new_freq=dataset.sample_rate)
    # waveform = mel_spectrogram(waveform)
    return waveform

def get_waveform_all(filepath):
    waveform, sample_rate = torchaudio.load(filepath, normalize=True)
    if sample_rate != dataset.sample_rate:
        waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=dataset.sample_rate)
    # waveform = mel_spectrogram(waveform)
    file_number_samples_to_read = int(dataset.spectrogramSizeInSeconds * dataset.sample_rate)
    return torch.split(waveform,file_number_samples_to_read,dim=1)

def predict_from_file(filename,location_in_time):
    Tokenizer = GuitarTokenizer(dataset.spectrogramSizeInSeconds,dataset.numberOfTimeTokensPerSecond)
    spectrogram = get_spectrogram(filename,location_in_time)
    tokens = predict(Tokenizer,model,spectrogram,ArrangementUtils.DSharp_Standard,"lead",0)
    for i in tokens:
        pprint(Tokenizer.encoder.decode(i))

def predict_entire_file(filename):
    Tokenizer = GuitarTokenizer(dataset.spectrogramSizeInSeconds,dataset.numberOfTimeTokensPerSecond)
    for x in get_waveform_all(filename):
        tokens = predict(Tokenizer,model,x,ArrangementUtils.DSharp_Standard,"lead",0)
        for token in tokens:
            print(Tokenizer.encoder.decode(token))

In [None]:
predict_entire_file(r"C:\Users\ritwi\Github\MusicTranscription\Downloads2\S_Tier\greewelc_p\greewelc.ogg")

In [None]:
predict_from_file(r"C:\Users\ritwi\Github\MusicTranscription\Downloads2\S_Tier\greewelc_p\greewelc.ogg",25.0)

In [None]:
save_model("model_weights/massive_first_test_0epoch.pt")