In [1]:
import os
import re
import logging
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader

from datasets import load_metric

from load_fsq_model import load_model
from modules.CustomWav2Vec2 import CustomWav2Vec2Config, CustomWav2Vec2Model
from modules.Wav2Vec2Model import Wav2Vec2Config, Wav2Vec2Model
from utils import *

In [2]:
# CONFIG ------------------------------
TEACHER_MODEL = 'wav2vec2_vox_960h_new.pt'
DATA_PATH = '../'
NUM_EPOCHS = 100
GPUS = 2
BATCH_SIZE = 2
LEARNING_RATE = 1e-5
ACCUMULATE_GRAD_BATCHES = 1
OUTPUT_DIR = './results/'
# CHECKPOINT = 'last.ckpt'
CHECKPOINT = None
TEST = False
# --------------------------------------

In [3]:
data_collator = DataCollatorWithPadding()

In [4]:
wer_metric = load_metric("wer")
cer_metric = load_metric("cer")

In [5]:
teacher_model = load_model(TEACHER_MODEL)
teacher_model.eval();

In [6]:
teacher_tf_encoder = teacher_model.w2v_encoder.w2v_model.encoder.layers

In [8]:
# Print Teacher Model Wav2Vec2Model config
teacher_model.w2v_encoder.w2v_model.cfg

Wav2Vec2Config(_name=None, conv_layer_setting=ConvFeatureExtractionModelConfig(_name=None, extractor_mode='layer_norm', conv_feature_layers='[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2', conv_bias=True, drop_out=0.5), encoder_setting=TransformerEncoderConfig(_name=None, layer_setting=TransformerSentenceEncoderLayerConfig(_name=None, encoder_embed_dim=1024, encoder_ffn_embed_dim=4096, encoder_attention_heads=16, dropout=0.0, attention_dropout=0.1, activation_dropout=0.0, activation_fn='gelu', layer_norm_first=True), encoder_layers=24, conv_pos=128, conv_pos_groups=16, encoder_layerdrop=0.0), dropout_input=0.1, dropout_features=0.1, final_dim=768, logit_temp=0.1, quantize_targets=True, quantize_input=False, same_quantizer=False, target_glu=False, feature_grad_mult=1.0, quantizer_depth=1, quantizer_factor=3, latent_vars=320, latent_groups=2, latent_dim=0, mask_length=10, mask_prob=0.65, mask_selection='static', mask_other=0.0, no_mask_overlap=False, mask_min_space=1, mask_channel

In [8]:
# Print Student Model Wav2Vec2Model config
student_config

NameError: name 'student_config' is not defined

In [7]:
# Freeze model parameters
for param in teacher_model.parameters():
    param.requires_grad = False

In [8]:
student_config = CustomWav2Vec2Config()

In [9]:
student_config.conv_layer_setting.extractor_mode = "layer_norm"
student_config.conv_bias = True
student_config.encoder_setting.layer_setting.encoder_embed_dim = 1024
student_config.encoder_setting.layer_setting.encoder_ffn_embed_dim = 4096
student_config.encoder_setting.layer_setting.encoder_attention_heads = 16
student_config.encoder_setting.layer_setting.dropout = 0.0
student_config.encoder_setting.layer_setting.layer_norm_first=True
student_config.encoder_setting.type_of_tr_layer = "conv1d"
student_config.encoder_setting.encoder_layers = 6
student_config.encoder_setting.tr_layer_floor = 3
student_config.encoder_setting.dropout_input = 0.1
student_config.encoder_setting.dropout_features = 0.1
student_config.encoder_setting.final_dim = 768
student_config.encoder_setting.latent_temp = (2, 0.1, 0.999995)
student_config.final_dropout = 0.0
student_config.targ_d = 32

In [10]:
student_model = CustomStudentModel(student_config)

In [11]:
import torchaudio

train_data = torchaudio.datasets.LIBRISPEECH(DATA_PATH, "train-clean-100", download=True)
eval_data = torchaudio.datasets.LIBRISPEECH(DATA_PATH, "dev-clean", download=True)
test_data = torchaudio.datasets.LIBRISPEECH(DATA_PATH, "test-clean", download=True)
# sample = test_data[0]

In [12]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_data, batch_size=2, collate_fn=data_collator, num_workers=4)
val_dataloader = DataLoader(eval_data, batch_size=2, collate_fn=data_collator, num_workers=4)
test_dataloader = DataLoader(test_data, batch_size=2, collate_fn=data_collator, num_workers=4)

