In [1]:
!git clone https://github.com/felixkreuk/SegFeat

Cloning into 'SegFeat'...
remote: Enumerating objects: 55, done.[K
remote: Counting objects: 100% (21/21), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 55 (delta 14), reused 12 (delta 12), pack-reused 34 (from 1)[K
Receiving objects: 100% (55/55), 32.80 MiB | 40.13 MiB/s, done.
Resolving deltas: 100% (23/23), done.


In [2]:
%cd /kaggle/working

/kaggle/working


In [3]:
%%writefile segfeat.patch
--- dataloader.py.orig	2026-01-30 19:09:45
+++ dataloader.py	2026-01-30 19:10:09
@@ -87,14 +87,14 @@
 
     # extract mel-spectrogram
     if hparams.feats == 'mel':
-        spect = librosa.feature.melspectrogram(wav,
+        spect = librosa.feature.melspectrogram(y=wav,
                                                sr=sr,
                                                n_fft=hparams.n_fft,
                                                hop_length=hparams.hop_length,
                                                n_mels=hparams.rnn_input_size)
     # extract mfcc
     elif hparams.feats == 'mfcc':
-        spect = librosa.feature.mfcc(wav,
+        spect = librosa.feature.mfcc(y=wav,
                                      sr=sr,
                                      n_fft=hparams.n_fft,
                                      hop_length=hparams.hop_length,
@@ -208,7 +208,10 @@
         raise NotImplementedError
 
     def process_file(self, wav_path):
-        phn_path = wav_path.replace("wav", "phn")
+        base = wav_path.rsplit(".", 1)[0]
+        phn_path = base + ".phn"
+        if not os.path.exists(phn_path):
+            phn_path = base + ".PHN"
 
         # load audio
         spect = extract_features(wav_path, self.hparams)
@@ -235,7 +238,7 @@
 
     def _make_dataset(self):
         files = []
-        wavs = list(iter_find_files(self.wav_path, "*.wav"))
+        wavs = list(iter_find_files(self.wav_path, "*.wav")) + list(iter_find_files(self.wav_path, "*.WAV"))
         if self.hparams.devrun:
             wavs = wavs[:self.hparams.devrun_size]
 
@@ -265,10 +268,19 @@
         self.data = self._make_dataset()
 
     @staticmethod
-    def get_datasets(hparams):
-        train_dataset = TimitDataset(join(hparams.wav_path, 'train'),
+    def _find_subdir(base, name):
+        """Find a subdirectory case-insensitively."""
+        target = name.lower()
+        for entry in os.listdir(base):
+            if entry.lower() == target and os.path.isdir(join(base, entry)):
+                return join(base, entry)
+        return join(base, name)
+
+    @staticmethod
+    def get_datasets(hparams):
+        train_dataset = TimitDataset(TimitDataset._find_subdir(hparams.wav_path, 'train'),
                                      hparams)
-        test_dataset  = TimitDataset(join(hparams.wav_path, 'test'),
+        test_dataset  = TimitDataset(TimitDataset._find_subdir(hparams.wav_path, 'test'),
                                      hparams)
 
         train_len   = len(train_dataset)


Writing segfeat.patch


In [4]:
%%writefile lightning.patch
--- main.py.orig	2026-01-30 19:07:04
+++ main.py	2026-01-30 19:07:16
@@ -11,7 +11,7 @@
 from loguru import logger
 from pytorch_lightning import Trainer
 from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
-from pytorch_lightning.logging import TestTubeLogger
+from pytorch_lightning.loggers import TensorBoardLogger
 from torch.backends import cudnn
 from torch.utils.data import DataLoader, Dataset
 
@@ -44,15 +44,13 @@
         mode='min'
     )
 
-    tt_logger = TestTubeLogger(
+    tb_logger = TensorBoardLogger(
         save_dir=hparams.run_dir,
         name="lightning_logs",
-        debug=False,
-        create_git_tag=False
     )
 
     checkpoint = ModelCheckpoint(
-        filepath=model_save_path,
+        dirpath=model_save_path,
         save_top_k=1,
         verbose=True,
         monitor='val_f1_at_2',
@@ -60,19 +58,17 @@
     )
 
     trainer = Trainer(
-            logger=tt_logger,
-            overfit_pct=hparams.overfit,
+            logger=tb_logger,
             check_val_every_n_epoch=1,
             min_epochs=1,
             max_epochs=hparams.epochs,
-            nb_sanity_val_steps=4,
-            checkpoint_callback=None,
-            val_percent_check=hparams.val_percent_check,
+            num_sanity_val_steps=4,
+            callbacks=[early_stop, checkpoint],
+            limit_val_batches=hparams.val_percent_check,
             val_check_interval=hparams.val_check_interval,
-            early_stop_callback=None,
-            gpus=hparams.gpus,
-            show_progress_bar=False,
-            distributed_backend=None,
+            devices="auto",
+            accelerator="auto",
+            enable_progress_bar=True,
             )
 
     if not hparams.test:
--- solver.py.orig	2026-01-30 19:07:04
+++ solver.py	2026-01-30 19:07:57
@@ -19,7 +19,7 @@
 class Solver(LightningModule):
     def __init__(self, config):
         super(Solver, self).__init__()
-        self.hparams = config
+        self.save_hyperparameters(config)
 
         if config.dataset == "timit":
             self.datasetClass = TimitDataset
@@ -46,23 +46,23 @@
                         'test':  StatsMeter()}
         self._device = 'cuda' if config.cuda else 'cpu'
 
+        self.validation_step_outputs = []
+        self.test_step_outputs = []
+
         self.build_model()
         logger.info(f"running on {self._device}")
         logger.info(f"rnn input size: {config.rnn_input_size}")
         logger.info(f"{self.segmentor}")
 
-    @pl.data_loader
     def train_dataloader(self):
         self.train_loader = DataLoader(self.train_dataset,
                                        batch_size=self.config.batch_size,
                                        shuffle=True,
                                        collate_fn=collate_fn_padd,
                                        num_workers=6)
-        logger.info(f"input shape: {self.train_dataset[0][0].shape}")
         logger.info(f"training set length {len(self.train_dataset)}")
         return self.train_loader
 
-    @pl.data_loader
     def val_dataloader(self):
         self.valid_loader = DataLoader(self.valid_dataset,
                                        batch_size=self.config.batch_size,
@@ -72,7 +72,6 @@
         logger.info(f"validation set length {len(self.valid_dataset)}")
         return self.valid_loader
 
-    @pl.data_loader
     def test_dataloader(self):
         self.test_loader  = DataLoader(self.test_dataset,
                                        batch_size=self.config.batch_size,
@@ -200,8 +199,6 @@
 
         for output in outputs:
             loss = output[f'{prefix}_loss']
-            if self.trainer.use_dp:
-                loss = torch.mean(loss)
             loss_mean += loss
 
         loss_mean /= len(outputs)
@@ -243,19 +240,28 @@
 
         logger.info(f"\nEVAL {prefix} STATS:\n{json.dumps(metrics, sort_keys=True, indent=4)}\n")
 
-        return metrics
+        for k, v in metrics.items():
+            self.log(k, v, prog_bar=(k == f'{prefix}_f1_at_2'))
 
     def validation_step(self, data_batch, batch_i):
-        return self.generic_eval_step(data_batch, batch_i, 'val')
-
-    def validation_epoch_end(self, outputs):
-        return self.generic_eval_end(outputs, 'val')
+        out = self.generic_eval_step(data_batch, batch_i, 'val')
+        self.validation_step_outputs.append(out)
+        return out
 
+    def on_validation_epoch_end(self):
+        outputs = self.validation_step_outputs
+        self.generic_eval_end(outputs, 'val')
+        self.validation_step_outputs.clear()
+
     def test_step(self, data_batch, batch_i):
-        return self.generic_eval_step(data_batch, batch_i, 'test')
+        out = self.generic_eval_step(data_batch, batch_i, 'test')
+        self.test_step_outputs.append(out)
+        return out
 
-    def test_epoch_end(self, outputs):
-        return self.generic_eval_end(outputs, 'test')
+    def on_test_epoch_end(self):
+        outputs = self.test_step_outputs
+        self.generic_eval_end(outputs, 'test')
+        self.test_step_outputs.clear()
 
     def configure_optimizers(self):
         optimizer = {'adam':     torch.optim.Adam(self.segmentor.parameters(), lr=self.config.lr),


Writing lightning.patch


In [5]:
%cd SegFeat

/kaggle/working/SegFeat


In [6]:
!git apply ../segfeat.patch
!git apply ../lightning.patch

In [7]:
%%writefile requirements.txt
torch
torchaudio
torchvision
pytorch-lightning
boltons
loguru
librosa
numpy
pandas
soundfile
tqdm

Writing requirements.txt


In [8]:
!pip install -r requirements.txt

Collecting boltons (from -r requirements.txt (line 5))
  Downloading boltons-25.0.0-py3-none-any.whl.metadata (6.5 kB)
Collecting loguru (from -r requirements.txt (line 6))
  Downloading loguru-0.7.3-py3-none-any.whl.metadata (22 kB)
Downloading boltons-25.0.0-py3-none-any.whl (194 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.2/194.2 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading loguru-0.7.3-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.6/61.6 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: loguru, boltons
Successfully installed boltons-25.0.0 loguru-0.7.3


In [9]:
!python main.py --wav_path /kaggle/input/darpa-timit-acousticphonetic-continuous-speech/data --dataset timit --delta_feats --dist_feats

[32m2026-01-30 18:24:32.245[0m | [1mINFO    [0m | [36m__main__[0m:[36mmain[0m:[36m28[0m - [1mrun dir: /tmp/segmentation/segmentation_experiment[0m
[32m2026-01-30 18:24:32.254[0m | [1mINFO    [0m | [36m__main__[0m:[36mmain[0m:[36m32[0m - [1msaving log in: /tmp/segmentation/segmentation_experiment/run.log[0m
[32m2026-01-30 18:24:32.254[0m | [1mINFO    [0m | [36m__main__[0m:[36mmain[0m:[36m35[0m - [1msaving models in: /tmp/segmentation/segmentation_experiment/ckpt[0m
[32m2026-01-30 18:24:32.254[0m | [1mINFO    [0m | [36m__main__[0m:[36mmain[0m:[36m36[0m - [1mearly stopping with patience of 5[0m
loading data into memory:   0%|                        | 0/9240 [00:15<?, ?it/s]
Traceback (most recent call last):
  File "/kaggle/working/SegFeat/main.py", line 131, in <module>
    main(args)
  File "/kaggle/working/SegFeat/main.py", line 38, in main
    solver = Solver(hparams)
             ^^^^^^^^^^^^^^^
  File "/kaggle/working/SegFea