In [1]:
from slr.models.CorrNet import CorrNet
from slr.datasets.Phoenix2014DataModule import Phoenix2014DataModule
import numpy as np
import os
import torch

In [2]:
# load model
model = CorrNet.load_from_checkpoint(
    "/new_home/xzj23/workspace/SLR/experiments/Phoenix2014/CorrNet/2024-09-23_12-48-46/checkpoints/epoch=21-DEV_WER=19.10.ckpt")
model

CorrNet(
  (conv2d): ResNet(
    (conv1): Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
        (bn1): Batc

In [3]:
# get data

with open(os.path.join("/new_home/xzj23/workspace/SLR/data/global_files/gloss_dict/phoenix2014_gloss_dict.npy"),
          'rb') as f:
    gloss_dict = np.load(f, allow_pickle=True).item()
datamodule = Phoenix2014DataModule(
    features_path="/new_home/xzj23/workspace/SLR/data/phoenix2014/phoenix-2014-multisigner/features/fullFrame-256x256px",
    annotations_path="/new_home/xzj23/workspace/SLR/data/phoenix2014/phoenix-2014-multisigner/annotations/manual",
    gloss_dict=gloss_dict,
    batch_size=2, num_workers=10
)
datamodule.setup(stage="fit")

x, x_lgt, y, y_lgt, info = datamodule.train_dataloader().__iter__().__next__()

In [4]:
x.shape

torch.Size([2, 148, 3, 224, 224])

In [5]:
x_lgt

tensor([148, 140])

In [6]:
y

tensor([ 300,  788,  319,  791,  735,  137,  986,  986,  986,  303, 1204, 1145,
         174,  958, 1087,  412,  906,  438,  966,  500,   65])

In [7]:
y_lgt

tensor([10, 11])

In [8]:
info

[folder        23January_2011_Sunday_tagesschau_default-4/1/*...
 signer                                                 Signer04
 annotation    FLACH REGEN FROST REGION NUR BISSCHEN TROPFEN ...
 Name: 23January_2011_Sunday_tagesschau_default-4, dtype: object,
 folder        29September_2011_Thursday_heute_default-2/1/*.png
 signer                                                 Signer05
 annotation    __ON__ WOCHE DANACH SUPER WARM HERBST SONNE IC...
 Name: 29September_2011_Thursday_heute_default-2, dtype: object]

In [9]:
# to device
device = "cuda:1"
model = model.to(device)
x = x.to(device)
y = y.to(device)
x_lgt = x_lgt.to(device)
y_lgt = y_lgt.to(device)

# forward
model.eval()
conv_logits, y_hat_logits, y_hat_lgt = model(x, x_lgt)
print(conv_logits.shape, y_hat_logits.shape, y_hat_lgt)
torch.cuda.empty_cache()

torch.Size([34, 2, 1296]) torch.Size([34, 2, 1296]) tensor([34, 32])


In [10]:
yyy = y_hat_logits.softmax(dim=-1).cpu()
yyy_lgt = y_hat_lgt.cpu()

In [11]:
yyy.shape

torch.Size([34, 2, 1296])

In [12]:
yyy_lgt

tensor([34, 32])

# torchaudio ctc_decoder

In [13]:
from torchaudio.models.decoder import ctc_decoder
from slr.datasets.tknzs.simple_tokenizer import SimpleTokenizer

In [14]:
tokenizer = SimpleTokenizer(
    vocab_file="/new_home/xzj23/workspace/SLR/slr/datasets/vocabs/phoenix2014_gloss_vocab.txt"
)
tokens = list(tokenizer.vocab.keys())
tokens.append('|')
decoder = ctc_decoder(
    lexicon=None,
    tokens=tokens,
    nbest=10,
    beam_size=100,
    beam_size_token=None,
    beam_threshold=0,
    log_add=True,
    blank_token=tokenizer.pad_token,
    unk_word=tokenizer.unk_token
)

In [15]:
tokens

['<PAD>',
 'A',
 'AACHEN',
 'AB',
 'AB-JETZT',
 'AB-PLUSPLUS',
 'AB-SO',
 'ABEND',
 'ABER',
 'ABFALLEN',
 'ABKUEHLEN',
 'ABSCHIED',
 'ABSCHNITT',
 'ABSINKEN',
 'ABWECHSELN',
 'ACH',
 'ACHT',
 'ACHTE',
 'ACHTHUNDERT',
 'ACHTUNG',
 'ACHTZEHN',
 'ACHTZIG',
 'AEHNLCH',
 'AEHNLICH',
 'AENDERN',
 'AFRIKA',
 'AKTIV',
 'AKTUELL',
 'ALLE',
 'ALLGAEU',
 'ALLGEMEIN',
 'ALPEN',
 'ALPENRAND',
 'ALPENTAL',
 'ALS',
 'ALSO',
 'ALT',
 'AM',
 'AM-KUESTE',
 'AM-MEER',
 'AM-RAND',
 'AM-TAG',
 'AMERIKA',
 'AN',
 'ANDERE',
 'ANDERE-MOEGLICHKEIT',
 'ANDERS',
 'ANFANG',
 'ANGEMESSEN',
 'ANGENEHM',
 'ANGST',
 'ANHALT',
 'ANKLICKEN',
 'ANKOMMEN',
 'ANSAMMELN',
 'ANSCHAUEN',
 'APRIL',
 'ARM',
 'ATLANTIK',
 'AUCH',
 'AUCH-NICHT',
 'AUF',
 'AUF-JEDEN-FALL',
 'AUFBLUEHEN',
 'AUFEINANDERTREFFEN',
 'AUFFUELLEN',
 'AUFHEITERN',
 'AUFHOEREN',
 'AUFKLAREN',
 'AUFKOMMEN',
 'AUFLOCKERUNG',
 'AUFLOCKERUNG-PLUSPLUS',
 'AUFLOESEN',
 'AUFLOESEN-PLUSPLUS',
 'AUFPASSEN',
 'AUFTAUCHEN',
 'AUFZIEHEN',
 'AUFZIEHEN-PLUSPLUS',
 'AUG

In [16]:
decoder

<torchaudio.models.decoder._ctc_decoder.CTCDecoder at 0x7f3e9aa30580>

In [17]:
beam_search_result = decoder(yyy.permute(1, 0, 2), yyy_lgt)
beam_search_result

[[CTCHypothesis(tokens=tensor([1296,  300, 1204, 1145,  174,  788,  958,  319,  412, 1296]), words=[], score=31.43412610888481, timesteps=tensor([ 0,  1,  2,  8, 12, 13, 20, 21, 30, 35], dtype=torch.int32))],
 [CTCHypothesis(tokens=tensor([1296, 1204, 1145,  174,  788,  958,  319,  412, 1296]), words=[], score=29.570444375276566, timesteps=tensor([ 0,  1,  7, 11, 12, 19, 20, 29, 33], dtype=torch.int32))]]

In [18]:
beam_search_result[0][0]

CTCHypothesis(tokens=tensor([1296,  300, 1204, 1145,  174,  788,  958,  319,  412, 1296]), words=[], score=31.43412610888481, timesteps=tensor([ 0,  1,  2,  8, 12, 13, 20, 21, 30, 35], dtype=torch.int32))

In [19]:
beam_search_result[0][0].tokens.shape

torch.Size([10])

In [20]:
beam_search_result[0][0].timesteps.shape

torch.Size([10])

In [21]:
gloss_pred = tokenizer.decode(beam_search_result[0][0].tokens)
print(gloss_pred)

['<PAD>', 'FLACH', '__ON__', 'WOCHE', 'DANACH', 'REGEN', 'SUPER', 'FROST', 'HERBST', '<PAD>']


In [22]:
gloss_pred = tokenizer.decode(beam_search_result[1][0].tokens)
print(gloss_pred)

['<PAD>', '__ON__', 'WOCHE', 'DANACH', 'REGEN', 'SUPER', 'FROST', 'HERBST', '<PAD>']


In [23]:
while '<UNK>' in gloss_pred:
    gloss_pred.remove('<UNK>')

In [24]:
print(gloss_pred)

['<PAD>', '__ON__', 'WOCHE', 'DANACH', 'REGEN', 'SUPER', 'FROST', 'HERBST', '<PAD>']


In [25]:
y

tensor([ 300,  788,  319,  791,  735,  137,  986,  986,  986,  303, 1204, 1145,
         174,  958, 1087,  412,  906,  438,  966,  500,   65], device='cuda:1')

In [26]:
print(tokenizer.decode(y))

['FLACH', 'REGEN', 'FROST', 'REGION', 'NUR', 'BISSCHEN', 'TROPFEN', 'TROPFEN', 'TROPFEN', 'FLOCKEN', '__ON__', 'WOCHE', 'DANACH', 'SUPER', 'WARM', 'HERBST', 'SONNE', 'ICH', 'TANKEN', 'KOERPER', 'AUFFUELLEN']


# Decode

In [27]:
from slr.models.utils.decode import Decode

with open("/new_home/xzj23/workspace/SLR/data/global_files/gloss_dict/phoenix2014_gloss_dict.npy", "rb") as f:
    gloss_dict = np.load(f, allow_pickle=True).item()

aa_decoder = Decode(
    gloss_dict=gloss_dict,
    num_classes=1296,
    search_mode='beam'
)

In [28]:
res = aa_decoder.decode(yyy, yyy_lgt, batch_first=False, probs=True)
res

[[('FLACH', 0),
  ('REGEN', 1),
  ('FROST', 2),
  ('REGION', 3),
  ('NUR', 4),
  ('BISSCHEN', 5),
  ('TROPFEN', 6),
  ('FLOCKEN', 7)],
 [('__ON__', 0),
  ('WOCHE', 1),
  ('DANACH', 2),
  ('SUPER', 3),
  ('HERBST', 4),
  ('SONNE', 5),
  ('ICH', 6),
  ('TANKEN', 7),
  ('KOERPER', 8)]]

In [29]:
for i in range(len(res)):
    print([k[0] for k in res[i]])

['FLACH', 'REGEN', 'FROST', 'REGION', 'NUR', 'BISSCHEN', 'TROPFEN', 'FLOCKEN']
['__ON__', 'WOCHE', 'DANACH', 'SUPER', 'HERBST', 'SONNE', 'ICH', 'TANKEN', 'KOERPER']


# decode3.py

In [30]:
from slr.models.decoders.CTCBeamSearchDecoder import CTCBeamSearchDecoder

In [31]:
decoder3 = CTCBeamSearchDecoder(tokenizer=tokenizer, num_processes=1)

In [32]:
yyy.shape

torch.Size([34, 2, 1296])

In [33]:
yyy_lgt

tensor([34, 32])

In [34]:
len(decoder3.tokenizer.vocab)

1296

In [35]:
ress = decoder3.decode(yyy.permute(1, 0, 2), yyy_lgt, batch_first=True)

In [36]:
ress

[['FLACH',
  'REGEN',
  'FROST',
  'REGION',
  'NUR',
  'BISSCHEN',
  'TROPFEN',
  'FLOCKEN'],
 ['__ON__',
  'WOCHE',
  'DANACH',
  'SUPER',
  'HERBST',
  'SONNE',
  'ICH',
  'TANKEN',
  'KOERPER']]

In [37]:
for i in range(len(ress)):
    print(ress[i])

['FLACH', 'REGEN', 'FROST', 'REGION', 'NUR', 'BISSCHEN', 'TROPFEN', 'FLOCKEN']
['__ON__', 'WOCHE', 'DANACH', 'SUPER', 'HERBST', 'SONNE', 'ICH', 'TANKEN', 'KOERPER']


In [38]:
for i in range(len(ress)):
    for j in range(len(ress[i])):
        print(ress[i][j], res[i][j])

FLACH ('FLACH', 0)
REGEN ('REGEN', 1)
FROST ('FROST', 2)
REGION ('REGION', 3)
NUR ('NUR', 4)
BISSCHEN ('BISSCHEN', 5)
TROPFEN ('TROPFEN', 6)
FLOCKEN ('FLOCKEN', 7)
__ON__ ('__ON__', 0)
WOCHE ('WOCHE', 1)
DANACH ('DANACH', 2)
SUPER ('SUPER', 3)
HERBST ('HERBST', 4)
SONNE ('SONNE', 5)
ICH ('ICH', 6)
TANKEN ('TANKEN', 7)
KOERPER ('KOERPER', 8)


In [39]:
for i in range(len(ress)):
    print(ress[i] == [k[0] for k in res[i]])

True
True