In [13]:
decoder = Decoder()

In [14]:
ctc_converter = CTCSequenceConverter(return_type="pt")

In [15]:
L1loss = nn.L1Loss()
CTCloss = nn.CTCLoss(blank=4, zero_infinity=True)

In [20]:
for i, batch in enumerate(val_dataloader):
    teacher_results = teacher_model.w2v_encoder.w2v_model.extract_features(
            source=batch['src'], 
            # padding_mask=batch['mask'],
            padding_mask=None,
            layer=100
        )
    print(teacher_results['layer_results'][11][1].shape)
    break

torch.Size([2, 292, 292])


In [21]:
teacher_results['layer_results'][11][0].shape

torch.Size([292, 2, 1024])

In [23]:
for i, batch in enumerate(val_dataloader):
    print(batch)
    break

{'src': tensor([[ 2.3804e-03,  2.0752e-03,  1.9836e-03,  ...,  4.2725e-04,
          5.7983e-04,  1.0376e-03],
        [-1.5259e-04, -9.1553e-05, -1.8311e-04,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]]), 'mask': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 1., 1., 1.]]), 'labels': ['MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL', "NOR IS MISTER QUILTER'S MANNER LESS INTERESTING THAN HIS MATTER"]}


In [33]:
optimizer = torch.optim.Adam(student_model.parameters(), lr=LEARNING_RATE)

In [24]:
for i, batch in enumerate(val_dataloader):
    result = teacher_model.w2v_encoder.w2v_model.extract_features(
            source=batch['src'], 
            # padding_mask=batch['mask'],
            padding_mask=None,
            layer=100
        )
    
    x = result['x'].transpose(0, 1)
    x = teacher_model.w2v_encoder.proj(x)

    teacher_results = {
        "encoder_out": x,  # T x B x C
        "padding_mask": result["padding_mask"],  # B x T,
        "layer_results": result["layer_results"],
    }
    
    student_results = student_model(batch['src'], padding_mask=None)
    
    x = student_results['tr_layer_results'][0].detach()
    
    for i, layer in enumerate(teacher_tf_encoder):
        if i >= 12:
            x, z = layer(x)

    x = x.transpose(0, 1)
    if teacher_model.w2v_encoder.w2v_model.encoder.layer_norm_first:
        x = teacher_model.w2v_encoder.w2v_model.encoder.layer_norm(x)
    x = x.transpose(0, 1)
    teacher_tf_encoder_out = teacher_model.w2v_encoder.proj(x)
    
    ctc_input = student_results['encoder_out'].log_softmax(2) # -> Revise this
    logits = teacher_results['encoder_out'].transpose(0,1)
    predicted_ids = torch.argmax(logits, dim=-1)
    fused_tokens = [ctc_converter(ids) for ids in predicted_ids]
    target = torch.cat(fused_tokens)
    target_lengths = torch.tensor([len(tokens) for tokens in fused_tokens]) # -> Revise this
    
    loss1 = L1loss(student_results['layer_results'][2][0], teacher_results['layer_results'][11][0])
    loss2 = L1loss(student_results['encoder_out'], teacher_tf_encoder_out)
    loss3 = CTCloss(
                ctc_input, 
                target, 
                torch.full(size=(ctc_input.shape[1],), fill_value=ctc_input.shape[0]), # -> Revise this
                target_lengths
            )
    
    loss = loss1 + loss2 + loss3
    
    # optimizer.zero_grad()
    # loss.backward()
    # optimizer.step()
    print(loss)

    break

