In [1]:
import torch
import torchaudio

print(torch.__version__)
print(torchaudio.__version__)

torch.random.manual_seed(0)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(device)

2.7.1
2.7.1
mps


In [2]:
import os
MEDIA_DIR = "/Volumes/KINGSTON/veteran_interviews"
# index = 0
# audio_path = f"{MEDIA_DIR}/{index}/audio.mp3"
# video_path = f"{MEDIA_DIR}/{index}/video.mp4"
# # check if the audio file exists
# if not os.path.exists(audio_path):
#     raise FileNotFoundError(f"Audio file not found: {audio_path}")
# # check if the video file exists
# if not os.path.exists(video_path):
#     raise FileNotFoundError(f"Video file not found: {video_path}")

# develop a function to return media type, and audio/ video path based on index
def get_media_type_and_path(index):
    audio_path = f"{MEDIA_DIR}/{index}/audio.mp3"
    video_path = f"{MEDIA_DIR}/{index}/video.mp4"
    if os.path.exists(video_path):
        return "video", video_path
    elif os.path.exists(audio_path):
        return "audio", audio_path
    else:
        return False
    
get_media_type_and_path(0)  # Test the function with index 0

('video', '/Volumes/KINGSTON/veteran_interviews/0/video.mp4')

In [3]:
# load parquet file
import pandas as pd
df = pd.read_parquet("../datasets/veterans_history_project_resources.parquet")
df['media_type'] = df.index.to_series().apply(lambda x: get_media_type_and_path(x)[0] if get_media_type_and_path(x) else None)
df['media_filepath'] = df.index.to_series().apply(lambda x: get_media_type_and_path(x)[1] if get_media_type_and_path(x) else None)

In [4]:
# create a sample dataset of 100 rows where returned media type is audio
df_sample = df[df['media_type'] == 'audio'].sample(n=10, random_state=42)

In [5]:
import pprint
# testing_raw_transcript = df_sample['fulltext_file_str'][730]
# strip content from the xml (only get the text between <p> tags)
import re
def strip_xml_tags(text):
    # make sure the input is a string
    if not isinstance(text, str):
        # force it to be a string
        text = str(text)
    # Remove all XML tags except <p> and <speaker>
    # Find the position of the first <speaker> tag
    speaker_match = re.search(r'<speaker>.*?</speaker>', text, re.DOTALL)
    if not speaker_match:
        return ""
    start_pos = speaker_match.end()
    # Only search for <p>...</p> after the first <speaker>
    paragraphs = re.findall(r'<p>(.*?)</p>', text[start_pos:], re.DOTALL)
    # Remove everything after "[Conclusion of Interview]"
    result = []
    for para in paragraphs:
        if "[Conclusion of Interview]" in para:
            break
        result.append(para)
    return '\n'.join(result)
# testing_raw_transcript_stripped = strip_xml_tags(testing_raw_transcript)
# apply this function to the fulltext_file_str column
df_sample['raw_transcript_stripped'] = df_sample['fulltext_file_str'].apply(strip_xml_tags)

In [6]:
df_sample['raw_transcript_stripped']

730     Joe, give us the basic identifying information...
1278    Good morning. Today is January 14, 2012. My na...
3545    This is Ashley Hancher interviewing Thomas Mar...
3691    It's Monday, May 31, 2014, Memorial Day. We ar...
1176    Hi. My name is Megan Schwartz. Today is Januar...
4199    Every morning I went by the path that went int...
5409    Oral History interview of World War II Veteran...
3601    The interview is being conducted for the Veter...
3418    This tape, it is July, 2003. My name is Julie ...
3417    My name is Steve Estes and today is August 9, ...
Name: raw_transcript_stripped, dtype: object

In [7]:
df_sample['media_filepath']

730     /Volumes/KINGSTON/veteran_interviews/730/audio...
1278    /Volumes/KINGSTON/veteran_interviews/1278/audi...
3545    /Volumes/KINGSTON/veteran_interviews/3545/audi...
3691    /Volumes/KINGSTON/veteran_interviews/3691/audi...
1176    /Volumes/KINGSTON/veteran_interviews/1176/audi...
4199    /Volumes/KINGSTON/veteran_interviews/4199/audi...
5409    /Volumes/KINGSTON/veteran_interviews/5409/audi...
3601    /Volumes/KINGSTON/veteran_interviews/3601/audi...
3418    /Volumes/KINGSTON/veteran_interviews/3418/audi...
3417    /Volumes/KINGSTON/veteran_interviews/3417/audi...
Name: media_filepath, dtype: object

