In [1]:
import argparse
import librosa

from coco_mulla.models import CoCoMulla
from coco_mulla.utilities import *
from coco_mulla.utilities.encodec_utils import extract_rvq, save_rvq
from coco_mulla.utilities.symbolic_utils import process_midi, process_chord

from coco_mulla.utilities.sep_utils import separate
from config import TrainCfg
import torch.nn.functional as F

device = get_device()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def generate(model_path, batch):
    model = CoCoMulla(TrainCfg.sample_sec,
                      num_layers=args.num_layers,
                      latent_dim=args.latent_dim).to(device)
    model.load_weights(model_path)
    model.eval()
    with torch.no_grad():
        gen_tokens = model(**batch)

    return gen_tokens

In [3]:
def generate_mask(xlen):
    names = ["chord-only", "chord-drums", "chord-midi", "chord-drums-midi"]
    mask = torch.zeros([4, 2, xlen]).to(device)
    mask[1, 1] = 1
    mask[2, 0] = 1
    mask[3] += 1
    return mask, names


In [44]:
def load_data(audio_path, chord_path, midi_path, offset):
    sr = TrainCfg.sample_rate
    res = TrainCfg.frame_res
    sample_sec = TrainCfg.sample_sec

    wav, _ = librosa.load(audio_path, sr=sr, mono=True)
    wav = np2torch(wav).to(device)[None, None, ...]
    wavs = separate(wav, sr)
    drums_rvq = extract_rvq(wavs["drums"], sr=sr)
    chord, _ = process_chord(chord_path)
    flatten_midi_path = midi_path + ".piano.mid"
    midi, _ = process_midi(midi_path)



    chord = crop(chord[None, ...], "chord", sample_sec, res)
    pad_chord = chord.sum(-1, keepdims=True) == 0
    chord = np.concatenate([chord, pad_chord], -1)

    midi = crop(midi[None, ...], "midi", sample_sec, res,offset=offset)
    drums_rvq = crop(drums_rvq[None, ...], "drums_rvq", sample_sec, res, offset=offset)

    chord = torch.from_numpy(chord).to(device).float()
    midi = torch.from_numpy(midi).to(device).float()
    drums_rvq = drums_rvq.to(device).long()

    return drums_rvq, midi, chord


def crop(x, mode, sample_sec, res, offset=0):
    xlen = x.shape[1] if mode == "chord" or mode == "midi" else x.shape[-1]
    sample_len = int(sample_sec * res) + 1
    if xlen < sample_len:
        if mode == "chord" or mode == "midi":
            x = np.pad(x, ((0, 0), (0, sample_len - xlen), (0, 0)))
        else:
            x = F.pad(x, (0, sample_len - xlen), "constant", 0)
        return x

    st = offset * res
    ed = int((offset + sample_sec) * res) + 1
    if mode == "chord" or mode == "midi":
        assert x.shape[1] > st
        return x[:, st: ed]
    assert x.shape[2] > ed
    return x[:, :, st: ed]


def save_pred(output_folder, tags, pred):
    mkdir(output_folder)
    output_list = [os.path.join(output_folder, tag) for tag in tags]
    save_rvq(output_list=output_list, tokens=pred)


def wrap_batch(drums_rvq, midi, chord, cond_mask, prompt):
    num_samples = len(cond_mask)
    midi = midi.repeat(num_samples, 1, 1)
    chord = chord.repeat(num_samples, 1, 1)
    drums_rvq = drums_rvq.repeat(num_samples, 1, 1)
    prompt = [prompt] * num_samples
    batch = {
        "seq": None,
        "desc": prompt,
        "chords": chord,
        "num_samples": num_samples,
        "cond_mask": cond_mask,
        "drums": drums_rvq,
        "piano_roll": midi,
        "mode": "inference",
    }
    return batch


def inference(args):
    drums_rvq, midi, chord = load_data(audio_path=args.audio_path,
                                       chord_path=args.chord_path,
                                       midi_path=args.midi_path,
                                       offset=args.offset)
    print(drums_rvq)
    print('================================================================')
    print(midi)
    print('================================================================')
    print(chord)
    
    cond_mask, names = generate_mask(drums_rvq.shape[-1])
    batch = wrap_batch(drums_rvq, midi, chord, cond_mask, read_lst(args.prompt_path)[0])
    pred = generate(model_path=args.model_path,
                    batch=batch)
    save_pred(output_folder=args.output_folder,
              tags=names,
              pred=pred)