tensor(14.3856, grad_fn=<AddBackward0>)


In [37]:
logits = teacher_results['encoder_out'].transpose(0,1)
predicted_ids = torch.argmax(logits, dim=-1)
fused_tokens = [ctc_converter(ids) for ids in predicted_ids]
target = torch.cat(fused_tokens)
target_lengths = torch.tensor([len(tokens) for tokens in fused_tokens])
target_lengths

tensor([90, 64])

In [None]:
ctc_input = student_results['encoder_out'].log_softmax(2) # -> Revise this
logits = teacher_results['encoder_out'].transpose(0,1)
predicted_ids = torch.argmax(logits, dim=-1)
fused_tokens = [ctc_converter(ids) for ids in predicted_ids]
target = torch.cat(fused_tokens)
target_lengths = torch.tensor([len(tokens) for tokens in fused_tokens]) # -> Revise this
loss = CTCloss(
                ctc_input, 
                target, 
                torch.full(size=(ctc_input.shape[1],), fill_value=ctc_input.shape[0]), # -> Revise this
                target_lengths
            )

In [32]:
target

tensor([17, 10, 12,  6,  5, 13,  4, 30, 16, 10, 15,  6,  5, 13,  4, 10, 12,  4,
         6, 11,  5,  4,  7, 23,  8, 12,  6, 15,  5,  4,  8, 20,  4,  6, 11,  5,
         4, 17, 10, 14, 14, 15,  5,  4, 19, 15,  7, 12, 12,  5, 12,  4,  7,  9,
        14,  4, 18,  5,  4,  7, 13,  5,  4, 21, 15,  7, 14,  4,  6,  8,  4, 18,
         5, 15, 19,  8, 17,  5,  4, 11, 10, 12,  4, 21,  8, 12, 23,  5, 15,  4,
         9,  8, 13,  4, 10, 12,  4, 17, 10, 12,  6,  5, 13,  4, 30, 16, 10, 15,
         6,  5, 13, 27, 12,  4, 17,  7,  9,  9,  5, 13,  4, 15,  5, 12, 12,  4,
        10,  9,  6,  5, 13,  5, 12,  6, 10,  9, 21,  4,  6, 11,  7,  9,  4, 11,
        10, 12,  4, 17,  7,  6,  6,  5, 13,  4])

In [37]:
teacher_results['encoder_out'].shape

torch.Size([461, 2, 32])

In [39]:
batch

