<a href="https://colab.research.google.com/github/buganart/descriptor-transformer/blob/main/descriptor_model_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@markdown Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Check GPU, should be a Tesla V100
!nvidia-smi -L
import os
print(f"We have {os.cpu_count()} CPU cores.")

In [None]:
#@markdown Mount google drive
from google.colab import drive
from google.colab import output
drive.mount('/content/drive')

from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    raise RuntimeError(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

def clear_on_success(msg="Ok!"):
    if _exit_code == 0:
        output.clear()
        print(msg)

In [None]:
#@markdown Install wandb and log in
%pip install wandb
output.clear()
import wandb
from pathlib import Path
wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login
output.clear()
print("ok!")

In [None]:
#@title Configuration

#@markdown Directories can be found via file explorer on the left by navigating into `drive` to the desired folders. 
#@markdown Then right-click and `Copy path`.
audio_db_dir = "/content/drive/My Drive/AUDIO DATABASE/MUSIC TRANSFORMER/Transformer Corpus" #@param {type:"string"}
# audio_db_dir = "/content/drive/My Drive/AUDIO DATABASE/TESTING" #@param {type:"string"}
experiment_dir = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/colab-violingan/descriptor-model" #@param {type:"string"}

#@markdown ### Resumption of previous runs
#@markdown Optional resumption arguments below, leaving both empty will start a new run from scratch. 
#@markdown - The ID can be found on wandb. 
#@markdown - It's 8 characters long and may contain a-z letters and digits (for example `1t212ycn`).

#@markdown Resume a previous run 
resume_run_id = "4gn7g6xq" #@param {type:"string"}

#@markdown train argument
window_size = 15 #@param {type: "integer"}
learning_rate = 1e-4 #@param {type: "number"}
batch_size = 64 #@param {type: "integer"}
epochs = 3000 #@param {type: "integer"}

# log_interval = 10 #@param {type: "integer"}
save_interval = 10 #@param {type: "integer"}
# n_test_samples = 8 #@param {type: "integer"}

notes = "" #@param {type: "string"}

import re
from pathlib import Path
from argparse import Namespace

audio_db_dir = Path(audio_db_dir)
experiment_dir = Path(experiment_dir)


for path in [experiment_dir]:
    path.mkdir(parents=True, exist_ok=True)

if not audio_db_dir.exists():
    raise RuntimeError(f"audio_db_dir {audio_db_dir} does not exists.")

def check_wandb_id(run_id):
    if run_id and not re.match(r"^[\da-z]{8}$", run_id):
        raise RuntimeError(
            "Run ID needs to be 8 characters long and contain only letters a-z and digits.\n"
            f"Got \"{run_id}\""
        )

check_wandb_id(resume_run_id)

colab_config = {
    "audio_db_dir": audio_db_dir,
    "experiment_dir": experiment_dir,
    "resume_run_id": resume_run_id,
    "window_size": window_size,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": epochs,
    "save_interval": save_interval,
    "notes": notes,
}

for k, v in colab_config.items():
    print(f"=> {k:20}: {v}")

config = Namespace(**colab_config)
config.seed = 1234

config.descriptor_size=5
config.hidden_size=100
config.num_layers=3

In [None]:
%pip install pytorch-lightning
clear_on_success()

#load descriptor files

In [None]:
import json
import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import pytorch_lightning as pl

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.loggers import WandbLogger
import pprint

class DataModule_descriptor(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.data_path = config.audio_db_dir
        self.attribute_list = None
        self.dataset_input = None
        self.dataset_target = None
        self.window_size = config.window_size
        self.batch_size = config.batch_size
        
    # def prepare_data(self):
    #     pass
    def setup(self, stage=None):
        window_size = self.window_size
        filepath_list = self.data_path.rglob("*.*")
        # check files in filepath_list is supported (by extensions)
        filepath_list = [
            path
            for path in filepath_list
            if Path(path).suffix == ".txt"
        ]

        attribute_list = []
        dataset_input = []
        dataset_target = []
        # process files in the filepath_list
        for path in tqdm.tqdm(filepath_list, desc="Descriptor Files"):
            
            with open(path) as json_file:
                data = json.load(json_file)
                data_list = []
                for des in data:
                    timestamp = next(iter(des))
                    descriptor = des[timestamp]
                    if len(attribute_list) == 0:
                        attribute_list = descriptor.keys()
                        attribute_list = sorted(attribute_list)
                    values = []
                    for k in attribute_list:
                        values.append(float(descriptor[k]))
                    data_list.append((int(timestamp), values))
                #sort value by timestamp
                sorted_data = sorted(data_list)
                #convert data into descriptor array
                des_array = [j for (i,j) in sorted_data]
                des_array = np.array(des_array)
                num_des = des_array.shape[0]
                #pack descriptors into batches based on window_size
                input_array = []
                target_array = []
                for i in range(num_des - window_size - 1):
                    input_batch = des_array[i:i+window_size]
                    target_batch = des_array[i+1+window_size]
                    target_batch = target_batch[np.newaxis,:]
                    input_array.append(input_batch)
                    target_array.append(target_batch)

                #add processed array to dataset
                dataset_input.append(input_array)
                dataset_target.append(target_array)

        self.attribute_list = attribute_list
        self.dataset_input = np.concatenate(dataset_input, axis=0)
        self.dataset_target = np.concatenate(dataset_target, axis=0)
    def train_dataloader(self):
        batch_size = self.batch_size
        dataset = TensorDataset(torch.tensor(self.dataset_input, dtype=torch.float32), torch.tensor(self.dataset_target, dtype=torch.float32))
        dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=8)
        return dataloader

#model
only simple RNN is implemented.
TODO: add more complicated time-series models (such as transformer)

In [None]:
class DescriptorModel(pl.LightningModule):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters("config")
        descriptor_size = config.descriptor_size=5
        hidden_size = config.hidden_size=100
        num_layers = config.num_layers=3

        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(descriptor_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, descriptor_size)
        self.loss_function = nn.MSELoss()

    def forward(self, x):
        batch_size = x.shape[0]
        h = (torch.zeros(self.num_layers, batch_size, self.hidden_size),
            torch.zeros(self.num_layers, batch_size, self.hidden_size))
        x, _ = self.lstm(x, h)
        x = self.linear(x)
        return x

    def training_step(self, batch, batch_idx):
        data, target = batch
        output = self(data)
        pred = output[:,-1,:].unsqueeze(1)

        loss = self.loss_function(pred, target)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.config.learning_rate)

    def predict(self, data, step):
        all_descriptors = data
        batch_size, window_size, des_size = data.shape
        for i in range(step):
            input = all_descriptors[:,i:,:]
            # print("input", input)
            pred = self(input)
            new_descriptor = pred[:,1,:].reshape(batch_size, 1, des_size)
            # print("new_descriptor", new_descriptor)
            all_descriptors = torch.cat((all_descriptors, new_descriptor), 1)
        return all_descriptors.detach().cpu().numpy()[:,-step:,:]

