# Training a Mamba-based segmentation model from scratch

## Pre-requisites


- If you want to run this notebook you will only need the AISHELL dataset (or any dataset ready as a `pyannote.database` protocol). Such protocol preparation scripts are available at https://github.com/FrenchKrab/datasets-pyannote.
- You will need a valid plaqntt installation, please check the previous notebook if you have trouble.

## Specifying an architecture

Architectures are specified with a dictionary by specifying the `blocks` parameter of the `SdResBlocks` class. 

For convenience we provide a list of premade architectures in [architectures/blocks/](./architectures/blocks/), including the configuration used in the paper : [mamba_plaqntt.yaml](./architectures/blocks/mamba_plaqntt.yaml).



In [None]:
import yaml

# Load the architecture file
with open('architectures/blocks/mamba_plaqntt.yaml', 'r') as file:
    blocks_architecture = yaml.safe_load(file)
blocks_architecture

Load the pyannote database protocol of your choice.

In [None]:
from pyannote.database import registry, FileFinder

DB_YAML_PATH = '/home/aplaquet/work58/databases/database.yml'
PROTOCOL_NAME = 'AISHELL4.SpeakerDiarization.Custom'

registry.load_database(DB_YAML_PATH)
protocol = registry.get_protocol(PROTOCOL_NAME, preprocessors={"audio": FileFinder()})

Prepare the task. This will influence the model architecture and loss function (it sets the number of features/speakers in the output tensor, and whether to use powerset cross entropy loss or not).

In [None]:
from pyannote.audio.tasks.segmentation.speaker_diarization import SpeakerDiarization

# Create a small model for this tutorial, but you can increase the duration and max number of speakers to match the paper configuration
task = SpeakerDiarization(
    protocol=protocol,
    duration=5.0,
    max_speakers_per_chunk=4,
    max_speakers_per_frame=None,    # use only for low duration models
    batch_size=32,
    balance=['database'],   # useless here, but useful for compound datasets
    num_workers=1,  # set to number of cores
    cache=f'{PROTOCOL_NAME}.protocol.cache'
)

**NOTE:** You will need a recent GPU to initialize this model, the architecture building calls model inference (to estimate parameter count), and the current implementation of Mamba only runs on CUDA for now.

In [None]:
from plaqntt.pyannote_audio.sdresblocks import SdResBlocks
import torch

mamba_diar = SdResBlocks(
    wav2vec='WAVLM_BASE_PLUS',
    wav2vec_layer=-1,
    blocks=blocks_architecture,
    linear={
        "hidden_size": 128,
        "num_layers": 2,
    },
    task=task,
)

In [None]:
task.prepare_data() # preload the protocol and save the cache to disk 
task.setup()

Now prepare the trainer. Here we only train for one epoch but feel free to change things as you see fit.

In [None]:
from pytorch_lightning import Trainer

trainer = Trainer(
    max_epochs=1,
    accelerator='gpu',
    gradient_clip_val=1.0,
)

trainer.validate(mamba_diar)
trainer.fit(mamba_diar)