In [1]:
from slr.models.CorrNet import CorrNet
from slr.datasets.Phoenix2014DataModule import Phoenix2014DataModule
import numpy as np
import os
from slr.models.utils.decode import Decode

In [2]:
# load model
model = CorrNet.load_from_checkpoint(
    "/new_home/xzj23/workspace/SLR/experiments/Phoenix2014/CorrNet/2024-09-23_12-48-46/checkpoints/epoch=21-DEV_WER=19.10.ckpt")
model

CorrNet(
  (conv2d): ResNet(
    (conv1): Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=False)
        (bn1): Batc

In [3]:
# get data

with open(os.path.join("/new_home/xzj23/workspace/SLR/data/global_files/gloss_dict/phoenix2014_gloss_dict.npy"),
          'rb') as f:
    gloss_dict = np.load(f, allow_pickle=True).item()
datamodule = Phoenix2014DataModule(
    features_path="/new_home/xzj23/workspace/SLR/data/phoenix2014/phoenix-2014-multisigner/features/fullFrame-256x256px",
    annotations_path="/new_home/xzj23/workspace/SLR/data/phoenix2014/phoenix-2014-multisigner/annotations/manual",
    gloss_dict=gloss_dict,
    batch_size=2, num_workers=10
)
datamodule.setup(stage="fit")

x, x_lgt, y, y_lgt, info = datamodule.train_dataloader().__iter__().__next__()

In [4]:
x.shape

torch.Size([2, 188, 3, 224, 224])

In [5]:
y

tensor([1204,  909,  825, 1118, 1257,  719, 1257, 1097,  906, 1149, 1257,  114,
         465,  848,  848, 1203, 1204,  925, 1132,   31,    8,  448,  788,  788,
         788,  788,   59,  951, 1014,  753, 1146])

In [6]:
x_lgt

tensor([188, 188])

In [7]:
y_lgt

tensor([16, 15])

In [8]:
info

[folder        15October_2009_Thursday_tagesschau_default-12/...
 signer                                                 Signer03
 annotation    __ON__ SONNTAG SAMSTAG WEST loc-REGION NORD lo...
 Name: 15October_2009_Thursday_tagesschau_default-12, dtype: object,
 folder        23September_2010_Thursday_heute_default-11/1/*...
 signer                                                 Signer01
 annotation    __ON__ STARK WIND ALPEN ABER IN-KOMMEND REGEN ...
 Name: 23September_2010_Thursday_heute_default-11, dtype: object]

In [9]:
# to device
device = "cuda:1"
model = model.to(device)
x = x.to(device)
y = y.to(device)
x_lgt = x_lgt.to(device)
y_lgt = y_lgt.to(device)

# forward
model.eval()
conv_logits, y_hat_logits, y_hat_lgt = model(x, x_lgt)

In [10]:
conv_logits.shape

torch.Size([44, 2, 1296])

In [11]:
y_hat_logits.shape

torch.Size([44, 2, 1296])

In [12]:
y_hat_lgt.shape

torch.Size([2])

In [13]:
# set decoder
decoder = Decode(gloss_dict=gloss_dict, num_classes=1296, search_mode='beam_search')

In [14]:
sentence = decoder.decode(y_hat_logits.softmax(dim=-1).cpu(), y_hat_lgt.cpu(), batch_first=False, probs=True)

In [15]:
for s in sentence:
    print(" ".join([t[0] for t in s]))

__ON__ SONNTAG SAMSTAG WEST loc-REGION NORD loc-REGION WECHSELHAFT SONNE WOLKE loc-REGION BERG IX SCHNEE __OFF__
__ON__ STARK WIND ALPEN ABER IN-KOMMEND REGEN AUCH SUED UND OST WOCHENENDE
