# Grammar correction fine-tune

In [None]:
import torch

from transformers import AutoTokenizer
from transformers import AdamW, DataCollatorWithPadding, \
       TrainingArguments, Trainer

import pandas as pd

from transformers import logging
logging.set_verbosity_error()

import utils
import importlib
importlib.reload(utils)

In [None]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

### Load checkpoint

In [None]:
GRAMFORMER_CHECKPOINT = "vennify/t5-base-grammar-correction"
grammar_model, grammar_tokenizer = utils.load_tokenizer_model(GRAMFORMER_CHECKPOINT)

### Model inputs 

In [None]:
wrong_sent = "gec: Energyie and Security Speciavissts - DT"
right_sent = "Energy and Security Specialists - DT"
tokens = grammar_tokenizer(wrong_sent, truncation=True, return_tensors='pt')
tokens

In [None]:
outputs = grammar_model.generate(**tokens)
grammar_tokenizer.decode(outputs[0])

### Fine tune test

In [None]:
raw_dataset['train'][0]

In [None]:
raw_dataset['train'].features

In [None]:
def tokenize_function(example):
    return tokenizer(example['sentence1'], example['sentence2'], truncation=True)

In [None]:
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer)

In [None]:
training_args = TrainingArguments('models')

trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer
    )

trainer.train()

In [None]:
wrong = pd.read_csv('wrong.csv')
right = pd.read_csv('correct.csv')
wrong.head(2)
right.head(2)

In [None]:
wrong = wrong.rename(columns={wrong.columns[0]: 't_grp_org'})
wrong.t_grp_org.unique()

In [None]:
wrong_sent = "VDT/Addetti Canali ViabilitÃ\xa0"
mask = wrong.t_grp_org == wrong_sent
right_sent = right[right.t_grp_org.str.startswith('VDT/Addetti Canali Via')].values[0][0]
wrong_sent, right_sent

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [None]:
tokens = tokenizer([right_sent, wrong_sent], padding='longest', return_tensors='pt')
right_ids, wrong_ids = tokens['input_ids'][0], tokens['input_ids'][1]
right_ids - wrong_ids

In [None]:
torch.cosine_similarity(right_ids, wrong_ids)

In [None]:
from difflib import SequenceMatcher

In [None]:
SequenceMatcher(None, right_sent, wrong_sent).ratio()
SequenceMatcher(None, 'Dirigenti - DT1', 'Dirigenti - DT2').ratio()