In [5]:
from types import SimpleNamespace
args = {
    "num_layers": 48,
    "latent_dim": 12,
    "output_folder": "/l/users/fathinah.izzati/coco-mulla-repo/demo/output",
    "model_path": "/l/users/fathinah.izzati/coco-mulla-repo/diff_9_end.pth",
    "audio_path": "/l/users/fathinah.izzati/coco-mulla-repo/demo/input/let_it_be.flac",
    "prompt_path": "/l/users/fathinah.izzati/coco-mulla-repo/demo/input/let_it_be.prompt.txt",
    "chord_path": "/l/users/fathinah.izzati/coco-mulla-repo/demo/input/let_it_be.flac.chord.lab",
    "midi_path": "/l/users/fathinah.izzati/coco-mulla-repo/demo/input/let_it_be.mid.piano.mid",
    "drums_path": None,
    "offset": 0
}
args = SimpleNamespace(**args)

In [7]:
drums_rvq, midi, chord = load_data(audio_path=args.audio_path,
                                       chord_path=args.chord_path,
                                       midi_path=args.midi_path,
                                       offset=args.offset)

In [9]:
drums_rvq.shape,midi.shape, chord.shape

(torch.Size([1, 4, 1001]),
 torch.Size([1, 1001, 128]),
 torch.Size([1, 1001, 37]))

In [22]:
cond_mask, names = generate_mask(drums_rvq.shape[-1])

In [38]:
num_samples = len(cond_mask)

In [39]:
midi = midi.repeat(num_samples, 1, 1)
chord = chord.repeat(num_samples, 1, 1)
drums_rvq = drums_rvq.repeat(num_samples, 1, 1)

In [45]:
pred = inference(args)

In [None]:
import argparse
from torch.utils.tensorboard import SummaryWriter

import torch.distributed as dist
from torch.multiprocessing import spawn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
from coco_mulla.utilities.trainer_utils import Trainer

import torch
import torch.nn as nn
import os
from config import TrainCfg
import numpy as np

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from tqdm import tqdm

from coco_mulla.data_loader.cc_dataset_sampler import Dataset, collate_fn
from coco_mulla.models import CoCoMulla

device = "cuda"
N_GPUS = 4


def _get_free_port():
    import socketserver
    with socketserver.TCPServer(('localhost', 0), None) as s:
        return s.server_address[1]


def get_dataset(rid, dataset_split, sampling_strategy, sampling_prob):

    file_lst = ["data/text/musdb18_full.lst",
                "data/text/closed_dataset_fm_full.lst"]
    splits = [
        [1],
        [0],
        [0, 1],
    ]
    dataset = Dataset(
        rid=rid,
        path_lst=[file_lst[i] for i in splits[dataset_split]],
        sampling_prob=sampling_prob,
        sampling_strategy=sampling_strategy,
        cfg=TrainCfg)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=TrainCfg.batch_size,
        collate_fn=collate_fn,
        shuffle=False,
        num_workers=0,
        sampler=DistributedSampler(dataset),
        pin_memory=True,
        drop_last=True)

    return dataset, dataloader


def train_dist(replica_id, replica_count, port, model_dir, args):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)
    torch.distributed.init_process_group('nccl', rank=replica_id, world_size=replica_count)
    device = torch.device('cuda', replica_id)
    torch.cuda.set_device(device)
    model = CoCoMulla(TrainCfg.sample_sec, num_layers=args.num_layers, latent_dim=args.latent_dim).to(device)
    model.set_training()
    model = DDP(model, [replica_id])
    dataset, dataloader = get_dataset(rid=replica_id, dataset_split=args.dataset,
                                      sampling_strategy=args.sampling_strategy,
                                      sampling_prob=[args.sampling_prob_a, args.sampling_prob_b])

    train(replica_id, model, dataset, dataloader, device, model_dir,
          args.learning_rate)


def loss_fn(outputs, y):
    prob = outputs.logits
    mask = outputs.mask
    prob = prob[mask]
    y = y[mask]
    prob = prob.view(-1, 2048)
    return nn.CrossEntropyLoss()(prob, y)