In [8]:
# save the sample dataset to a new parquet file
df_sample.to_parquet("../datasets/veterans_history_project_sample.parquet", index=False)

In [9]:
from datasets import load_dataset
# load parquet file as a Hugging Face dataset
hf_dataset = load_dataset('parquet', data_files="../datasets/veterans_history_project_sample.parquet")

  from .autonotebook import tqdm as notebook_tqdm
Generating train split: 10 examples [00:00, 1479.94 examples/s]


In [3]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

In [3]:
import torchaudio

def preprocess_audio(file_path):
    waveform, sample_rate = torchaudio.load(file_path)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    return waveform

def normalize_text(text):
    text = text.lower()
    text = re.sub(r'[^a-z\s]', '', text)
    return text

def prepare_dataset(batch):
    audio = preprocess_audio(batch["media_filepath"])
    batch["input_values"] = processor(audio, sampling_rate=16000).input_values[0]
    batch["labels"] = processor.tokenizer(batch["raw_transcript_stripped"]).input_ids
    return batch

In [12]:
dataset = hf_dataset.map(prepare_dataset)

Map: 100%|██████████| 10/10 [08:35<00:00, 51.51s/ examples]


In [37]:
# save dataset to disk
dataset.save_to_disk("../datasets/veterans_history_project_sample_processed")

Saving the dataset (8/8 shards): 100%|██████████| 10/10 [00:12<00:00,  1.21s/ examples]


In [2]:
# retrieve the dataset from disk
from datasets import load_from_disk
dataset = load_from_disk("../datasets/veterans_history_project_sample_processed")

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from datasets import DatasetDict

# Split 90% train, 10% validation
split_dataset = dataset['train'].train_test_split(test_size=0.1)

# Rename 'test' split to 'validation'
split_dataset = DatasetDict({
    'train': split_dataset['train'],
    'validation': split_dataset['test']
})

In [None]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from importlib import reload
from transformers import Trainer, TrainingArguments
import accelerate
reload(accelerate)
# reload(TrainingArguments)
# reload(Trainer)

training_args = TrainingArguments(
    output_dir="./checkpoints",
    per_device_train_batch_size=1,
    eval_strategy="no",
    num_train_epochs=2,
    fp16=False,
    save_strategy="no",
    dataloader_num_workers=0,
    logging_steps=10,
)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["validation"],
    tokenizer=processor.feature_extractor,
)

  trainer = Trainer(


In [10]:
import torch
print(torch.backends.mps.is_available())  # Should be True
print(torch.backends.mps.is_built())       # Should be True

True
True


In [9]:
trainer.train()



KeyboardInterrupt: 

END for HF train here^

In [3]:
torchaudio.set_audio_backend("soundfile")
# Load and resample audio
waveform, sample_rate = torchaudio.load(wav_filepath)

  torchaudio.set_audio_backend("soundfile")


In [4]:
waveform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)(waveform)

In [5]:
waveform

tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -5.1479e-05,
         -5.2398e-05,  2.4214e-06]])

In [7]:
from datasets import load_dataset
parentNaId = "653144"
filepath = f'../datasets/{parentNaId}_transcriptions_with_audio.parquet'
filepath = f'../datasets/veterans_history_project_resources.parquet'
# Load dataset from CSV
dataset = load_dataset('parquet', data_files=filepath)

Generating train split: 10404 examples [00:01, 6927.53 examples/s]


In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['collection_number', 'fulltext_file_url', 'fulltext_file_str', 'video_url', 'audio_url', 'title', 'description', 'dates', 'language', 'location', 'location_home', 'location_service', 'partof', 'subject', 'subject_battles', 'subject_branch', 'subject_conflict', 'subject_entrance', 'subject_format', 'subject_gender', 'subject_rank', 'subject_status', 'subject_unit', 'subject_race'],
        num_rows: 10404
    })
})

