In [None]:
## uncomment if using colab
# !git clone https://github.com/olijacklu/MLSuperb-Project.git

# !cp -r /content/MLSuperb-Project/requirements.txt /content/
# !cp -r /content/MLSuperb-Project/config/ /content/
# !cp -r /content/MLSuperb-Project/data/ /content/
# !cp -r /content/MLSuperb-Project/evaluation/ /content/
# !cp -r /content/MLSuperb-Project/models/ /content/
# !cp -r /content/MLSuperb-Project/training/ /content/

In [None]:
!pip install -r requirements.txt

In [None]:
import os
import json
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import wandb

from config.config import TRAIN_PAIRS, TORCH_DEFAULT_TYPE
from data.preprocess import preprocess_data
from data.dataset import ASRDataset
from data.utils import data_loaders_and_vocab
from models.utils import load_model, clean_memory
from training.monolingual import train_and_evaluate_monolingual
from training.multilingual import train_and_evaluate_multilingual
from evaluation.test import test_model
from evaluation.analysis import analyze_layer_weights

In [None]:
base_dir = '/content/drive/MyDrive/MVA/NLP/AlgorithmsSpeechNLP' # Important: Specify the path to the directory where the data is stored and where you wish to save any results

torch.set_default_dtype(TORCH_DEFAULT_TYPE)

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
print(f"Using device: {device}")

In [None]:
datasets = preprocess_data()

with open(f'{base_dir}/ml_superb_dataset.json', 'w') as f:
    json.dump(datasets, f, indent=2)

print(f"Found {len(datasets)} language-source pairs")

In [None]:
with open(f'{base_dir}/ml_superb_dataset.json', 'r') as f:
    datasets = json.load(f)

print(f"Loaded {len(datasets)} language-source pairs")

In [None]:
# get dataloaders for french monolingual data
pair = TRAIN_PAIRS['fra1']
_, _, _, char_mappings = data_loaders_and_vocab(datasets, pair, batch_size=32, device=device)

In [None]:
# fine tuned monolingual models dirs
dir_ft_model_base = os.path.join(base_dir, 'reproduced_models/hubert_base/monolingual/')
dir_ft_model_base_lora = os.path.join(base_dir, 'lora_models/hubert_base/monolingual/')
dir_ft_model_q_lora = os.path.join(base_dir, 'qlora_models/hubert_base/monolingual/')

In [None]:
# login to wandb
wandb.login()

**Experiments using the model fine-tuned with the paper approach**

In [None]:
model = torch.load(os.path.join(dir_ft_model_base, 'model.pt'))
model = model.to(device)

In [None]:
wandb.init(project="ml-superb", name=f"monolingual_base_fr1")

In [None]:
test_model(
            model=model,
            feature_extractor=feature_extractor,
            datasets=datasets,
            char_mappings=char_mappings,
            model_type="monolingual",
            data_pair=data_pair,
            device=device,
            num_samples=50,
        )

In [None]:
wandb.finish()

In [None]:
del model
clean_memory()

**Experiments with the lora-ft model**

In [None]:
model = torch.load(os.path.join(dir_ft_model_base_lora, 'model.pt'))
model = model.to(device)

In [None]:
wandb.init(project="ml-superb", name=f"monolingual_lora_fr1")

In [None]:
test_model(
            model=model,
            feature_extractor=feature_extractor,
            datasets=datasets,
            char_mappings=char_mappings,
            model_type="monolingual",
            data_pair=data_pair,
            device=device,
            num_samples=50,
        )

In [None]:
wandb.finish()

In [None]:
del model
clean_memory()

**Experiments with q-lora-ft model**

In [None]:
model = torch.load(os.path.join(dir_ft_model_q_lora, 'model.pt'))
model = model.to(device)

In [None]:
wandb.init(project="ml-superb", name=f"monolingual_lora_fr1")

In [None]:
test_model(
            model=model,
            feature_extractor=feature_extractor,
            datasets=datasets,
            char_mappings=char_mappings,
            model_type="monolingual",
            data_pair=data_pair,
            device=device,
            num_samples=50,
        )

In [None]:
wandb.finish()

In [None]:
del model
clean_memory()