In [None]:
from typing import Dict, List, Tuple
import numpy as np
import torch
import torchaudio
import torchaudio.transforms as T
import glob
import tqdm.auto as tqdm

import pathlib

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

import torchaudio.compliance
import torch.nn.functional as F

import collections


print(torch.__version__)
print(torchaudio.__version__)

In [None]:
class WoofalyticsDataset(Dataset):
    def __init__(self, data_dir: str, labels_tsv: str):
        wave_files = list(glob.glob(f"{data_dir}/*.wav"))
        labels_tsv = self._load_labels(labels_tsv)
        data = []
        times = []
        labels = []
        for wave_file in tqdm.tqdm(wave_files):
            wave_form = self._load_wave(wave_file)
            file_id = pathlib.Path(wave_file).stem
            feats = self._extract_features(wave_form)
            print("waveform shape", wave_form.shape)
            print("feats shape", feats.shape)
            file_labels = []
            if file_id not in labels_tsv:
                print(f"File {file_id} does not have any labels.")
            else:
                file_labels = [(start, start+duration) for start, duration in labels_tsv[file_id]]
            
            win_len = 6
            overlap = 3
            res = self._extract_overlapping_sections(feats, win_len, overlap)
            print("res len", len(res))
            for idx, item in enumerate(res):
                start = idx*(overlap/100.)
                end = start + (win_len/100.)
                lbl = self._is_range_within_any(file_labels, (start, end))
                data.append(item)
                labels.append(lbl)
                times.append((file_id, start, end))
                
        self.data = data
        self.labels = labels
        self.times = times
        assert len(self.data) == len(self.labels) == len(self.times)

    def _load_labels(self, labels_tsv: str) -> Dict[str, List[Tuple[float, float]]]:
        output = {}
        with open(labels_tsv, "r") as file_handle:
            for idx, line in enumerate(file_handle):
                parts = line.strip().split()
                if len(parts) != 3:
                    print(f"Invalid line: {idx+1}: {line.strip()}: {parts}")
                else:
                    filename, start, end = parts
                    start, end = float(start), float(end)
                    if filename not in output:
                        output[filename] = []
                    
                    output[filename].append((start, end))
        return output
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (self.data[idx].flatten().unsqueeze(0), 
                torch.tensor([[1.] if self.labels[idx] else [0.]]), 
                self.times[idx])
    
    def _is_range_within_any(self, input_ranges, given_range):
        """
        Check if a given range falls within any of the input ranges.

        Args:
            input_ranges (list): List of tuples representing ranges in the format [(start1, end1), (start2, end2), ...].
            given_range (tuple): The range to check in the format (start, end).

        Returns:
            bool: True if the given range falls within any of the input ranges, False otherwise.
        """
        start, end = given_range

        for input_range in input_ranges:
            input_start, input_end = input_range
            if input_start <= start <= input_end or input_start <= end <= input_end:
                return True

        return False
    
    def _extract_features(self, wave_form):       
        mel_spectrogram = torchaudio.compliance.kaldi.fbank(num_mel_bins=80, 
                                                            frame_length=25, 
                                                            frame_shift=10,
                                                            waveform=wave_form)
        return mel_spectrogram
    
    def _extract_overlapping_sections(self, tensor, section_length, overlap):
        """
        Extract overlapping sections from a torch tensor.

        Args:
            tensor (torch.Tensor): Input tensor of shape TxD.
            section_length (int): Length of each section to be extracted.
            overlap (int): Number of elements to overlap between consecutive sections.

        Returns:
            List of overlapping sections, each with shape section_lengthxD.
        """
        print("tensor size", tensor.size())
        T, D = tensor.size()
        sections = []

        # Ensure the section_length is not greater than the input tensor size
        section_length = min(section_length, T)

        # Start extracting sections
        start_idx = 0
        while start_idx + section_length <= T:
            end_idx = start_idx + section_length
            section = tensor[start_idx:end_idx]
            sections.append(section)
            start_idx += section_length - overlap

        return sections


    def _load_wave(self, filename: str):
        waveform, sample_rate = torchaudio.load(filename)
        #assert sample_rate == 44_100
        #return waveform
        resample_rate = 16_000
        resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
        resampled_waveform = resampler(waveform)
        return resampled_waveform



In [None]:
train_dataset = WoofalyticsDataset(data_dir="../data/train/", labels_tsv="../data/labels.tsv")

In [None]:
dev_dataset = WoofGuardDataset(data_dir="data/dev/", labels_tsv="data/labels.tsv")

In [None]:
test_dataset = WoofGuardDataset(data_dir="data/test/", labels_tsv="data/labels.tsv")

