<a href="https://colab.research.google.com/github/gttae/gitae_githubTest/blob/master/Untitled12.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import deepspeed
import argparse
import random
import pandas as pd
import json
from allennlp.training.metrics import BLEU
from itertools import cycle
from pathlib import Path
import os
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

def get_arguments():
    parser = argparse.ArgumentParser(description='Train T5 on Lakh Midi Dataset Instruments-Lyrics')

    parser.add_argument('--dataset-file', '-df', type=str, required=True,
                        help='Dataset parquet file')

    parser.add_argument('--vocabulary-prefix', '-v', type=str, default='',
                        help='Prefix of the vocab files: <pref>_instrumental.vocab, <prf>_lyric.vocab')

    parser.add_argument('--save-dir', '-sd', type=str, required=True,
                        help='Directory to save checkpoints, states, event logs')

    parser.add_argument('--train-split', '-ts', type=float, default=0.9,
                        help='Percentage of the dataset to use for training')

    parser.add_argument('--epochs', '-e', type=int, default=20,
                        help='Number of epochs')

    parser.add_argument('--validate-every', '-ve', type=int, default=200,
                        help='Validate every n batches')

    parser.add_argument('--generate-every', '-ge', type=int, default=400,
                        help='Generate every n batches')

    parser.add_argument('--print-training-loss-every', '-ptle', type=int, default=20,
                        help='It will average training loss and print it every n steps')

    parser.add_argument('--validate-size', '-vs', type=int, default=40,
                        help='Will calculate average of validation loss for n batches')

    parser.add_argument('--validate-batch-size', '-vss', type=int, default=1,
                        help='Batch size for validation dataset')

    parser.add_argument('--checkpoints-per-epoch', '-cpp', type=int, default=3,
                        help='How many checkpoints to keep per epoch')

    parser.add_argument('--local_rank', type=int, default=-1,
                        help='Local rank passed from distributed launcher')

    parser = deepspeed.add_config_arguments(parser)

    return parser.parse_args()


class MidiDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_file, tokenizer, max_length):
        self.tokenizer = tokenizer
        self.max_length = max_length

        df = pd.read_parquet(dataset_file)
        self.files = list(df['file'])
        self.inputs = self.prepare_input_sequences(df['instrumental'])
        self.targets = self.prepare_target_sequences(df['lyric'])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx], self.files[idx]

    def prepare_input_sequences(self, sequences):
        return [self.tokenize_sequence(seq) for seq in sequences]

    def prepare_target_sequences(self, sequences):
        return [self.tokenize_sequence(seq) for seq in sequences]

    def tokenize_sequence(self, sequence):
        inputs = f"lyric generation: {sequence}"
        return self.tokenizer.encode_plus(inputs, max_length=self.max_length, truncation=True, padding='max_length', return_tensors="pt")


def main():
    args = get_arguments()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    tokenizer = T5Tokenizer.from_pretrained("t5-base")

    dataset = MidiDataset(args.dataset_file, tokenizer, max_length=512)

    train_size = int(args.train_split * len(dataset))
    val_size = len(dataset) - train_size

    torch.manual_seed(0)
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_log_dir = os.path.join(args.save_dir, 'train')
    val_log_dir = os.path.join(args.save_dir, 'val')
    Path(train_log_dir).mkdir(parents=True, exist_ok=True)
    Path(val_log_dir).mkdir(parents=True, exist_ok=True)
    writer_train = SummaryWriter(log_dir=train_log_dir)
    writer_val = SummaryWriter(log_dir=val_log_dir)

    model = T5ForConditionalGeneration.from_pretrained("t5-base").to(device)

    model_engine, optimizer, trainloader, _ = deepspeed.initialize(args=args, model=model, model_parameters=model.parameters(),  training_data=train_dataset, collate_fn=collate_fn_zero_pad)
    device = model_engine.local_rank

    torch.manual_seed(torch.initial_seed())
    val_loader_ = DataLoader(val_dataset, batch_size=args.validate_batch_size, shuffle=True, collate_fn=collate_fn_zero_pad)
    val_loader = cycle(val_loader_)

    num_batches = (len(train_dataset) + trainloader.batch_size - 1) // trainloader.batch_size

    save_every = num_batches // args.checkpoints_per_epoch
    save_at = 0
    saving_steps = []
    for _ in range(args.checkpoints_per_epoch - 1):
        save_at += save_every
        saving_steps.append(save_at)
    saving_steps.append(num_batches - 1)

    print("\n", "Train Dataset - size: {}, batches: {}".format(len(train_dataset), num_batches), "\n")
    print("\n", "Validate Dataset - size: {}, batches: {}".format(len(val_dataset), len(val_loader_)), "\n")

    checkpoint_name, client_state = model_engine.load_checkpoint(args.save_dir, load_module_strict=False)

    if checkpoint_name is not None:
        print("\nLoaded checkpoint: {}\n".format(checkpoint_name))
        i = client_state['i']
        i += 1
        epoch, step = divmod(i, num_batches)
        print("Epoch: {}, step: {}, i: {}".format(epoch, step, i))
        if step == 0:
            print("Starting next epoch...")
            rng = torch.get_rng_state()
            trainloader = iter(trainloader)
        else:
            rng = torch.load(os.path.join(args.save_dir, 'rng_state.pt'))
            torch.set_rng
