In [1]:
import argparse
import torch
import pandas as pd
import os
from datasets import load_metric
from dataclasses import dataclass
from tqdm import tqdm
from typing import Union, Dict, List, Optional
from transformers import AdamW, AutoTokenizer, T5ForConditionalGeneration, T5Config
from transformers import (
    DataCollator,
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments,
    set_seed,
)

import sys
sys.path.append('src')
from data_utils import load_pronuncation_dictionary, load_all_pronuncation_dictionaries
from ByT5_MoE import SwitchT5ForConditionalGeneration

In [2]:

@dataclass
class DataCollatorWithPadding:

    tokenizer: AutoTokenizer
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        words = [feature["input_ids"] for feature in features]
        prons = [feature["labels"] for feature in features]

        batch = self.tokenizer(words,padding=self.padding,add_special_tokens=False,
                          return_attention_mask=True,return_tensors='pt')
        pron_batch = self.tokenizer(prons,padding=self.padding,add_special_tokens=True,
                          return_attention_mask=True,return_tensors='pt')
        
        # replace padding with -100 to ignore loss correctly
        batch['labels'] = pron_batch['input_ids'].masked_fill(pron_batch.attention_mask.ne(1), -100)


        return batch
    
    
def prepare_dataset(batch):
    
    batch['input_ids'] = batch['word']
    batch['labels'] = batch['pron']
    
    return batch
    

In [8]:
checkpoint = '/scratch/lingjzhu_root/lingjzhu1/lingjzhu/g2p/mt5_8_layers_baseline/checkpoint-245000'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = T5ForConditionalGeneration.from_pretrained(checkpoint)
#model = SwitchT5ForConditionalGeneration.from_pretrained(checkpoint)
model.eval()
model.to('cuda')

test_data = load_all_pronuncation_dictionaries('data/test', prefix=True)
test_data = test_data.map(prepare_dataset)   

100%|██████████| 99/99 [00:00<00:00, 260.55it/s]


  0%|          | 0/49500 [00:00<?, ?ex/s]

In [9]:
collator = DataCollatorWithPadding(tokenizer=tokenizer,padding=True)
loader = torch.utils.data.DataLoader(test_data,batch_size=512,collate_fn=collator)

preds = []
labels = []
for batch in tqdm(loader):
    input_ids = batch['input_ids'].to('cuda')
    masks = batch['attention_mask'].to('cuda')
    label = batch['labels'].squeeze()
    labels.append(label)
    with torch.no_grad():
        pred = model.generate(input_ids,attention_mask=masks).squeeze().cpu()
    preds.append(pred)

100%|██████████| 97/97 [00:22<00:00,  4.37it/s]


In [10]:
49500/22

2250.0

In [11]:
num_params = sum(param.numel() for param in model.parameters())
num_params

67177600