In [49]:
dataset["train"][0]['audio_filepaths'][0]
# create new column to store the filepath after conversion
# audio_filepaths are lists, only take the first element for conversion
dataset = dataset.map(lambda x: {'audio_filepath_1st': [fp[0].replace("./", "../datasets/") for fp in x['audio_filepaths']]}, batched=True)
dataset = dataset.map(lambda x: {'transcription_str': [next(iter(t.values()))['transcription'] for t in x['transcription']]}, batched=True)
dataset = dataset.map(lambda x: {'audio_filepath_1st': [convert_mp3_to_wav(fp) for fp in x['audio_filepath_1st']]}, batched=True)

Map: 100%|██████████| 2/2 [00:00<00:00, 26.32 examples/s]
Map: 100%|██████████| 2/2 [00:00<00:00, 16.90 examples/s]
Map: 100%|██████████| 2/2 [00:03<00:00,  1.86s/ examples]


In [50]:
dataset["train"][0]['audio_filepath_1st']

'../datasets/audio/208-192.wav'

In [56]:
dataset["train"][0]['transcription_str']

'Announcer: Good evening from the White House in Washington. Ladies and Gentlemen, the President of the United States. \n\nTruman: Fellow citizens. On August the 18th, 1945, four days after the surrender of Japan, I issued Executive Order 9599 which laid down the guiding policies of your government during the transition from war to peace. Briefly stated these policies are: First, to assist in the maximum production of civilian goods. Second, as rapidly as possible to remove Government controls and restore collective bargaining and free markets. Third, to avoid both inflation and deflation. Those are still our policies. One of the major factors determining whether or not we shall succeed in carrying out those policies is the question of wages and prices. If wages go down substantially, we face deflation. If prices go up substantially, we face inflation. We must be on our guard, and steer clear of both these dangers to our security.\n\nWhat happens to wages is important to all of us, eve

In [None]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained("facebook/wav2vec2-large-960h")

def preprocess_text(batch):
    batch["input_ids"] = tokenizer(batch["transcription_str"], padding=True, truncation=True).input_ids
    return batch

    # Convert tokenized output to a numpy array to ensure consistent dtype
    tokenized = tokenizer(texts, padding=True, truncation=True, return_tensors='np')
    batch["input_ids"] = tokenized["input_ids"].tolist()
    batch["attention_mask"] = tokenized["attention_mask"].tolist()
    return batch

# Apply preprocessing
dataset = dataset.map(preprocess_text, batched=True)

Map:   0%|          | 0/2 [00:00<?, ? examples/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
Map: 100%|██████████| 2/2 [00:00<00:00, 19.58 examples/s]


In [64]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer

# Load tokenizer and model
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'Wav2Vec2CTCTokenizer'. 
The class this function is called from is 'Wav2Vec2Tokenizer'.
Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [65]:
# Update model configuration for new vocabulary size
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = len(tokenizer)

In [77]:
train_dataset = dataset['train']
eval_dataset = dataset['train']


In [78]:
from torch.utils.data import DataLoader

# Define a custom collator
def data_collator(batch):
    audio_features = [item["input_values"] for item in batch]
    labels = [item["labels"] for item in batch]
    return {"input_values": audio_features, "labels": labels}

# Create DataLoader
train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, collate_fn=data_collator
)
eval_loader = DataLoader(
    eval_dataset, batch_size=16, shuffle=False, collate_fn=data_collator
)

In [85]:
for i in train_loader:
    print(i)
    break

KeyError: 'input_values'

In [79]:
dataset['train']

Dataset({
    features: ['_index', '_id', '_score', '_source', 'sort', 'hit_title', 'hit_record_urls', 'hit_digitalObjects_metadata', 'transcription', 'audio_filepaths', 'audio_filepath_1st', 'transcription_str', 'input_ids', 'attention_mask'],
    num_rows: 2
})

In [84]:
import torch
from torch.optim import AdamW

# Define optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=3e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.1)

# Training loop
for epoch in range(5):
    model.train()
    for batch in train_loader:
        # Use "input_ids" as both input and label, to avoid KeyError
        input_ids = batch["input_ids"]
        labels = batch["input_ids"]

        # Forward pass
        outputs = model(input_ids, labels=labels)
        loss = outputs.loss

        # Backward pass
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    print(f"Epoch {epoch + 1} completed with loss {loss.item():.4f}")

KeyError: 'input_values'