def train(rank, model, dataset, dataloader, device, model_dir, learning_rate):
    # optimizer and lr scheduler
    num_steps = len(dataloader)
    epochs = TrainCfg.epoch
    rng = np.random.RandomState(569 + rank * 100)
    if rank == 0:
        writer = SummaryWriter(model_dir, flush_secs=20)

    trainer = Trainer(params=model.parameters(), lr=learning_rate, num_epochs=epochs, num_steps=num_steps)

    model = model.to(device)
    step = 0
    for e in range(0, epochs):
        mean_loss = 0
        n_element = 0
        model.train()

        dl = tqdm(dataloader, desc=f"Epoch {e}") if rank == 0 else dataloader
        r = rng.randint(0, 233333)
        dataset.reset_random_seed(r, e)
        for i, batch in enumerate(dl):
            desc = batch["desc"]
            mix = batch["mix"].to(device).long()
            drums = batch["drums"].to(device).long()
            chords = batch["chords"].to(device).float()
            piano_roll = batch["piano_roll"].to(device).float()
            cond_mask = batch["cond_mask"].to(device).long()

            batch_1 = {
                "seq": mix,
                "drums": drums,
                "chords": chords,
                "piano_roll": piano_roll,
                "cond_mask": cond_mask,
                "desc": desc,

            }
            # with autocast:
            outputs = model(**batch_1)
            r_loss = loss_fn(outputs, mix.long())

            grad_1, lr_1 = trainer.step(r_loss, model.parameters())

            step += 1
            n_element += 1
            if rank == 0:
                writer.add_scalar("r_loss", r_loss.item(), step)
                writer.add_scalar("grad_1", grad_1, step)
                writer.add_scalar("lr_1", lr_1, step)

            mean_loss += r_loss.item()

        mean_loss = mean_loss / n_element
        if rank == 0:
            with torch.no_grad():
                writer.add_scalar('train/mean_loss', mean_loss, step)
                model.module.save_weights(os.path.join(model_dir, f"diff_{e}_end.pth"))