In [None]:
class WoofClassifier(pl.LightningModule):
    def __init__(self, input_size):
        super(WoofClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 32)
        self.output_layer = nn.Linear(32, 1)
        
        self.validation_step_outputs = collections.defaultdict(list)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.output_layer(x)
        return torch.sigmoid(x)

    def training_step(self, batch, batch_idx):
        inputs, targets, times = batch
        outputs = self(inputs)
        loss = F.binary_cross_entropy(outputs, targets)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets, times = batch
        outputs = self(inputs)
        
        loss = F.binary_cross_entropy(outputs, targets)
        self.log('val_loss', loss)
        
        acc = self.compute_accuracy(outputs, targets)
        self.log('val_acc', acc)
        
        self.validation_step_outputs["loss"].append(loss)
        self.validation_step_outputs["acc"].append(acc)
        return loss
    
    def on_validation_epoch_end(self):
        avg_loss = torch.stack([x for x in self.validation_step_outputs["loss"]]).mean()
        avg_acc = torch.stack([x for x in self.validation_step_outputs["acc"]]).mean()
        self.validation_step_outputs.clear()
        return {'val_loss': avg_loss, 'val_acc': avg_acc}    
    
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

    def compute_accuracy(self, predictions, labels):
        # Convert predictions to binary (0 or 1) based on threshold 0.5
        binary_predictions = (predictions >= 0.5).float()

        # Compare binary predictions with the correct labels
        correct_predictions = torch.eq(binary_predictions, labels.float())

        # Calculate accuracy as the percentage of correct predictions
        accuracy = torch.mean(correct_predictions.float())

        return accuracy

In [None]:
input_size = 480
model = WoofClassifier(input_size)

