In [None]:
pip --quiet install datasets transformers soundfile librosa evaluate jiwer

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m94.6 MB/s[0m eta [36m0:00:00[0m
[?25h

## PARTE 1: IMPORT e CARICAMENTO del dataset LibriSpeech

In [None]:
from datasets import load_dataset, Dataset, Features, Value, Audio, IterableDataset
import datasets
import torchaudio
import torchaudio.transforms as T
from torchaudio.datasets import LIBRISPEECH
import torch
import torch.nn as nn
from torch.utils.data import Subset
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.nn import CTCLoss
from torch.utils.data import DataLoader
from jiwer import wer, cer
import time
import random
import re
import os
from pprint import pprint
from torch.amp import autocast, GradScaler
from torch import amp
import torch.nn.functional as F

In [None]:
data_dir = "/content/data"

train_dataset = LIBRISPEECH(root=data_dir, url="train-clean-100", download=True)
eval_dataset = LIBRISPEECH(root=data_dir, url="dev-clean", download=True)
test_dataset = LIBRISPEECH(root=data_dir, url="test-clean", download=True)
print("Caricamento dataset LibriSpeech avvenuto!")

In [None]:
def generate_examples(data_path):
    examples = []
    for root, _, files in os.walk(data_path):
        for file in files:
            if file.endswith(".trans.txt"):
                with open(os.path.join(root, file), "r", encoding="utf-8") as f:
                    for line in f:
                        parts = line.strip().split(" ", 1)
                        if len(parts) < 2:
                            continue
                        file_id, text = parts
                        audio_path = os.path.join(root, file_id + ".flac")
                        if os.path.exists(audio_path):
                            examples.append({
                                "id": file_id,
                                "audio": audio_path,
                                "text": text,
                            })
    return examples

train_examples = generate_examples("/content/data/LibriSpeech/train-clean-100")
dev_examples = generate_examples("/content/data/LibriSpeech/dev-clean")
test_examples = generate_examples("/content/data/LibriSpeech/test-clean")

features = Features({
    "id": Value("string"),
    "audio": Audio(sampling_rate=16000),
    "text": Value("string"),
})

train_dataset = Dataset.from_list(train_examples).cast_column("audio", Audio(sampling_rate=16000))
dev_dataset = Dataset.from_list(dev_examples).cast_column("audio", Audio(sampling_rate=16000))
test_dataset = Dataset.from_list(test_examples).cast_column("audio", Audio(sampling_rate=16000))

In [None]:
if torch.cuda.is_available():
    print(f"GPU disponibile: {torch.cuda.get_device_name(0)}")
    print(f"Memoria GPU: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
    print("GPU non disponibile, usando CPU")

print(f"Device utilizzato: {device}")

## PARTE 2: VOCABOLARIO

In [None]:
def normalize_text(text):
    text = text.upper()
    text = re.sub(r"[^A-Z ]+", "", text)
    text = re.sub(r"\s+", " ", text)
    return text.strip()

def build_vocab_dict(dataset):
    normalized_texts = [normalize_text(t) for t in dataset["text"]]

    all_text = " ".join(normalized_texts)
    unique_chars = sorted(set(all_text) - {" "})
    unique_chars.append("|")

    vocab_dict = {c: i for i, c in enumerate(unique_chars)}

    vocab_dict["<blank>"] = len(vocab_dict)
    vocab_dict["<pad>"] = len(vocab_dict)
    vocab_dict["<unk>"] = len(vocab_dict)

    return vocab_dict

vocab_dict = build_vocab_dict(train_dataset)
print(vocab_dict)

## PARTE 3: PRE_PROCESSING

In [None]:
class SimpleProcessor:
    def __init__(self, vocab_dict, target_sampling_rate=16000, augment=True):
        self.vocab = vocab_dict
        self.target_sr = target_sampling_rate
        self.augment = augment
        self._resamplers = {}

    def get_resampler(self, orig_sr):
        if orig_sr == self.target_sr:
            return None
        if orig_sr not in self._resamplers:
            self._resamplers[orig_sr] = torchaudio.transforms.Resample(orig_sr, self.target_sr)
        return self._resamplers[orig_sr]

    def preprocess_audio(self, audio_array, orig_sr):
        audio = torch.tensor(audio_array, dtype=torch.float32)
        resampler = self.get_resampler(orig_sr)
        if resampler:
            audio = resampler(audio)
        return (audio - audio.mean()) / (audio.std() + 1e-5)

    def tokenize_text(self, text):
        text = text.replace(" ", "|")
        return [self.vocab.get(c, self.vocab["<unk>"]) for c in text]

    def __call__(self, audio, sampling_rate, text=None):
        inp = self.preprocess_audio(audio, sampling_rate)
        out = {"input_values": inp}

        if text is not None:
            out["labels"] = torch.tensor(self.tokenize_text(text), dtype=torch.long)
        return out

processor = SimpleProcessor(vocab_dict)

def preprocess(batch):
    audio = batch["audio"]["array"]
    sr = batch["audio"]["sampling_rate"]
    text = batch["text"]

    processed = processor(audio, sampling_rate=sr, text=text)
    return {
        "input_values": processed["input_values"],
        "labels": processed["labels"]
    }

print("Avvio preprocessing ottimizzato..")

train_dataset_processed = train_dataset.map(
    preprocess,
    remove_columns=train_dataset.column_names,
    num_proc=1,
)

eval_dataset_processed = dev_dataset.map(
    preprocess,
    remove_columns=dev_dataset.column_names,
    num_proc=1,
)

test_dataset_processed = test_dataset.map(
    preprocess,
    remove_columns=test_dataset.column_names,
    num_proc=1,
)

print("Fine preprocessing!")

### Esempio preprocessing

In [None]:
#esempio per mostrare il preProcessing
print(train_dataset_processed[0].keys())
print(train_dataset_processed[0]["input_values"][:20])
print(train_dataset_processed[0]["labels"][:20])

print(eval_dataset_processed[0]["input_values"][:20])
print(eval_dataset_processed[0]["labels"][:20])

print(test_dataset_processed[0]["input_values"][:20])
print(test_dataset_processed[0]["labels"][:20])

# In base all'output accade che:

# vocabolario: {"'": 0, 'A': 1, 'B': 2, 'C': 3, 'D': 4, 'E': 5, 'F': 6, 'G': 7, 'H': 8, 'I': 9, 'J': 10, 'K': 11, 'L': 12, 'M': 13, 'N': 14,
#               'O': 15, 'P': 16, 'Q': 17, 'R': 18, 'S': 19, 'T': 20, 'U': 21, 'V': 22, 'W': 23, 'X': 24, 'Y': 25, 'Z': 26, '|': 27, '<pad>': 28,
#               '<unk>': 29}

# primi 20 caratteri del primo elemento: H A D   L A I D   B E F O R E
# diventano: [H, A, D, '|', L, A, I, D, '|', B, E, F, O, R, E] -> [8, 1, 4, 27, 12, 1, 9, 4, 27, 2, 5, 6, 15, 18, 5]
# label: [8, 1, 4, 27, 12, 1, 9, 4, 27, 2, 5, 6, 15, 18, 5, 27, ...]

#lo stesso per l'eval e test

## PARTE 4: MODELLO


In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return x

class SimpleASRModel(nn.Module):
    def __init__(self, vocab_size):
        super(SimpleASRModel, self).__init__()

        # 1. Trasformazione audio -> log-Mel spectrogram
        self.melspec = T.MelSpectrogram(
            sample_rate=16000,
            n_fft=512,
            hop_length=160,
            n_mels=80
        )
        self.log_transform = lambda x: torch.log(x + 1e-5)

        # 2. CNN per feature extraction (2 layer)
        self.cnn = nn.Sequential(
            nn.Conv1d(80, 256, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=5, stride=2, padding=2),
            nn.ReLU()
        )

        # 3. Positional Encoding
        self.positional_encoding = PositionalEncoding(d_model=256)

        # 3. Transformer Encoder
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=256,
                nhead=4,
                dim_feedforward=512,
                batch_first=True
            ),
            num_layers=3
        )
        self.classifier = nn.Linear(256, vocab_size)

    def forward(self, x):
        x = self.melspec(x)
        x = self.log_transform(x)
        x = self.cnn(x)
        x = x.permute(2, 0, 1)
        x = self.positional_encoding(x)
        x = self.transformer(x)
        x = self.classifier(x)

        return x

## PARTE 5: COLLATE


In [None]:
def collate(batch):

    inputs = [torch.tensor(item["input_values"]) for item in batch]
    targets = [torch.tensor(item["labels"]) for item in batch]

    inputs_padded = nn.utils.rnn.pad_sequence(inputs, batch_first=True)

    input_lengths = torch.tensor([len(i) for i in inputs])

    targets_padded = nn.utils.rnn.pad_sequence(targets, batch_first=True, padding_value=vocab_dict["<pad>"])

    target_lengths = torch.tensor([len(t) for t in targets])

    return {
        "input_values": inputs_padded,
        "labels": targets_padded,
        "input_lengths": input_lengths,
        "label_lengths": target_lengths
    }

## PARTE 5: TRAIN LOOP

In [None]:
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleASRModel(vocab_size=len(vocab_dict)).to(device)
model = torch.compile(model)
print("Model on:", next(model.parameters()).device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)
criterion = CTCLoss(blank=vocab_dict["<blank>"], zero_infinity=True)
scaler = amp.GradScaler(device="cuda")

train_batch_size = 1024
eval_batch_size = 1024

train_loader = DataLoader(
    train_dataset_processed,
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=collate,
    pin_memory=True,
    num_workers=4
)
dev_loader = DataLoader(
    eval_dataset_processed,
    batch_size=eval_batch_size,
    shuffle=False,
    collate_fn=collate,
    pin_memory=True,
    num_workers=4
)

print("Numero di batch per epoca:", len(train_loader))

inv_vocab = {v: k for k, v in vocab_dict.items()}
def greedy_decode(log_probs, vocab_dict):
    pred_ids = torch.argmax(log_probs, dim=-1)
    pred_texts = []

    for b in range(pred_ids.shape[1]):
        prev = None
        sentence = []
        for t in range(pred_ids.shape[0]):
            idx = pred_ids[t, b].item()
            if idx != prev and idx != vocab_dict["<blank>"] and idx != vocab_dict["<pad>"]:
                if idx in inv_vocab:
                    sentence.append(inv_vocab[idx])
                prev = idx
        pred_texts.append("".join(sentence))
    return pred_texts

def compute_output_lengths(input_lengths, num_layers=2, stride=2):
    for _ in range(num_layers):
        input_lengths = (input_lengths + 1) // stride
    return input_lengths

def monitor_gpu():
  if torch.cuda.is_available():
      allocated = torch.cuda.memory_allocated() / 1024**3
      reserved = torch.cuda.memory_reserved() / 1024**3
      print(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")

patience = 10
best_loss = float('inf')
no_improve_epochs = 0
monitor_gpu()

for epoch in range(100):
    model.train()
    total_loss = 0
    start_time = time.time()

    for batch in train_loader:
        inputs = batch["input_values"].to(device)
        targets = batch["labels"].to(device)
        input_lengths = batch["input_lengths"].to(device)
        target_lengths = batch["label_lengths"].to(device)

        mel_frame_length = ((input_lengths - 400) // 160) + 1
        adjusted_input_lengths = compute_output_lengths(mel_frame_length).to(device)

        optimizer.zero_grad()

        with autocast("cuda"):
          outputs = model(inputs)
          log_probs = F.log_softmax(outputs, dim=-1)
          flattened_targets = torch.cat([targets[i, :target_lengths[i]] for i in range(targets.size(0))])
          loss = criterion(log_probs, flattened_targets, adjusted_input_lengths, target_lengths)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

    duration = time.time() - start_time
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}, Durata: {duration:.2f}s")

     # ======= VALUTAZIONE SU DEV ========

    if epoch % 10 == 0:
      model.eval()
      all_preds, all_targets = [], []

      with torch.no_grad():
          for batch in dev_loader:
              inputs = batch["input_values"].to(device)
              targets = batch["labels"]
              target_lengths = batch["label_lengths"]

              outputs = model(inputs)
              log_probs = F.log_softmax(outputs, dim=-1)

              pred_texts = greedy_decode(log_probs.cpu(), vocab_dict)

              inv_vocab = {v: k for k, v in vocab_dict.items()}
              for i, length in enumerate(target_lengths):
                  target_ids = targets[i][:length].tolist()
                  target_text = "".join([inv_vocab[id] for id in target_ids])
                  target_text = target_text.replace("|", " ")
                  all_targets.append(target_text)

              all_preds.extend(pred_texts)

      wer_score = wer(all_targets, all_preds)
      cer_score = cer(all_targets, all_preds)
      print(f"Epoch {epoch+1} - WER: {wer_score:.4f}, CER: {cer_score:.4f}")

      print("\nEsempi di predizioni:")
      for i in range(min(3, len(all_preds))):
          print(f"Ref: {all_targets[i]}")
          print(f"Pred: {all_preds[i]}")
      monitor_gpu()

    scheduler.step(avg_loss)

    # ======= EARLY STOPPING ========

    if avg_loss < best_loss:
        best_loss = avg_loss
        no_improve_epochs = 0
        print(f"Nuovo best loss: {best_loss:.4f}")
        print("-" * 50)
    else:
        no_improve_epochs += 1
        print(f"Nessun miglioramento della loss ({no_improve_epochs}/{patience})")
        print("-" * 50)
    if no_improve_epochs >= patience:
        print("Early stopping!")
        print("-" * 50)
        break