In [None]:
import warnings

warnings.simplefilter(action="ignore", category=FutureWarning)
warnings.simplefilter(action="ignore", category=UserWarning)
warnings.simplefilter(action="ignore", category=RuntimeWarning)

from tasks import EmbeddingKnowledgeDistillationTask
import pytorch_lightning as pl
import torch
from pyannote.database import registry
from pyannote.audio import Model
from models import ReducedEmbedding

# Load AMI dataset
registry.load_database("AMI-diarization-setup/pyannote/database.yml")
ami = registry.get_protocol("AMI.SpeakerDiarization.word_and_vocalsounds")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load teacher model
teacher = Model.from_pretrained("pyannote/embedding")
teacher.to(device)

In [None]:
# Get parameter size of teacher and student
teacher_size = sum(p.numel() for p in teacher.parameters())
student_size = sum(p.numel() for p in ReducedEmbedding().parameters())
print(f"Teacher model size: {teacher_size}")
print(f"Student model size: {student_size}")

# Print teacher architecture
print(teacher)

In [None]:
import pandas as pd
from utils import benchmark

TRAIN = False

if TRAIN:
    results = []
    regularizations = [
        1,
        5,
        10,
        50,
        100,
        500,
        1000,
        5000,
        10000,
        50000,
        100000,
        500000,
        1000000,
    ]
    for kd_regularization in regularizations:
        task = EmbeddingKnowledgeDistillationTask(
            ami, teacher=teacher, kd_regularization=kd_regularization
        )
        student = ReducedEmbedding(task=task)
        student.to(device)
        trainer = pl.Trainer(devices=1, max_epochs=5)
        trainer.fit(student)
        trainer.save_checkpoint("models/embedding/student.ckpt")

        results.append(
            benchmark("models/embedding/student.ckpt", "pyannote/segmentation")
        )

    result_list = []
    for i, result in enumerate(results):
        result_list.append(
            {
                "kd_regularization": regularizations[i],
                "diarization_error_rate": result["diarization error rate"]["%"].TOTAL,
                "false_alarm": result["false alarm"]["%"].TOTAL,
            }
        )

    df = pd.DataFrame(result_list)
    df.to_csv("results/regularization_embedding.csv")
else:
    df = pd.read_csv("results/regularization_embedding.csv")

kd_regularization = df.loc[df["diarization_error_rate"].idxmin()]["kd_regularization"]

In [None]:
NUM_EPOCHS = 10

TRAIN = True

if TRAIN:
    task = EmbeddingKnowledgeDistillationTask(
        ami, teacher=teacher, kd_regularization=kd_regularization
    )
    student = ReducedEmbedding(task=task)
    student.to(device)
    trainer = pl.Trainer(devices=1, max_epochs=NUM_EPOCHS)
    trainer.fit(student)
    trainer.save_checkpoint("models/embedding/student.ckpt")

    benchmark_student = benchmark(
        "models/embedding/student.ckpt", "pyannote/segmentation"
    )
    benchmark_teacher = benchmark("pyannote/embedding", "pyannote/segmentation")

    results = []
    results.append(
        {
            "diarization_error_rate": benchmark_teacher["diarization error rate"][
                "%"
            ].TOTAL,
            "false_alarm": benchmark_teacher["false alarm"]["%"].TOTAL,
            "model": "teacher",
        }
    )
    results.append(
        {
            "diarization_error_rate": benchmark_student["diarization error rate"][
                "%"
            ].TOTAL,
            "false_alarm": benchmark_student["false alarm"]["%"].TOTAL,
            "model": "student",
        }
    )

    df = pd.DataFrame(results)
    df.to_csv("results/embedding.csv")
else:
    df = pd.read_csv("results/embedding.csv")

print(df)