In [1]:
import datasets
import evaluate
import fire
import numpy as np
import os
import pandas as pd
import torch
import transformers
import wandb
import random
from pretrain_multilingual_model import create_prompt, tokenize, create_trainer

pretrained_model = "google/byt5-base"
random.seed(0)
MODEL_INPUT_LENGTH = 1024
mode = 'predict'
test_split = 'id'
model_path = 'lecslab/byt5-translation-all'

tokenizer = transformers.ByT5Tokenizer.from_pretrained(
    pretrained_model, use_fast=False
)
dataset = datasets.load_dataset('lecslab/glosslm-split')
dataset = dataset.filter(lambda x: x["transcription"] is not None and x["glosses"] is not None)
dataset = dataset.map(create_prompt)
dataset = dataset.map(
    tokenize(tokenizer, max_length=MODEL_INPUT_LENGTH), batched=True
)

dataset["train"] = dataset["train"].shuffle()

print(f"Loading model from {pretrained_model}")
model: transformers.T5ForConditionalGeneration = transformers.T5ForConditionalGeneration.from_pretrained(pretrained_model if mode == 'train' else model_path)
trainer = create_trainer(
    model,
    dataset=dataset,
    tokenizer=tokenizer,
    batch_size=2,
    lr=5e-5,
    max_epochs=10,
)

if mode == "train":
    print("Training...")
    trainer.train()
    if not os.path.exists(model_path):
        os.makedirs(model_path)
    print(f"Saving model to {model_path}")
    trainer.save_model(model_path)
    print(f"Model saved at {model_path}")

elif mode == "predict":
    print("Creating predictions...")
    assert test_split in ['id', 'ood']
    test_split = "test_" + test_split.upper()

    preds = trainer.predict(dataset[test_split].select(range(10)))
    # preds_df = pd.DataFrame({
    #     "ID": dataset[test_split]["ID"],
    #     "glottocode": dataset[test_split]["glottocode"],
    #     "is_segmented": dataset[test_split]["is_segmented"],
    #     "pred": preds,
    # })
    # preds_df.to_csv(f"{test_split}-preds.csv", index=False)
    preds

cpu
Loading model from google/byt5-base
Creating trainer...
Creating predictions...