{'src': tensor([[-1.2207e-03, -8.5449e-04,  2.4414e-04,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.1035e-05, -9.1553e-05, -1.2207e-04,  ...,  1.5259e-04,
          -3.0518e-05, -9.1553e-05]]),
 'mask': tensor([[0., 0., 0.,  ..., 1., 1., 1.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 'labels': ['AS FOR ETCHINGS THEY ARE OF TWO KINDS BRITISH AND FOREIGN',
  'HE LAMENTS MOST BITTERLY THE DIVORCE THAT HAS BEEN MADE BETWEEN DECORATIVE ART AND WHAT WE USUALLY CALL PICTURES MAKES THE CUSTOMARY APPEAL TO THE LAST JUDGMENT AND REMINDS US THAT IN THE GREAT DAYS OF ART MICHAEL ANGELO WAS THE FURNISHING UPHOLSTERER']}

In [40]:
predicted_ids = np.argmax(student_results['encoder_out'].transpose(0,1).cpu().detach().numpy(), axis=-1)
predictions = [decoder.decode(ids) for ids in predicted_ids]
predictions

['', '']

In [None]:
for i, batch in enumerate(tqdm(val_dataloader)):
    logits = model(source=batch['src'].cuda(), padding_mask=batch['mask'])["encoder_out"].transpose(0,1)
    predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
    predictions = [decoder.decode(ids) for ids in predicted_ids]
    
    # wer_.append(wer_metric.compute(predictions=predictions, references=labels))
    # cer_.append(cer_metric.compute(predictions=predictions, references=labels))
    wer_metric.add_batch(predictions=predictions, references=batch['labels'])
    cer_metric.add_batch(predictions=predictions, references=batch['labels'])
    
wer = wer_metric.compute()
cer = cer_metric.compute()

In [23]:
predicted_ids

tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])

In [24]:
fused_tokens = torch.cat([ctc_converter(ids) for ids in predicted_ids])

In [33]:
temp = [ctc_converter(ids) for ids in predicted_ids]
temp[0].shape

torch.Size([159])

In [38]:
[len(tokens) for tokens in temp]

[159, 43]

In [39]:
target_length = torch.tensor([len(tokens) for tokens in temp])

In [40]:
target_length

tensor([159,  43])

In [25]:
fused_tokens

tensor([11,  5,  4, 11,  8, 23,  5, 14,  4,  6, 11,  5, 13,  5,  4, 18,  8, 16,
        15, 14,  4, 24,  5,  4, 12,  6,  5, 18,  4, 20,  8, 13,  4, 14, 10,  9,
         9,  5, 13,  4,  6, 16, 13,  9, 10, 23, 12,  4,  7,  9, 14,  4, 19,  7,
        13, 13,  8,  6, 12,  4,  7,  9, 14,  4, 24, 13, 16, 10, 12,  5, 14,  4,
        23,  8,  6,  7,  6,  8,  5, 12,  4,  7,  9, 14,  4, 20,  7,  6,  4, 17,
        16,  6,  6,  8,  9,  4, 23, 10,  5, 19,  5, 12,  4,  6,  8,  4, 24,  5,
         4, 15,  7, 14, 15,  5, 14,  4,  8, 16,  6,  4, 10,  9,  4,  6, 11, 10,
        19, 26,  4, 23,  5, 23, 23,  5, 13,  5, 14,  4, 20, 15,  8, 16, 13,  4,
        20,  7,  6,  6,  5,  9,  5, 14,  4, 12,  7, 16, 19,  5,  4, 12,  6, 16,
        20, 20,  4, 10,  6,  4, 10,  9,  6,  8,  4, 22,  8, 16,  4, 11, 10, 12,
         4, 24,  5, 15, 15, 22,  4, 19,  8, 16,  9, 12,  5, 15, 15,  5, 14,  4,
        11, 10, 17,  4])

In [24]:
test = torch.tensor([[1,2,3],[4,5,6]])
test.reshape(1,-1)
fused_tokens = ctc_converter(predicted_ids)
fused_tokens.shape

torch.Size([202])

In [26]:
dict = {"<s>": 0, "<pad>": 1, "</s>": 2, "<unk>": 3, "|": 4, "E": 5, 
    "T": 6, "A": 7, "O": 8, "N": 9, "I": 10, "H": 11, "S": 12, 
    "R": 13, "D": 14, "L": 15, "U": 16, "M": 17, "W": 18, "C": 19, 
    "F": 20, "G": 21, "Y": 22, "P": 23, "B": 24, "V": 25, "K": 26, 
    "'": 27, "X": 28, "J": 29, "Q": 30, "Z": 31}

look_up = np.asarray(list(dict.keys()))
look_up[fused_tokens]

array(['H', 'E', '|', 'H', 'O', 'P', 'E', 'D', '|', 'T', 'H', 'E', 'R',
       'E', '|', 'W', 'O', 'U', 'L', 'D', '|', 'B', 'E', '|', 'S', 'T',
       'E', 'W', '|', 'F', 'O', 'R', '|', 'D', 'I', 'N', 'N', 'E', 'R',
       '|', 'T', 'U', 'R', 'N', 'I', 'P', 'S', '|', 'A', 'N', 'D', '|',
       'C', 'A', 'R', 'R', 'O', 'T', 'S', '|', 'A', 'N', 'D', '|', 'B',
       'R', 'U', 'I', 'S', 'E', 'D', '|', 'P', 'O', 'T', 'A', 'T', 'O',
       'E', 'S', '|', 'A', 'N', 'D', '|', 'F', 'A', 'T', '|', 'M', 'U',
       'T', 'T', 'O', 'N', '|', 'P', 'I', 'E', 'C', 'E', 'S', '|', 'T',
       'O', '|', 'B', 'E', '|', 'L', 'A', 'D', 'L', 'E', 'D', '|', 'O',
       'U', 'T', '|', 'I', 'N', '|', 'T', 'H', 'I', 'C', 'K', '|', 'P',
       'E', 'P', 'P', 'E', 'R', 'E', 'D', '|', 'F', 'L', 'O', 'U', 'R',
       '|', 'F', 'A', 'T', 'T', 'E', 'N', 'E', 'D', '|', 'S', 'A', 'U',
       'C', 'E', '|', 'S', 'T', 'U', 'F', 'F', '|', 'I', 'T', '|', 'I',
       'N', 'T', 'O', '|', 'Y', 'O', 'U', '|', 'H', 'I', 'S', '|

---

In [None]:
logits = teacher_model(source=sample[0], padding_mask=None)["encoder_out"].transpose(0,1)

# --------------------------------------------------------------- #
predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
predictions = [decoder.decode(ids) for ids in predicted_ids]
fused_tokens = [tok[0] for tok in groupby(predicted_ids[0]) if tok[0] != 0]

print(logits)
print(logits.shape)
print(predictions)
print(fused_tokens)

In [55]:
result = teacher_model.w2v_encoder.w2v_model.extract_features(
    source=sample[0], 
    padding_mask=None,
    layer=100
)

x = result['x'].transpose(0, 1)
x = teacher_model.w2v_encoder.proj(x)

teacher_results = {
    "encoder_out": x,  # T x B x C
    "padding_mask": result["padding_mask"],  # B x T,
    "layer_results": result["layer_results"],
}

logits = teacher_results['encoder_out'].transpose(0,1)

# --------------------------------------------------------------- #
predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
predictions = [decoder.decode(ids) for ids in predicted_ids]
fused_tokens = [tok[0] for tok in groupby(predicted_ids[0]) if tok[0] != 0]

print(logits_to_match_with)
print(logits_to_match_with.shape)
print(predictions)
print(fused_tokens)

tensor([[[ 15.3397, -13.0896, -13.1977,  ...,  -5.5268,  -5.8605,  -4.5874],
         [ 15.4613, -13.2443, -13.3591,  ...,  -5.6072,  -5.9839,  -4.6637],
         [ 15.2651, -12.9915, -13.0938,  ...,  -5.4673,  -5.7874,  -4.5339],
         ...,
         [ 15.1254, -12.8308, -12.9267,  ...,  -5.3783,  -5.6602,  -4.4518],
         [ 15.0574, -12.7499, -12.8424,  ...,  -5.3364,  -5.6016,  -4.4117],
         [ 14.9019, -12.5752, -12.6623,  ...,  -5.2408,  -5.4742,  -4.3249]]])
torch.Size([1, 521, 32])
['HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE']
[11, 5, 4, 11, 8, 23, 5, 14, 4, 6, 11, 5, 13, 5, 4, 18, 8, 16, 15, 14, 4, 24, 5, 4, 12, 6, 5, 18, 4, 20, 8, 13, 4, 14, 10, 9, 9, 5, 13, 4, 6, 16, 13, 9, 10, 23, 12, 4, 7, 9, 14, 4, 19, 7, 13, 13, 8, 6, 12, 4, 7, 9, 14, 4, 24, 13, 16, 10, 12, 5, 14, 4, 23, 8, 6, 7, 6, 8, 5, 12, 4, 7, 9, 14, 4, 20, 7, 6, 4, 17, 16, 6, 6, 8, 9, 4, 23, 

In [93]:
# Torch Version
result = teacher_model.w2v_encoder.w2v_model.extract_features(
    source=sample[0], 
    padding_mask=None,
    layer=100
)

x = result['x'].transpose(0, 1)
x = teacher_model.w2v_encoder.proj(x)

teacher_results = {
    "encoder_out": x,  # T x B x C
    "padding_mask": result["padding_mask"],  # B x T,
    "layer_results": result["layer_results"],
}

logits = teacher_results['encoder_out'].transpose(0,1)

# --------------------------------------------------------------- #
predicted_ids = torch.argmax(logits, dim=-1)
# fused_tokens = torch.tensor([tok[0] for tok in groupby(predicted_ids[0]) if tok[0] != 0])
fused_tokens = converter(predicted_ids)

print(logits_to_match_with)
print(logits_to_match_with.shape)
print(predicted_ids.shape)
# print(predictions)
print(fused_tokens)

tensor([[[ 15.3397, -13.0896, -13.1977,  ...,  -5.5268,  -5.8605,  -4.5874],
         [ 15.4613, -13.2443, -13.3591,  ...,  -5.6072,  -5.9839,  -4.6637],
         [ 15.2651, -12.9915, -13.0938,  ...,  -5.4673,  -5.7874,  -4.5339],
         ...,
         [ 15.1254, -12.8308, -12.9267,  ...,  -5.3783,  -5.6602,  -4.4518],
         [ 15.0574, -12.7499, -12.8424,  ...,  -5.3364,  -5.6016,  -4.4117],
         [ 14.9019, -12.5752, -12.6623,  ...,  -5.2408,  -5.4742,  -4.3249]]])
torch.Size([1, 521, 32])
torch.Size([1, 521])
tensor([11,  5,  4, 11,  8, 23,  5, 14,  4,  6, 11,  5, 13,  5,  4, 18,  8, 16,
        15, 14,  4, 24,  5,  4, 12,  6,  5, 18,  4, 20,  8, 13,  4, 14, 10,  9,
         9,  5, 13,  4,  6, 16, 13,  9, 10, 23, 12,  4,  7,  9, 14,  4, 19,  7,
        13, 13,  8,  6, 12,  4,  7,  9, 14,  4, 24, 13, 16, 10, 12,  5, 14,  4,
        23,  8,  6,  7,  6,  8,  5, 12,  4,  7,  9, 14,  4, 20,  7,  6,  4, 17,
        16,  6,  6,  8,  9,  4, 23, 10,  5, 19,  5, 12,  4,  6,  8,  4, 24, 

---

In [56]:
student_results = student_model(sample[0], padding_mask=None)

In [59]:
student_results.keys()

dict_keys(['encoder_out', 'padding_mask', 'layer_results', 'tr_layer_results'])

In [14]:
x = teacher_results['layer_results'][11][0]
x

tensor([[[ 14.1960,  -5.6175,   2.2727,  ...,  -8.1675,   0.7943, -14.7274]],

        [[ 11.7198,  -7.6046,   3.0378,  ...,  -3.9107,   1.5182,  -4.7468]],

        [[ 10.9948,   2.3512,   0.1073,  ...,  -4.7215,   1.0011,  -2.7422]],

        ...,

        [[ 10.2565,  -3.1519,  -1.8599,  ...,  -1.9192,  -5.8508,  -0.1338]],

        [[  7.1228,   4.6386,  -5.1138,  ...,   2.3087,  -9.0622,   5.4171]],

        [[  9.7594,   1.1248,  -1.2406,  ...,   6.0834, -13.6151,   1.1677]]])

In [45]:
student_results['layer_results'][2][0]

tensor([[[-2.1155,  1.4240,  0.8586,  ...,  2.0274, -0.6959, -0.2633]],

        [[-2.0062,  0.4673,  0.3084,  ...,  1.1150, -0.8888, -0.6580]],

        [[ 0.5156,  0.9136, -0.3271,  ...,  1.7274, -1.0569, -0.4347]],

        ...,

        [[-0.2336,  0.0499, -0.7006,  ...,  1.2007, -0.9971, -0.9232]],

        [[-1.2208, -0.1211, -0.3233,  ...,  1.3017,  0.3959,  0.3928]],

        [[-2.2006, -0.5248,  0.7346,  ...,  0.8345, -0.1040,  0.3916]]],
       grad_fn=<AddBackward0>)

In [35]:
L1loss = nn.L1Loss()
loss1 = L1loss(student_results['layer_results'][2][0], teacher_results['layer_results'][11][0])
loss1

tensor(6.3655, grad_fn=<L1LossBackward0>)

In [58]:
x = student_results['tr_layer_results'][0].detach()

for i, layer in enumerate(teacher_tf_encoder):
    if i >= 12:
        x, z = layer(x)
        
x = x.transpose(0, 1)
if teacher_model.w2v_encoder.w2v_model.encoder.layer_norm_first:
    x = teacher_model.w2v_encoder.w2v_model.encoder.layer_norm(x)

x = x.transpose(0, 1)
teacher_tf_encoder_out = teacher_model.w2v_encoder.proj(x)

In [59]:
teacher_tf_encoder_out.shape

torch.Size([260, 1, 32])

In [60]:
teacher_tf_encoder_out

tensor([[[  5.8814, -11.0086, -11.1685,  ...,  -3.4234,  -3.9079,  -3.7239]],

        [[  5.9002, -11.1906, -11.3423,  ...,  -3.4618,  -3.9210,  -3.7639]],

        [[  5.9180, -11.0942, -11.2541,  ...,  -3.4218,  -4.0035,  -3.7718]],

        ...,

        [[  5.9409, -10.8936, -11.0562,  ...,  -3.3947,  -3.8838,  -3.6830]],

        [[  5.9506, -11.0034, -11.1554,  ...,  -3.3877,  -3.9427,  -3.7168]],

        [[  5.9455, -10.9126, -11.0684,  ...,  -3.4173,  -3.9037,  -3.6471]]])

In [51]:
student_results['encoder_out']

tensor([[[ 0.0133,  1.7132,  1.4479,  ...,  2.5222, -0.2823, -1.1558]],

        [[-1.6635,  0.0967,  1.3286,  ..., -0.5803, -0.0128,  1.4295]],

        [[-0.9927,  0.9092,  0.5922,  ...,  2.5229,  0.0198, -0.1499]],

        ...,

        [[ 0.5080,  2.6394,  0.2370,  ..., -2.1678, -1.7372,  0.0479]],

        [[-0.7190,  1.7406,  2.1662,  ...,  2.8859,  0.7594, -2.0479]],

        [[-1.4107,  1.5948,  0.2705,  ..., -0.8034,  1.7776, -2.7136]]],
       grad_fn=<AddBackward0>)

In [52]:
student_results['encoder_out'].shape

torch.Size([260, 1, 32])

In [54]:
loss2 = L1loss(student_results['encoder_out'], teacher_tf_encoder_out)
loss2

tensor(3.0481, grad_fn=<L1LossBackward0>)

In [72]:
ctc_input = student_results['encoder_out'].log_softmax(2) # -> Revise this

In [89]:
print(ctc_input)
print(ctc_input.shape)

tensor([[[-2.3118, -2.5565, -1.5400,  ..., -5.7079, -4.2700, -5.0472]],

        [[-3.5404, -4.1167, -3.4455,  ..., -3.4902, -4.2522, -6.0133]],

        [[-3.6018, -5.2308, -1.9200,  ..., -5.4679, -5.6073, -6.5199]],

        ...,

        [[-7.0460, -6.1014, -2.2327,  ..., -5.6954, -5.0860, -5.1001]],

        [[-4.5830, -4.4180, -2.2562,  ..., -4.4676, -4.3130, -5.1296]],

        [[-1.5732, -4.0880, -3.9460,  ..., -4.4798, -4.6704, -5.5060]]],
       grad_fn=<LogSoftmaxBackward0>)
torch.Size([260, 1, 32])


In [87]:
fused_tokens

tensor([11,  5,  4, 11,  8, 23,  5, 14,  4,  6, 11,  5, 13,  5,  4, 18,  8, 16,
        15, 14,  4, 24,  5,  4, 12,  6,  5, 18,  4, 20,  8, 13,  4, 14, 10,  9,
         9,  5, 13,  4,  6, 16, 13,  9, 10, 23, 12,  4,  7,  9, 14,  4, 19,  7,
        13, 13,  8,  6, 12,  4,  7,  9, 14,  4, 24, 13, 16, 10, 12,  5, 14,  4,
        23,  8,  6,  7,  6,  8,  5, 12,  4,  7,  9, 14,  4, 20,  7,  6,  4, 17,
        16,  6,  6,  8,  9,  4, 23, 10,  5, 19,  5, 12,  4,  6,  8,  4, 24,  5,
         4, 15,  7, 14, 15,  5, 14,  4,  8, 16,  6,  4, 10,  9,  4,  6, 11, 10,
        19, 26,  4, 23,  5, 23, 23,  5, 13,  5, 14,  4, 20, 15,  8, 16, 13,  4,
        20,  7,  6,  6,  5,  9,  5, 14,  4, 12,  7, 16, 19,  5,  4])

In [90]:
fused_tokens.shape

torch.Size([159])

In [94]:
ctc_input.shape

torch.Size([260, 1, 32])

In [95]:
CTCloss = nn.CTCLoss(blank=4)
loss3 = CTCloss(
    ctc_input, 
    fused_tokens, 
    torch.full(size=(ctc_input.shape[1],), fill_value=ctc_input.shape[0]),
    # torch.tensor(
    torch.tensor(fused_tokens.shape)
)
loss3

tensor(6.0094, grad_fn=<MeanBackward0>)

In [19]:
logits = teacher_model(source=sample[0], padding_mask=None)["encoder_out"].transpose(0,1)
predicted_ids = np.argmax(logits.cpu().detach().numpy(), axis=-1)
predictions = [decoder.decode(ids) for ids in predicted_ids]
predictions

['HE HOPED THERE WOULD BE STEW FOR DINNER TURNIPS AND CARROTS AND BRUISED POTATOES AND FAT MUTTON PIECES TO BE LADLED OUT IN THICK PEPPERED FLOUR FATTENED SAUCE ']

In [17]:
t_logits = teacher_results['encoder_out'].transpose(0,1)
predicted_ids = np.argmax(t_logits.cpu().detach().numpy(), axis=-1)
predictions = [decoder.decode(ids) for ids in predicted_ids]
t_logits

tensor([[[  2304.1094, 108910.8281, 109119.4531,  ...,  25060.8281,
           47617.1797,  40467.1797],
         [  3238.6797, 106824.4141, 107127.7812,  ...,  24026.8047,
           46491.0977,  39706.0820],
         [  1863.9941, 109858.3047, 110005.1875,  ...,  25653.4727,
           47935.5312,  40887.2773],
         ...,
         [   984.2344, 111826.3984, 111824.1562,  ...,  26681.4766,
           48915.2578,  41566.8203],
         [   531.6680, 112751.0469, 112670.9688,  ...,  27177.0664,
           49357.2578,  41902.7656],
         [  -493.8281, 114384.1719, 114133.1328,  ...,  28157.3457,
           50213.3320,  42554.9023]]])

In [80]:
# predicted_ids = torch.argmax(teacher_results['encoder_out'], dim=-1)
predicted_ids = np.argmax(teacher_results['encoder_out'].transpose(0,1).cpu().detach().numpy(), axis=-1)
predicted_ids.shape

(1, 521)

In [None]:
ctc_target = 