In [2]:
from importlib import reload

import datasets
import tokenizers
import transformers

import src.models.components.feature_extractor_dinov2
import src.models.components.sign_language_net

In [3]:
rwth_phoenix_pretrain = datasets.load_dataset('lukasbraach/rwth_phoenix_weather_2014', 'pre-training', streaming=True)

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [4]:
from itertools import chain


def string_iterator():
    it = chain.from_iterable(
        (rwth_phoenix_pretrain['train'], rwth_phoenix_pretrain['validation'], rwth_phoenix_pretrain['test']))

    for batch in it:
        yield batch['transcription']

In [20]:
from tokenizers import Tokenizer
from tokenizers.pre_tokenizers import Whitespace, WhitespaceSplit
from tokenizers.models import BPE, WordLevel
from tokenizers.trainers import BpeTrainer, WordLevelTrainer

model = WordLevel(unk_token="__UNK__")
tokenizer = Tokenizer(model=model)
tokenizer.pre_tokenizer = WhitespaceSplit()

trainer = WordLevelTrainer(special_tokens=["__PAD__", "__UNK__"])

tokenizer.train_from_iterator(string_iterator(), trainer)
tokenizer.add_special_tokens([
    tokenizers.AddedToken("__ON__"),
    tokenizers.AddedToken("__OFF__"),
    tokenizers.AddedToken("__EMOTION__"),
    tokenizers.AddedToken("__PU__"),
])

print(tokenizer.get_vocab_size())

1297


In [21]:
tokenizer.save("../src/etc/rwth_phoenix_tokenizer_wordlevel.json")

In [36]:
output = tokenizer.encode("__ON__ SUED VERAENDERN KAUM WIE HEUTE SONNE ODER NEBEL __OFF__")

for batch in string_iterator():
    enc = tokenizer.encode(batch, is_pretokenized=True)
    print(enc.ids)

