In [1]:
import json
import tarfile
from pathlib import Path

from tqdm import tqdm


def list_tar_files(path):
    """List all tar files in the given path"""
    path = Path(path)
    tar_files = list(path.glob("*.tar"))
    return sorted(tar_files)


def read_webdataset(tar_path):
    """Read webdataset tar file and return contents"""
    with tarfile.open(tar_path) as tar:
        members = tar.getmembers()
        for member in members:
            if member.isfile():
                f = tar.extractfile(member)
                yield f.read()


# Example usage
path = "/home/aj/repo/rescoring/src/rescoring_webdataset_train2/"
tar_files = list_tar_files(path)
print(f"Found {len(tar_files)} tar files")

Found 282 tar files


In [2]:
num_samples, transcription_founded_in_hypotheses, improvement_is_possible = 0, 0, 0
for shard in tqdm(tar_files):
    for sample in read_webdataset(shard):
        num_samples += 1
        sample = json.loads(sample.decode())
        if sample["transcription"] in sample["hyps"]:
            transcription_founded_in_hypotheses += 1
            if sample["hyps"].index(sample["transcription"]) != 0:
                improvement_is_possible += 1

100%|██████████| 282/282 [04:31<00:00,  1.04it/s]


In [3]:
transcription_founded_in_hypotheses / num_samples

0.6814094705784575

In [4]:
improvement_is_possible / num_samples

0.14921580646054963

# Now test our model

In [5]:
import torch

device = torch.device("cuda:0")

In [6]:
from model.model import NTPModel

model = NTPModel.load_from_checkpoint(
    "checkpoints/train-minloss-epoch=0-step=1340000.ckpt",
    token_size=2000,
    d_model=512,
    n_heads=8,
    dim_feedforward=2048,
    num_layers=6,
)
model = model.to(device)
model.eval()

/home/aj/venv/lib/python3.11/site-packages/pytorch_lightning/utilities/migration/utils.py:56: The loaded checkpoint was produced with Lightning v2.5.5, which is newer than your current Lightning version: v2.5.0.post0


NTPModel(
  (embedding): Embedding(2000, 512)
  (transformer_layers): ModuleList(
    (0-5): 6 x TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
      )
      (linear1): Linear(in_features=512, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=512, bias=True)
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
  )
  (fc): Linear(in_features=512, out_features=2000, bias=False)
  (criterion): NLLLoss()
)

In [7]:
import sentencepiece as spm

sp = spm.SentencePieceProcessor(model_file="data/tokenizer/unigram_2000.model")

In [8]:
from torch import nn

from data.positional_encoder import get_positional_encoding

positional_encoding = get_positional_encoding(d_model=512, max_len=1200)


def create_model_inputs(sentence):
    # encode sentence
    tokens = sp.encode(sentence.upper())
    tokens = [sp.bos_id()] + tokens + [sp.eos_id()]
    x = torch.LongTensor(tokens).unsqueeze(0).to(device)  # (batch_size, seq_len)
    seq_len = x.size(1)
    # crop positional encoding
    pe = positional_encoding.pe[:seq_len]
    pe = pe.to(device)
    # create mask
    causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len)
    causal_mask = causal_mask.bool().to(device)
    return (x, pe, causal_mask)


def get_average_sentence_prob(x, pe, mask):
    with torch.inference_mode():
        pred = model(x, pe, mask)  # (batch_size, seq_len, token_size)
        pred = pred.exp()  # convert log-prob to probabilities
        # get probabilities of the actual tokens
        token_probs = torch.gather(pred, 2, x.unsqueeze(-1)).squeeze(-1)
    return token_probs.sum().item() / x.size(1)

In [None]:
num_samples, transcription_founded_in_hypotheses, improvement_is_possible = 0, 0, 0
for shard in tqdm(tar_files):
    shard_json = []
    for sample in read_webdataset(shard):
        num_samples += 1
        sample = json.loads(sample.decode())
        sample["avg_probs"] = []
        if sample["transcription"] in sample["hyps"]:
            transcription_founded_in_hypotheses += 1
            if sample["hyps"].index(sample["transcription"]) != 0:
                improvement_is_possible += 1
            # run model inference
            for hyp in sample["hyps"]:
                x, pe, mask = create_model_inputs(hyp)
                avg_prob = get_average_sentence_prob(x, pe, mask)
                sample["avg_probs"].append(avg_prob)
        shard_json.append(sample)
    with open(f"shard_checks/{shard.name}.json", "w") as f:
        json.dump(shard_json, f)

 54%|█████▍    | 152/282 [33:53:16<28:11:45, 780.81s/it]