[[   0  101  120  119   35   54   49   76   80   83   72   85   73   49
    81   72   74   48  118  100  124   49  118   49  119   49   48  115
   104  117  118   49   83   79   35  124  114  120   35  110  113  114
   122    1 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
  -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
  -100 -100 -100 -100 -100 -100 -100]
 [   0  101  120  119   35   54   49   76   80   83   72   85   73   49
    81   72   74   48  118  100  124   49  118   49  119   49   48  115
   104  117  118   49   83   79   35  124  114  120   35  110  113  114
   122    1 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
  -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
  -100 -100 -100 -100 -100 -100 -100]
 [   0  119  107  108  118   35  122  107  104  113   49   83   68   86
    87   48  118  115  104  100  110   48   54   83   79   35   76   70
    49  102  114  117  117  117  104  102  119   48   54   8

In [2]:
preds

PredictionOutput(predictions=array([[   0,  101,  120,  119,   35,   54,   49,   76,   80,   83,   72,
          85,   73,   49,   81,   72,   74,   48,  118,  100,  124,   49,
         118,   49,  119,   49,   48,  115,  104,  117,  118,   49,   83,
          79,   35,  124,  114,  120,   35,  110,  113,  114,  122,    1,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100],
       [   0,  101,  120,  119,   35,   54,   49,   76,   80,   83,   72,
          85,   73,   49,   81,   72,   74,   48,  118,  100,  124,   49,
         118,   49,  119,   49,   48,  115,  104,  117,  118,   49,   83,
          79,   35,  124,  114,  120,   35,  110,  113,  114,  122,    1,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
        

In [6]:
labels = preds.label_ids
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

tokenizer.batch_decode(labels, skip_special_tokens=True)

['but 3.IMPERF.NEG-say.s.t.-pers.PL you know',
 'but 3.IMPERF.NEG-say.s.t.-pers.PL you know',
 'this when.PAST-speak-3PL IC.tell.the.truth-3PL',
 'this when.PAST-speak-3PL IC.tell.the.truth-3PL']

In [10]:
dataset[test_split].select(range(4))['glosses']

['but 3.IMPERF.NEG-say.s.t.-pers.PL you know',
 'but 3.IMPERF.NEG-say.s.t.-pers.PL you know',
 'this when.PAST-speak-3PL IC.tell.the.truth-3PL',
 'this when.PAST-speak-3PL IC.tell.the.truth-3PL']

In [9]:
trainer.args



In [19]:
model.generation_config.max_new_tokens = 1024

In [9]:
from datasets import load_dataset

dataset = load_dataset("lecslab/glosslm-split")['test_ID'].to_pandas()

In [10]:
dataset[dataset['glottocode'] == 'arap1274']

Unnamed: 0,ID,glottocode,transcription,glosses,translation,metalang_glottocode,is_segmented,source,language,metalang
0,st_test_arap1274_0,arap1274,'oh hiihoow-kohtobei-no' you know,but 3.IMPERF.NEG-say.s.t.-pers.PL you know,"But they don't say anything , you know ?",stan1293,yes,sigmorphon_st,Arapaho,English
1,st_test_arap1274_0,arap1274,'oh hiihoowkohtobeino' you know,but 3.IMPERF.NEG-say.s.t.-pers.PL you know,"But they don't say anything , you know ?",stan1293,no,sigmorphon_st,Arapaho,English
2,st_test_arap1274_1,arap1274,nuhu' tih-'eeneti-3i' heneenei3oobei-3i',this when.PAST-speak-3PL IC.tell.the.truth-3PL,"When they speak , they tell the truth .",stan1293,yes,sigmorphon_st,Arapaho,English
3,st_test_arap1274_1,arap1274,Nuhu' tih'eeneti3i' heneenei3oobei'i3i',this when.PAST-speak-3PL IC.tell.the.truth-3PL,"When they speak , they tell the truth .",stan1293,no,sigmorphon_st,Arapaho,English
4,st_test_arap1274_2,arap1274,nehe' nebesiibehe' nih'ii-P heesi-nihii-t heih...,this my.grandfather PAST.IMPERF-pause what-say...,"My grandfather , what he said , we aren't Whit...",stan1293,yes,sigmorphon_st,Arapaho,English
...,...,...,...,...,...,...,...,...,...,...
9779,st_test_arap1274_4889,arap1274,Ne'nihP ne'niinihiinoo,then-PAST-pause then-REDUP-say-1S,"Then , then I said those things .",stan1293,no,sigmorphon_st,Arapaho,English
9780,st_test_arap1274_4890,arap1274,nii-ciinen-ou'u beneesou-'u nonoocou-'u kookoy...,IMPERF-put-3PL IC.big-0.PL IC.white-0.PL squar...,Bricklayers lay large white blocks .,stan1293,yes,sigmorphon_st,Arapaho,English
9781,st_test_arap1274_4890,arap1274,Niiciinenou'u beneesou'u nonoocou'u kookoyou'u,IMPERF-put-3PL IC.big-0.PL IC.white-0.PL squar...,Bricklayers lay large white blocks .,stan1293,no,sigmorphon_st,Arapaho,English
9782,st_test_arap1274_4891,arap1274,ne'-P cee3ei'oo-no' kee'in,then-pause IC.set.off-12 you.know?,"They say , "" We're leaving ,"" you know ?",stan1293,yes,sigmorphon_st,Arapaho,English


0      and DEM.DIST-guy-EP-ABS.SG picked.up necklace-...
2      and this-NOM.PL real-NOM.PL stories-NOM.PL be....
4                              run.PRET there-ABL far.. 
6      let’s.go-1PL.INCL seaside-GEN-to say-AOR.3SG-3...
8                 1SG.ERG on vacation.GEN send-PASS.PST 
                             ...                        
888                         2SG PROG-die-PROG 2SG-go see
890    those COMP CL.2-children CL.2-those travel.IPF...
892                                             AUG-hand
894                     so 2-die-PL NEG 14-fall-IPFV PRT
896    3.SG 2-be.PRS two-two water NEG-3-sleep-PROG P...
Name: pred, Length: 449, dtype: object

In [54]:
import pandas as pd
all_eval = {}
pred_df = pd.read_csv('../preds/byt5-translation-all-beams3-ft-lez/lezg1247-test_OOD-preds.csv')
pred_df = pred_df[(pred_df["is_segmented"] != "yes")]
pred_df['pred'] = pred_df['pred'].str.replace(r'([a-zA-Z])([,.!?;:]) ', r'\1 \2 ', regex=True)


preds = pred_df["pred"]
print(preds)

from eval import eval_accuracy, eval_morpheme_glosses
import re
import datasets

def _eval(preds, gold):
    pred_words = [str(pred).split() for pred in preds]

    gold_words = [gloss.split() for gloss in gold]
    word_eval = eval_accuracy(pred_words, gold_words)

    pred_morphemes = [re.split(r"\s|-", str(pred)) for pred in preds]
    gold_morphemes = [re.split(r"\s|-", gloss) for gloss in gold]

    eval_dict = {
        "word_level": word_eval,
        **eval_morpheme_glosses(
            pred_morphemes=pred_morphemes, gold_morphemes=gold_morphemes
        ),
    }
    return eval_dict

dataset = datasets.load_dataset('lecslab/glosslm-split', split="test_OOD")
dataset = dataset.filter(lambda x: x["is_segmented"] != "yes")
assert pred_df["ID"].tolist() == dataset["ID"]
gold = dataset["glosses"]

print(len(gold))

all_eval["all"] = _eval(preds, gold)

for lang in ['lezg1247']:
    lang_dataset = dataset.filter(lambda x: x["glottocode"] == lang)
    lang_preds = pred_df[pred_df["glottocode"] == lang]
    assert lang_preds["ID"].tolist() == lang_dataset["ID"]
    lang_preds = lang_preds["pred"]
    lang_gold = lang_dataset["glosses"]
    all_eval[lang] = _eval(lang_preds, lang_gold)

all_eval

1      this boy-DIR-GEN-ERG hand threw-AOR one pair t...
3                  himself-FOC be-PERF-PTP story-PL was.
5                           far apperance-AOR 1pl.abs...
7        come go-GEN sea-OBL-GEN bank-ERG-DAT » say-ENT.
9                    then vacation give-AOR 1sg.ERG-DAT.
                             ...                        
889                              2.PRON smoke 1.PRON say
891    like COMP 2.people like 3.SG-buy-PAST house to to
893                                             live-NEG
895                                so harpy NEG like PRT
897    3.SG 2-be.two 2-be.on 1PL-be 1.SG.SBJ-PRS-hit-...
Name: pred, Length: 449, dtype: object
449


Filter:   0%|          | 0/449 [00:00<?, ? examples/s]

{'all': {'word_level': {'average_accuracy': 0.12521672584706878,
   'accuracy': 0.20238489566081483},
  'morpheme_level': {'average_accuracy': 0.10907540355175534,
   'accuracy': 0.15565163681284744},
  'classes': {'stem': {'prec': 0.19887051230334812,
    'rec': 0.1833395314243213,
    'f1': 0.1907894736842105},
   'gram': {'prec': 0.18639262934089298,
    'rec': 0.12130996309963099,
    'f1': 0.146968426934898}},
  'bleu': 0.1519495897134019},
 'lezg1247': {'word_level': {'average_accuracy': 0.5008002114010598,
   'accuracy': 0.5485327313769752},
  'morpheme_level': {'average_accuracy': 0.508352420625257,
   'accuracy': 0.5080174927113703},
  'classes': {'stem': {'prec': 0.5585241730279898,
    'rec': 0.49943117178612056,
    'f1': 0.5273273273273273},
   'gram': {'prec': 0.5, 'rec': 0.5233265720081136, 'f1': 0.5113974231912785}},
  'bleu': 0.4602559480357618}}

In [48]:
lang_preds

174    but priest Lesile ᴄᴏᴍ.ᴘʟ child those four NEG-...
176                   43. group-3MINII PCLF.RSBL Mr Lore
178    SUBR hear-PDIR.YON father-1MINII NMLZ1-speak-N...
180    and 3AUG-be-GDIR.UP-PDIR.YON-3AUGIS place.DEM2...
182    SUBR look-go-TR-3MINIA room DEM2.DIST and see-...
                             ...                        
362    because MID-work-APPL-1AUGII house SUBR-be.big...
364    start-again-1AUGI Noipx.vil and sleep-1AUGI Ne...
366    and mista Sadrak Sunday SUBR be SUBR-be.big GE...
368    but NEG-know-1MINI-NEG SUBR 3AUG-return-again-...
370    travel-GDIR.UP-1MINI place.DEM2.DIST RL-fork-A...
Name: pred, Length: 99, dtype: object

In [49]:
lang_gold

['but priest Lesile ᴄᴏᴍ.ᴘʟ child those four NEG-PAS-be-COS-NEG .',
 '43 . group-3MINII PCLF.RSBL Mr Lore',
 'SUBR hear-PDIR.YON father-1MINII NMLZ1-speak-NMLZ.POSS-PDIR.YON monk John , PFV RL-MID-say-COS-PDIR.YON-3MINIA SUBR , “ thank-INTS-COS John , PREP NMLZ1-speak-NMLZ.POSS-2MINII DEM1.PROΧ .',
 'and 3AUG-be-GDIR.UP-PDIR.YON-3AUGIS place.DEM2.DIST RL-stand-GDIR.DOWN-APPL-PDIR.HITHER foot-3MINII AT.ΦNT.PL.PL people .',
 'SUBR look-thruout-TR-3MINIA room DEM2.DIST and see-3MINII thing SUBR do-PLCT-1AUGI-PL , neck-3MINII feel.sad-INTS because NEG-NMLZ1-say-NMLZ.POSS-PDIR.HITHER-3MINIA-NEG DAT.1AUGII NMLZ1-forbid-NMLZ.POSS-3MINII NMLZ1-jump-GDIR.IN-NMLZ.POSS-1AUGII PREP room PCLF.B&G image cross .',
 '3AUG-desire-3AUGIA IRR-hurry-path-GDIR.IN-3AUGIA bush but place long-GDIR.OUT-PDIR.HITHER .',
 'leave-PDIR.HITHER-1AUGII ᴄᴏᴍ.ᴘʟ place.DEM2.DIST and sleep-INTS-1AUGI ᴄᴏᴍ.ᴘʟ PREP field pineapple PCLF.B&G school .',
 'choose-3MINIA those mankind SUBR IRR-be be.3ᴀᴜɢII committee PCLF.RSBL-3MINI

In [13]:
import re
def strip_gloss_punctuation(glosses: str):
    """Strips any punctuation from gloss string (assuming it is surrounded by spaces)"""
    return re.sub(r"(\s|^)[^\w\s](\s|$)", " ", glosses).strip()

strip_gloss_punctuation(". PART 1.PL , the.thing !")

'PART 1.PL the.thing'