[2, 134, 140, 40, 247, 466, 80, 298, 22, 357, 893, 9]
[162, 77, 162, 59, 60]
[60, 60, 74, 17, 55, 1271, 46, 55, 82, 93, 94, 123]
[10, 73, 112, 33, 133, 117, 133]
[25, 66, 47, 41, 9, 65, 47, 6, 6]
[57, 30, 16, 105, 79, 4, 31, 19]
[35, 41, 62, 79, 23, 23, 77, 171, 87, 42, 211, 89, 40, 3]
[2, 29, 53, 10, 61, 326, 231, 3]
[28, 17, 37, 6, 80, 4, 91, 10, 127, 41]
[9, 31, 40, 47, 125, 27, 6, 6]
[192, 32, 56, 51, 51, 162, 24, 32, 3]
[2, 28, 17, 52, 96, 55, 111, 93, 194, 122]
[10, 104, 619, 128, 52]
[518, 79, 25, 6, 6, 64, 4, 16, 3]
[2, 161, 9, 64, 131, 163, 77, 34, 63, 12, 76, 306, 216, 73, 73, 3]
[2, 14, 54, 43, 57, 34, 6, 6, 76, 20, 3]
[62, 9, 14, 21, 38, 136, 49, 119, 11, 212, 9, 958, 101, 23, 12, 3]
[2, 32, 56, 51, 24, 167, 9, 21, 109, 15, 15, 15, 3]
[2, 44, 262, 28, 40, 99, 13, 52, 149, 13, 3]
[2, 323, 10, 128, 13, 52, 216, 82, 36, 33, 111, 36, 13, 3]
[2, 57, 47, 97, 11, 21, 6, 6, 151, 20, 21]
[121, 177, 12, 65, 77, 3]
[2, 84, 40, 134, 140, 216, 21, 89, 47, 137, 108, 29, 70, 485, 100, 159

In [66]:
reload(src.models.components.feature_extractor_dinov2)
from src.models.components.feature_extractor_dinov2 import SignLanguageFeatureExtractor

feature_extractor = SignLanguageFeatureExtractor()


def collate_fn(batch):
    labels = tokenizer.encode(batch['tokens'], is_pretokenized=True)
    feature = feature_extractor(batch['frames'], sampling_rate=25)

    return {"input_values": feature.input_values[0], "labels": labels.ids}


train = rwth_phoenix['train'].map(function=collate_fn, batched=False, remove_columns=['frames', 'tokens'])
first = next(iter(train))

In [67]:
first

{'input_values': array([[-0.84833705, -0.58704543, -2.074996  , ..., -0.26034883,
         -2.8003612 , -0.06684101],
        [ 1.2415382 , -1.3554696 ,  0.14667816, ..., -0.8437252 ,
         -2.207518  ,  0.22651023],
        [ 1.0892667 , -1.7075235 ,  0.19240695, ..., -1.3566458 ,
         -2.3756495 , -0.35655874],
        ...,
        [ 0.10560069, -1.5657716 , -0.36898893, ..., -0.51297307,
         -3.1072054 , -0.00521034],
        [ 0.02705131, -1.019656  , -0.52293456, ...,  0.08987308,
         -2.6816874 ,  0.95835376],
        [-0.42416984, -0.6026987 , -1.3009195 , ..., -0.2116571 ,
         -2.1713347 ,  0.44469047]], dtype=float32),
 'labels': [2,
  418,
  427,
  194,
  617,
  990,
  77,
  5,
  75,
  726,
  77,
  5,
  72,
  832,
  1707,
  87]}

In [69]:
reload(src.models.components.sign_language_net)
from src.models.components.sign_language_net import SignLanguageNet

transformers_tokenizer = transformers.PreTrainedTokenizerFast(
    model_input_names=['input_values'],
    pad_token="__PAD__",
    bos_token="__ON__",
    eos_token="__OFF__",
    unk_token="__UNK__",
    tokenizer_object=tokenizer
)

model = SignLanguageNet(tokenizer=transformers_tokenizer)

model

SignLanguageNet(
  (encoder): SpatiotemporalEncoder(
    (feature_extractor): SpatialFeatureEncoder()
    (feature_projection): SpatiotemporalFeatureProjection(
      (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Wav2Vec2Encoder(
      (pos_conv_embed): Wav2Vec2PositionalConvEmbedding(
        (conv): Conv1d(768, 768, kernel_size=(128,), stride=(1,), padding=(64,), groups=16)
        (padding): Wav2Vec2SamePadLayer()
        (activation): GELUActivation()
      )
      (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (layers): ModuleList(
        (0-5): 6 x Wav2Vec2EncoderLayer(
          (attention): Wav2Vec2Attention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=7

In [76]:
import torch
from transformers import DataCollatorForSeq2Seq

with torch.no_grad():
    collator = DataCollatorForSeq2Seq(
        model=model,
        tokenizer=transformers_tokenizer,
        pad_to_multiple_of=16,
        return_tensors='pt'
    )

collated = collator([first])

collated

{'input_values': tensor([[[-0.8483, -0.5870, -2.0750,  ..., -0.2603, -2.8004, -0.0668],
         [ 1.2415, -1.3555,  0.1467,  ..., -0.8437, -2.2075,  0.2265],
         [ 1.0893, -1.7075,  0.1924,  ..., -1.3566, -2.3756, -0.3566],
         ...,
         [ 0.1056, -1.5658, -0.3690,  ..., -0.5130, -3.1072, -0.0052],
         [ 0.0271, -1.0197, -0.5229,  ...,  0.0899, -2.6817,  0.9584],
         [-0.4242, -0.6027, -1.3009,  ..., -0.2117, -2.1713,  0.4447]]]), 'labels': tensor([[   2,  418,  427,  194,  617,  990,   77,    5,   75,  726,   77,    5,
           72,  832, 1707,   87]]), 'decoder_input_ids': tensor([[   3,    2,  418,  427,  194,  617,  990,   77,    5,   75,  726,   77,
            5,   72,  832, 1707]])}