def main(args):
    experiment_folder = args.experiment_folder
    experiment_name = args.experiment_name

    if not os.path.exists(experiment_folder):
        os.mkdir(experiment_folder)
    model_dir = os.path.join(experiment_folder, experiment_name)
    if not os.path.exists(model_dir):
        os.mkdir(model_dir)
    world_size = N_GPUS
    port = _get_free_port()
    spawn(train_dist, args=(world_size, port, model_dir, args), nprocs=world_size, join=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('-d', '--experiment_folder', type=str)
    parser.add_argument('-n', '--experiment_name', type=str)
    parser.add_argument('-l', '--num_layers', type=int)
    parser.add_argument('-t', '--text_path', type=str, default=None)
    parser.add_argument('-r', '--latent_dim', type=int)
    parser.add_argument('-lr', '--learning_rate', type=float)
    parser.add_argument('-s', '--sampling_strategy', type=str)
    parser.add_argument('-a', '--sampling_prob_a', type=float, default=0.)
    parser.add_argument('-b', '--sampling_prob_b', type=float, default=0.)
    parser.add_argument('-ds', '--dataset', type=int, default=0)

    args = parser.parse_args()
    main(args)


Cocomulla's repository is tructured into three parts:
1. Input preparation 
This part includes loading the audio file, converting it to RVQ, processing the chord labels, and loading the MIDI file.
2. COCO-Mulla model
The COCO-Mulla model is implemented in the `coco_mulla/models/coco_mulla.py` file. This model consists of three main components: the encoder, the decoder, and the conditioning module. The encoder takes the RVQ representation of the drums and chords, and the conditioning module combines them to produce the input to the decoder. The decoder then generates the piano roll and drum samples.
3. Inference
The inference part is implemented in the `coco_mulla/inference.py` file. This script takes a pre-trained COCO-Mulla model(from above step) and a MIDI file as input, and generates a piano roll and drum sample.
The provided code snippet is a part of the training process. The `train_dist` function is responsible for initializing the distributed training environment, loading the dataset, and training the COCO-Mulla model. The `loss_fn` function calculates the cross-entropy loss between the predicted drum samples and the ground truth drum samples. The `train` function is the main training loop, which iterates over the dataset, performs forward and backward passes, and logs the training progress using

There are four scenarios:
1. if conditioned by the chord only
2. if conditioned by chord and midi
3. if conditioned by chord and drums
4. if conditioned by chord, midi, and drums

Each of the output has piano roll and drum sample.

I'm lookinh to change the conditioned input to video with a different encoder ( video encoder )
In order to do so, I need to train the RVQ module with video input data.
question: how to train the RVQ module for video?


In [1]:
## Motion encoder
import yaml
import importlib

def get_obj_from_str(string, reload=False):
    module, cls = string.rsplit(".", 1)
    if reload:
        module_imp = importlib.import_module(module)
        importlib.reload(module_imp)
    return getattr(importlib.import_module(module, package=None), cls)


def instantiate_from_config(config):
    if "target" not in config:
        raise KeyError("Expected key `target` to instantiate.")
    return get_obj_from_str(config["target"])(**config.get("params", dict()))

In [2]:
# Load the YAML file
with open("/l/users/fathinah.izzati/coco-mulla-repo/vaura/logs/24-08-01T08-34-26/vgg-9cb-viscond-avclip_delayed-llama-ib_03/hparams.yaml", "r") as file:
    hparams = yaml.safe_load(file)

# Extract the feature_extractor_config
feature_extractor_config = hparams['feature_extractor_config']
feature_extractor_config['target']

'vaura.models.modules.feature_extractors.avclip.motionformer.MotionFormer'

In [3]:
visual_feature_extractor = instantiate_from_config(feature_extractor_config)
visual_feature_extractor.eval()
visual_feature_extractor.requires_grad_(False)

  from .autonotebook import tqdm as notebook_tqdm


MotionFormer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  )
  (patch_embed_3d): PatchEmbed3D(
    (proj): Conv3d(3, 768, kernel_size=(2, 16, 16), stride=(2, 16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): DividedSpaceTimeBlock(
      (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
      (attn): DividedAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (timeattn): DividedAttention(
        (qkv): Linear(in_features=768, out_features=2304, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2):

In [4]:
visual_bridge = instantiate_from_config({'target': 'torch.nn.Identity'})

In [5]:
DEVICE = "cuda"
MODEL_MAX_DURATION = 2.56  # do not modify
DURATION = 2.56  # n * 0.64s
STRIDE = 1.28 

In [6]:
from omegaconf import OmegaConf
dl_cfg = OmegaConf.load("./vaura/data/demo/dataloader_config.yaml")
dl_cfg["sample_duration"] = DURATION
OmegaConf.resolve(dl_cfg)  # resolve durations

In [7]:
from vaura.utils.train_utils import get_datamodule_from_type

In [8]:
datamodule = get_datamodule_from_type("motionformer_gen", dl_cfg)
datamodule.setup("test")
dataloader = datamodule.test_dataloader()



In [9]:
dl_cfg

{'batch_size': 1, 'num_workers': 0, 'path_to_metadata': 'vaura/data/demo', 'gen_videos_filepath': None, 'assert_fps': False, 'crop': False, 'partition_video_to_clips': True, 'sample_duration': 2.56, 'audio_transforms_test': [{'target': 'vaura.models.data.transforms.audio_transforms.AudioStereoToMono', 'params': {'keepdim': True}}, {'target': 'vaura.models.data.transforms.audio_transforms.AudioResample', 'params': {'target_sr': 44100, 'clip_duration': 2.56}}, {'target': 'vaura.models.data.transforms.audio_transforms.AudioTrim', 'params': {'duration': 2.56, 'sr': 44100}}], 'video_transforms_test': [{'target': 'vaura.models.data.transforms.video_transforms.Permute', 'params': {'permutation': [0, 3, 1, 2]}}, {'target': 'vaura.models.data.transforms.video_transforms.UniformTemporalSubsample', 'params': {'target_fps': 25, 'clip_duration': 2.56}}, {'target': 'vaura.models.data.transforms.video_transforms.Permute', 'params': {'permutation': [0, 2, 3, 1]}}, {'target': 'torchvision.transforms.v2

In [10]:
# Resolve generation parameters
MODEL_MAX_DURATION = 2.56  # do not modify
COMPRESSION_MODEL_FRAME_RATE = 86  # do not modify


total_gen_len = int(DURATION * COMPRESSION_MODEL_FRAME_RATE)
stride_tokens = int(STRIDE * COMPRESSION_MODEL_FRAME_RATE)

In [11]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x15327bcb5a90>

In [12]:
from tqdm import tqdm

for sample in tqdm(dataloader):

    frames = sample["frames"]
    vis_feats, _ = visual_feature_extractor(frames)
    B, S, Tv, D = vis_feats.shape
    vis_feats = vis_feats.reshape(B, S * Tv, D)
    vis_feats = vis_feats.detach()
    vis_feats = visual_bridge(vis_feats)

  0%|          | 0/3 [00:00<?, ?it/s]ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/xK-7W3ZPd3o_94000_104000.mp4'
ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/76UZQRJq028_181000_191000.mp4'
ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/xK-7W3ZPd3o_94000_104000.mp4'
ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/xK-7W3ZPd3o_94000_104000.mp4'
ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/Vi7kQhNcaOs_114000_124000.mp4'
ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/xK-7W3ZPd3o_94000_104000.mp4'
ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/Vi7kQhNcaOs_114000_124000.mp4'
ERROR:vaura.models.data.video_dataset:[Errno 2] No such file or directory: 'data/demo/Vi7kQhNcaOs_114000_124000.mp4'
ERROR:vaura.models.data.video_d

RuntimeError: Video could not be loaded correctly. Tried 10 times.