early_stopping_callback = EarlyStopping(monitor='val_acc', mode='max', patience=5, strict=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
dev_loader = DataLoader(dev_dataset, batch_size=32)

# Create the Trainer
trainer = pl.Trainer(max_epochs=100, callbacks=[early_stopping_callback])  # You can modify max_epochs and gpus according to your requirements
tb_logger = pl.loggers.TensorBoardLogger('logs/')  # Logs will be saved in the 'logs' directory

# Train the model using DataLoader
trainer.fit(model, train_loader, dev_loader)

# tb_logger.close()

In [None]:
traced_model = torch.jit.trace(model, train_dataset[0][0])
scripted_model = torch.jit.script(model)

# Save the JIT-compiled model
torch.jit.save(traced_model, "traced_model.pt")
torch.jit.save(scripted_model, "scripted_model.pt")

loaded_traced_model = torch.jit.load("traced_model.pt")
loaded_scripted_model = torch.jit.load("scripted_model.pt")

print(train_dataset[0][0].shape)
output = loaded_traced_model(train_dataset[0][0])
print(output)
output = loaded_scripted_model(train_dataset[0][0])
print(output)

print(model(train_dataset[0][0]))


In [None]:
from sklearn.metrics import f1_score

def find_best_f1_threshold(predictions, labels):
    #thresholds = sorted(set(predictions))  # Unique and sorted prediction values as potential thresholds
    thresholds = np.arange(0,1, 0.01)
    best_f1 = 0.0
    best_threshold = 0.0

    for threshold in thresholds:
        binary_predictions = [1 if pred >= threshold else 0 for pred in predictions]
        f1 = f1_score(labels, binary_predictions)
        
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    return best_threshold, best_f1

In [None]:
preds = []
labels = []
for item in dev_dataset:
    preds.append(model(item[0])[0].detach().numpy()[0])
    labels.append(item[1][0].numpy()[0])
    
threshold, f1 = find_best_f1_threshold(preds, labels)
print(threshold, f1)

In [None]:
import collections
indexes = collections.defaultdict(list)
for idx, (data, label, times) in enumerate(test_dataset):
    if model(data)>= threshold:
        print(idx, times, model(data)[0][0])
        indexes[times[0]].append((times[1], times[2]))
print("-"*80)

In [None]:
def merge_and_drop_segments(segments, threshold):
    # Step 1: Sort segments based on their start times
    sorted_segments = sorted(segments, key=lambda x: x[0])

    # Step 2: Merge neighboring segments
    merged_segments = []
    current_segment = sorted_segments[0]

    for segment in sorted_segments[1:]:
        if segment[0] <= current_segment[1]:  # Overlapping or adjacent
            current_segment = (current_segment[0], max(current_segment[1], segment[1]))
        else:
            merged_segments.append(current_segment)
            current_segment = segment

    merged_segments.append(current_segment)  # Append the last segment

    # Step 3: Drop segments with duration less than the threshold
    final_segments = [segment for segment in merged_segments if segment[1] - segment[0] >= threshold]

    return final_segments

In [None]:
fn="1693562020483807935"
merge_and_drop_segments(indexes[fn], 0.095)

In [None]:
from IPython.display import Audio
def audio(fn, start_seconds=0, end_seconds=0):
    SPEECH_WAVEFORM, SAMPLE_RATE = torchaudio.load(f"data/test/{fn}.wav")
    if start_seconds == 0 and end_seconds == 0:
        return Audio(SPEECH_WAVEFORM.numpy(), rate=SAMPLE_RATE)        
    else:
        s = int(start_seconds*SAMPLE_RATE)
        e = int(end_seconds*SAMPLE_RATE)
        assert e>s
        print(s, e, SPEECH_WAVEFORM.shape)
        return Audio(SPEECH_WAVEFORM.numpy()[:,s:e], rate=SAMPLE_RATE)

In [None]:
audio(fn)

In [None]:
audios = []
merged = merge_and_drop_segments(indexes[fn], 0.095)
for item in merged:
    audios.append(audio(fn, item[0], item[1]))
print(len(audios))
for item in merged:
    print(item, item[1]-item[0])

In [None]:
audios[0]

In [None]:
audios[1]

In [None]:
waveform, sample_rate = torchaudio.load("sample.wav", normalize=False)
waveform.shape

In [None]:
train_dataset[0][0].size()[1]

In [None]:
def infer_file(wav_filename, model, threshold):
    def load_wav(filename):
        waveform, sample_rate = torchaudio.load(filename)
        resample_rate = 16_000
        resampler = T.Resample(sample_rate, resample_rate, dtype=waveform.dtype)
        resampled_waveform = resampler(waveform)
        return resampled_waveform
    def extract_features(wave_form):       
        mel_spectrogram = torchaudio.compliance.kaldi.fbank(num_mel_bins=80, 
                                                            frame_length=25, 
                                                            frame_shift=10,
                                                            waveform=wave_form)
        return mel_spectrogram
    def extract_overlapping_sections(tensor, section_length, overlap):
        T, D = tensor.size()
        sections = []

        # Ensure the section_length is not greater than the input tensor size
        section_length = min(section_length, T)

        # Start extracting sections
        start_idx = 0
        while start_idx + section_length <= T:
            end_idx = start_idx + section_length
            section = tensor[start_idx:end_idx]
            sections.append(section)
            start_idx += section_length - overlap

        return sections
    
    wave_form = load_wav(wav_filename)
    feats = extract_features(wave_form)
    win_len = 6
    overlap = 3
    res = extract_overlapping_sections(feats, win_len, overlap)
    data = []
    times = []
    for idx, item in enumerate(res):
        start = idx*(overlap/100.)
        end = start + (win_len/100.)
        data.append(item)
        times.append((start, end))
    
    print(len(data))
    result = []
    with torch.no_grad():
        for idx, d in enumerate(data):
            dd = d.flatten().unsqueeze(0)
            pred = model(dd)
            if pred >= threshold:
                result.append(times[idx])
    
    if len(result)>0:
        return merge_and_drop_segments(result, 0.095)
    else:
        return []


In [None]:
loaded_traced_model = torch.jit.load("traced_model.pt")

In [None]:
infer_file("data/1693561903662528518.wav", loaded_traced_model, threshold=0.8888888888888888)

In [None]:
int(29.99 * 1000 / 30)

In [None]:
torchaudio.load?

In [None]:
fn = "data/test/1693562020483807935.wav"
waveform, sample_rate = torchaudio.load(fn, normalize=False)
waveform, sample_rate

In [None]:
train_dataset[0][0].shape

In [None]:
#             data = stream.read(self._chunk)
#             import numpy as np
#             audio_array = np.frombuffer(data, dtype=np.int16)
#             print(audio_array.shape)
#             audio_array = audio_array.reshape((2,-1), order='F')
#             print(audio_array.shape)
#             print(audio_array)

#             filename = f"/tmp/sample.wav"

#             print("data len", len(data))
#             wf = wave.open(filename, 'wb')
#             wf.setnchannels(self._channels)
#             wf.setsampwidth(self._sample_size)
#             wf.setframerate(self._fs)
#             wf.writeframes(data)
#             self._logger.info(f"Stored {filename}")
#             return

In [None]:
win_len = 6
overlap = 3

chunk_size = 512
clip = []
count = 0
for i in range(0, waveform.shape[1], chunk_size):
    chunk = waveform[:, i:i + chunk_size].flatten() 
    #chunk = chunk / torch.iinfo(torch.int16).max
    # Process the chunk here (replace this with your code)
    count += 1
    clip.append(chunk)
    if len(clip) == 6:
        clip = torch.cat(clip, dim=0)
        clip = clip.reshape((2,-1))
        clip = clip / torch.iinfo(torch.int16).max
        print(clip.shape)
        mel_spectrogram = torchaudio.compliance.kaldi.fbank(num_mel_bins=80, 
                                                    frame_length=25, 
                                                    frame_shift=10,
                                                    waveform=clip)
        clip = clip.flatten().unsqueeze(0)
        print("CS", clip.shape)
        pred = model(clip)
        if pred >= threshold:
            print(pred)
        
        clip = []
        
#     stereo_chunk = chunk.reshape((2,-1))
#     clip.append(mel_spectrogram)
#     if len(clip) == 6:
#         clip = torch.cat(clip, dim=0).flatten().unsqueeze(0)
#         pred = model(clip)
#         if pred >= threshold:
#             print(pred)
#         clip = []
        

In [None]:
train_dataset[0][0].shape

In [None]:
model(clip)

In [None]:
int16_to_float32 = T.Resample(orig_freq=44100, new_freq=16000, dtype=torch.float32)
int16_to_float32(b)

In [None]:
479818/16000

In [None]:
b

In [None]:
a

In [None]:
a.dtype