In [2]:
from importlib import reload

import datasets
import tokenizers
import transformers

import src.models.components.feature_extractor_dinov2
import src.models.components.sign_language_net

In [3]:
rwth_phoenix_pretrain = datasets.load_dataset('lukasbraach/rwth_phoenix_weather_2014', 'pre-training', streaming=True)

In [4]:
from itertools import chain


def string_iterator():
    it = chain.from_iterable(
        (rwth_phoenix_pretrain['train'], rwth_phoenix_pretrain['validation'], rwth_phoenix_pretrain['test']))

    for batch in it:
        yield batch['transcription']

In [22]:
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import Whitespace, WhitespaceSplit
from tokenizers.models import BPE, WordLevel
from tokenizers.trainers import BpeTrainer, WordLevelTrainer

model = WordLevel(unk_token="__UNK__")
tokenizer = Tokenizer(model=model)
tokenizer.pre_tokenizer = WhitespaceSplit()

trainer = WordLevelTrainer(special_tokens=["__PAD__", "__UNK__"])

tokenizer.train_from_iterator(string_iterator(), trainer)
tokenizer.add_special_tokens([
    tokenizers.AddedToken("__ON__"),
    tokenizers.AddedToken("__OFF__"),
    tokenizers.AddedToken("__EMOTION__"),
    tokenizers.AddedToken("__PU__"),
])

print(tokenizer.get_vocab_size())

In [23]:
tokenizer.save("../src/etc/rwth_phoenix_tokenizer_wordlevel.json")

In [36]:
output = tokenizer.encode("__ON__ SUED VERAENDERN KAUM WIE HEUTE SONNE ODER NEBEL __OFF__")

for batch in string_iterator():
    enc = tokenizer.encode(batch, is_pretokenized=True)
    print(enc.ids)

In [66]:
reload(src.models.components.feature_extractor_dinov2)
from src.models.components.feature_extractor_dinov2 import SignLanguageFeatureExtractor

feature_extractor = SignLanguageFeatureExtractor()


def collate_fn(batch):
    labels = tokenizer.encode(batch['tokens'], is_pretokenized=True)
    feature = feature_extractor(batch['frames'], sampling_rate=25)

    return {"input_values": feature.input_values[0], "labels": labels.ids}


train = rwth_phoenix['train'].map(function=collate_fn, batched=False, remove_columns=['frames', 'tokens'])
first = next(iter(train))

In [67]:
first

In [69]:
reload(src.models.components.sign_language_net)
from src.models.components.sign_language_net import SignLanguageNet

transformers_tokenizer = transformers.PreTrainedTokenizerFast(
    model_input_names=['input_values'],
    pad_token="__PAD__",
    bos_token="__ON__",
    eos_token="__OFF__",
    unk_token="__UNK__",
    tokenizer_object=tokenizer
)

model = SignLanguageNet(tokenizer=transformers_tokenizer)

model

In [76]:
import torch
from transformers import DataCollatorForSeq2Seq

with torch.no_grad():
    collator = DataCollatorForSeq2Seq(
        model=model,
        tokenizer=transformers_tokenizer,
        pad_to_multiple_of=16,
        return_tensors='pt'
    )

collated = collator([first])

collated