#helper function

In [None]:
class SaveWandbCallback(Callback):
    def __init__(self, log_interval, save_model_path):
        super().__init__()
        self.epoch = 0
        self.log_interval = log_interval
        self.save_model_path = save_model_path

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        if self.epoch % self.log_interval == 0:
            # log
            trainer.save_checkpoint(self.save_model_path)
            save_checkpoint_to_cloud(self.save_model_path)
        self.epoch += 1

# function to save/load files from wandb
def save_checkpoint_to_cloud(checkpoint_path):
    wandb.save(checkpoint_path)


def load_checkpoint_from_cloud(checkpoint_path="model_dict.pth"):
    checkpoint_file = wandb.restore(checkpoint_path)
    return checkpoint_file.name

#######################         train functions

def save_model_args(config, run):
    filepath = str(Path(run.dir).absolute() / "model_args.json")

    config = vars(config)
    config_dict = {}
    for k in config.keys():
        config_dict[k] = str(config[k])
    with open(filepath, "w") as fp:
        json.dump(config_dict, fp)
    save_checkpoint_to_cloud(filepath)

def init_wandb_run(config, run_dir="./", mode="run"):
    resume_run_id = config.resume_run_id
    entity = "demiurge"
    run_dir = Path(run_dir).absolute()

    if resume_run_id:
        run_id = resume_run_id
    else:
        run_id = wandb.util.generate_id()

    run = wandb.init(
        project="descriptor_model",
        id=run_id,
        entity=entity,
        resume=True,
        dir=run_dir,
        mode=mode,
    )

    print("run id: " + str(wandb.run.id))
    print("run name: " + str(wandb.run.name))
    wandb.watch_called = False
    # run.tags = run.tags + (selected_model,)
    return run

def setup_datamodule(config):
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)

    dataModule = DataModule_descriptor(config)
    return dataModule

def setup_model(config, run):
    checkpoint_path = str(Path(run.dir).absolute() / "checkpoint.ckpt")

    if config.resume_run_id:
        # Download file from the wandb cloud.
        load_checkpoint_from_cloud(checkpoint_path="checkpoint.ckpt")
        extra_trainer_args = {"resume_from_checkpoint": checkpoint_path}
        model = DescriptorModel.load_from_checkpoint(checkpoint_path)
    else:
        extra_trainer_args = {}
        model = DescriptorModel(config)

    return model, extra_trainer_args

def train(config, run, model, dataModule, extra_trainer_args):
    np.random.seed(config.seed)
    torch.manual_seed(config.seed)

    # wandb logger setup
    wandb_logger = WandbLogger(
        experiment=run, log_model=True, save_dir=Path(run.dir).absolute()
    )

    # log config
    wandb.config.update(config)
    save_model_args(config, run)
    pprint.pprint(vars(config))

    checkpoint_path = str(Path(run.dir).absolute() / "checkpoint.ckpt")
    callbacks = [SaveWandbCallback(config.save_interval, checkpoint_path)]

    trainer = pl.Trainer(
        max_epochs=config.epochs,
        logger=wandb_logger,
        callbacks=callbacks,
        default_root_dir=wandb.run.dir,
        checkpoint_callback=None,
        **extra_trainer_args,
    )

    # train
    trainer.fit(model, dataModule)

#Train

In [None]:
run = init_wandb_run(config, run_dir=experiment_dir)#, mode="offline")
datamodule = setup_datamodule(config)
model, extra_trainer_args = setup_model(config, run)
# model = DescriptorModel(config)
extra_trainer_args = {}
train(config, run, model, datamodule, extra_trainer_args)