In [None]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

from tasks import SegmentationKnowledgeDistillationTask
from pyannote.audio.models.segmentation import PyanNet
import pytorch_lightning as pl
from pyannote.database import registry
from pyannote.audio import Model, Inference
import torch

# Load AMI dataset
registry.load_database("AMI-diarization-setup/pyannote/database.yml")
ami = registry.get_protocol("AMI.SpeakerDiarization.word_and_vocalsounds")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load teacher model
teacher = Model.from_pretrained("pyannote/segmentation")
teacher.to(device)

In [None]:
# Get parameter size of teacher and student
task = SegmentationKnowledgeDistillationTask(ami, teacher=teacher)
teacher_size = sum(p.numel() for p in teacher.parameters())
student_size = sum(
    p.numel()
    for p in PyanNet(
        task=task, sincnet={"stride": 10}, lstm={"num_layers": 1}
    ).parameters()
)
print(f"Teacher model size: {teacher_size}")
print(f"Student model size: {student_size}")

In [None]:
import pandas as pd
from utils import benchmark

TRAIN = False

if TRAIN:
    results = []
    regularizations = [
        1,
        5,
        10,
        50,
        100,
        500,
        1000,
        5000,
        10000,
        50000,
        100000,
        500000,
        1000000,
    ]
    for kd_regularization in regularizations:
        task = SegmentationKnowledgeDistillationTask(
            ami, teacher=teacher, kd_regularization=kd_regularization
        )
        student = PyanNet(task=task, sincnet={"stride": 10}, lstm={"num_layers": 1})
        student.to(device)
        trainer = pl.Trainer(devices=1, max_epochs=5)
        trainer.fit(student)
        trainer.save_checkpoint("models/segmentation/student.ckpt")

        results.append(
            benchmark("pyannote/embedding", "models/segmentation/student.ckpt")
        )

    result_list = []
    for i, result in enumerate(results):
        result_list.append(
            {
                "kd_regularization": regularizations[i],
                "diarization_error_rate": result["diarization error rate"]["%"].TOTAL,
                "false_alarm": result["false alarm"]["%"].TOTAL,
            }
        )

    df = pd.DataFrame(result_list)
    df.to_csv("results/regularization_segmentation.csv")
else:
    df = pd.read_csv("results/regularization_segmentation.csv")

kd_regularization = df.loc[df["diarization_error_rate"].idxmin()]["kd_regularization"]

In [None]:
NUM_EPOCHS = 10

TRAIN = True

if TRAIN:
    task = SegmentationKnowledgeDistillationTask(
        ami, teacher=teacher, kd_regularization=kd_regularization
    )
    student = PyanNet(task=task, sincnet={"stride": 10}, lstm={"num_layers": 1})
    student.to(device)
    trainer = pl.Trainer(devices=1, max_epochs=NUM_EPOCHS)
    trainer.fit(student)
    trainer.save_checkpoint("models/segmentation/student.ckpt")

    benchmark_student = benchmark(
        "pyannote/embedding", "models/segmentation/student.ckpt"
    )
    benchmark_teacher = benchmark("pyannote/embedding", "pyannote/segmentation")

    results = []
    results.append(
        {
            "diarization_error_rate": benchmark_teacher["diarization error rate"][
                "%"
            ].TOTAL,
            "false_alarm": benchmark_teacher["false alarm"]["%"].TOTAL,
            "model": "teacher",
        }
    )
    results.append(
        {
            "diarization_error_rate": benchmark_student["diarization error rate"][
                "%"
            ].TOTAL,
            "false_alarm": benchmark_student["false alarm"]["%"].TOTAL,
            "model": "student",
        }
    )

    df = pd.DataFrame(results)
    df.to_csv("results/segmentation.csv")
else:
    df = pd.read_csv("results/segmentation.csv")

print(df)