In [None]:
!pip install -q datasets transformers deepspeed
!pip install wandb
!pip install pytorch_lightning

In [None]:
# hide
import datetime
import os
from pathlib import Path
import random
from typing import Any, Dict, List, Optional, Tuple
import re
import pandas as pd
from datetime import datetime
from sklearn.model_selection import train_test_split

import datasets
from deepspeed.ops import adam
import matplotlib.pyplot as plt
import numpy as np
import subprocess

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split, DataLoader
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torchvision import models
from tqdm.auto import tqdm
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import wandb
import string

from torch.utils.data import Dataset
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.optim import Adam
from torch.utils.data import DataLoader
import tqdm

import nltk
nltk.download('punkt')
from nltk.translate.gleu_score import sentence_gleu

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
os.environ["TOKENIZERS_PARALLELISM"] = "true"
# Wandb KEY here
os.environ["WANDB_API_KEY"] = ""

wandb.login()
# hf_token = user_secrets.get_secret("wandb_api_key")

print(torch.__version__, transformers.__version__)

In [None]:
def correct_sent(sentence, edits):
    sent = sentence.split()
    res = ''
    insert_map ={}
    for edit in edits:
        loc = [int(item) for item in edit[0].split()]
        left = loc[0]
        right = loc[1]
        if left < 0 or right < 0:
            return sentence
        if left < right:
            sent[left] = edit[2]
        if left == right:
            insert_map[left] = edit[2]
    sorted_map = sorted(insert_map.items(),reverse=True)
    for key,value in sorted_map:
        sent.insert(key, value)
    return ' '.join(sent)

def sentence_reform(sent):
    sent = re.sub(r"n't","not",sent)
    sent = re.sub(r"'m","am",sent)
    return sent

def parse_m2_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()

    sentences = []
    edits = []
    edited_sentences = []
    current_sentence = []
    current_edits = []

    for line in lines:
        if line.startswith('S '):
            if current_sentence:
                curr = sentence_reform(' '.join(current_sentence))
                sentences.append(curr)
                edits.append(current_edits)
                edited_sentences.append(correct_sent(curr, current_edits))
                current_sentence = []
                current_edits = []
            current_sentence = line[2:].split()
        elif line.startswith('A '):
            current_edits.append(line[2:].strip().split('|||'))

        elif line.strip() == '':
            if current_sentence:
                curr = sentence_reform(' '.join(current_sentence))
                sentences.append(curr)
                edits.append(current_edits)
                edited_sentences.append(correct_sent(curr, current_edits))
                current_sentence = []
                current_edits = []

    if current_sentence:
        sentences.append(' '.join(current_sentence))
        edits.append(current_edits)
        edited_sentences.append(correct_sent(' '.join(current_sentence), current_edits))
    return sentences, edits,edited_sentences

file_path = '/content/ABC.train.gold.bea19.m2'

sentences, edits, edited_sentences = parse_m2_file(file_path)
sent_df = pd.DataFrame(list(zip(sentences, edited_sentences)),columns = ['original','corrected'])
sent_df = sent_df[sent_df['corrected'].str.split().str.len() > 3]

test_path = '/content/ABCN.dev.gold.bea19.m2'
test_sent,test_edits, test_edited = parse_m2_file(test_path)
true_test = pd.DataFrame(list(zip(test_sent, test_edited)),columns = ['original','corrected'])

### Tokenization
\<bos> + input_sent + \<clip> + output_sent + \<eos>

In [None]:
# training parameters
BATCH_SIZE = 16
MAX_LEN = 128
LR = 1e-3
EPOCHS = 5

In [None]:
class SentData(Dataset):
    def __init__(self, inp, oup,tokenizer):
        self.input = inp
        self.output = oup
        self.sent = []
        for idx,inp in enumerate(self.input):
            oup = self.output[idx]
            self.sent.append("<startofstring> " + inp +" <clip> " + oup + " <endofstring>")
        self.sent_encoded = tokenizer(self.sent,max_length=int(MAX_LEN/2), truncation=True, padding=True, return_tensors="pt")
        self.input_ids = self.sent_encoded['input_ids']
        self.attention_mask = self.sent_encoded['attention_mask']

    def __len__(self):
        return len(self.input)

    def __getitem__(self, idx):
        return (self.input_ids[idx], self.attention_mask[idx])

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.add_special_tokens({"pad_token": "<pad>",
               "bos_token": "<startofstring>",
               "eos_token": "<endofstring>"})
tokenizer.add_tokens(["<clip>"])

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

train_df, test_df = train_test_split(sent_df, test_size=0.2, random_state=42)
trainData = SentData(train_df['original'].tolist(), train_df['corrected'].tolist(), tokenizer)
testData = SentData(test_df['original'].tolist(), test_df['corrected'].tolist(), tokenizer)

trainData =  DataLoader(trainData, batch_size=BATCH_SIZE)
testData = DataLoader(testData, batch_size=BATCH_SIZE)

### Training

In [None]:
current_datetime = str(datetime.now())
wandb.init(project='GPT2LMHeadModel_GEC',name = current_datetime,
    config={
    "language_model": 'GPT2LMHeadModel_GEC',
    "learning_rate": LR,
    "batch_size": BATCH_SIZE,
    "max_len": MAX_LEN
})
#

def train(trainData,testData, model, optim):
    epochs = EPOCHS
    global_step = 0
    gloabal_batch = 0
    for epoch in tqdm.tqdm(range(epochs)):
        for xx, aa in trainData:
            X = xx.to(device)
            a = aa.to(device)
            optim.zero_grad()
            loss = model(X, attention_mask=a, labels=X).loss
            loss.backward()
            optim.step()
            gloabal_batch += 1
            global_step += BATCH_SIZE
            if gloabal_batch % 50 == 0:
              wandb.log({"train_loss": loss.item()},step=global_step)
        wandb.log({"epoch": epoch},step=global_step)
        model.eval()  # 切换到评估模式
        with torch.no_grad():
            valid_loss = []
            for X, a in testData:
                X = X.to(device)
                a = a.to(device)
                outputs = model(X, attention_mask=a, labels=X)
                valid_loss.append(outputs.loss.item())
            v_loss = sum(valid_loss) / len(valid_loss)
            wandb.log({"valid_loss": v_loss},step=global_step)
            for first_batch in testData:
              data = []
              tensors, a = first_batch
              tensors = tensors.to(device)
              a = a.to(device)
              for idx in range(len(tensors)):
                X = tensors[idx]
                decode = tokenizer.decode(X)
                sent = decode.split('<clip>')[0]
                input = sent + '<clip>'
                input = tokenizer(input, return_tensors="pt")
                input = input.to(device)
                output = model.generate(input["input_ids"], attention_mask=input["attention_mask"])
                output = tokenizer.decode(output[0], skip_special_tokens=True)
                data.append([sent[15:],output])
              table = wandb.Table(data=data, columns=['Original Sentence','Generate Sentence'])
              wandb.log({f"epoch{epoch} Generatrion": table})
              break
        model.train()  # 切换回训练模式
        torch.save(model.state_dict(), "model_state.pt")

device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

model = GPT2LMHeadModel.from_pretrained("gpt2",max_length=128)
model.resize_token_embeddings(len(tokenizer))
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.pad_token_id

model = model.to(device)

model.train()

optim = Adam(model.parameters(), lr=LR)

print("training .... ")
train(trainData, testData, model, optim)
wandb.finish()

In [ ]:
def text_pruning(pred, original):
    if pred.strip()[-1] not in string.punctuation:
        i = len(original)-1
        thresh = min(int(len(original) * 0.2),i)
        while i - thresh < len(pred):
            if pred[i - thresh] in string.punctuation:
                pred = pred[:i - thresh + 1]
                break
            i += 1
    return pred

In [None]:
def infer(inp):
    input = "<startofstring> "+inp+" <clip>"
    input = tokenizer(input,max_length=64, truncation=True, padding=True, return_tensors="pt")
    X = input["input_ids"].to(device)
    a = input["attention_mask"].to(device)
    try:
      output = model.generate(X, attention_mask=a )
    except:
      return inp
    output = tokenizer.decode(output[0],skip_special_tokens=True)
    try:
        res = output.split('<clip>')[1]
        res = text_pruning(inp,res)
    except:
        res = output
    return res

In [None]:
samples = ['I loves my mom',
           'We should protect the enviorment',
           'Last week I go to party']
for s in samples:
  print(infer(s))


In [None]:
torch.save(model, "cancat_gpt_v3.pth")
model = torch.load("cancat_gpt_v3.pth",map_location=device)

In [None]:
# Predict Process
trueTest = true_test['original'].tolist()
predict = []
for i in tqdm.tqdm(trueTest):
    pred = infer(i)
    predict.append(pred)

In [None]:
test_output = pd.DataFrame({'predict':predict})
test_output['test'] = true_test["original"]
test_output['label'] = true_test["corrected"]
test_output.to_csv('test_output_lmhead3.csv')

### Evaluation

In [None]:
# GLEU score
# takes lists of decoded prediction and reference labels and compute gleu
def compute_metrics(pred,label):
    preds = pred
    labels = label
    preds = [str(pred).strip() for pred in preds]
    labels = [str(label).strip() for label in labels]
    # sentence_gleu takes list(list(word)) as reference, and list(word) as prediction: https://www.nltk.org/api/nltk.translate.gleu_score.html
    gleu_scores = [sentence_gleu([ref.split()], pred.split()) for pred, ref in zip(preds, labels)]
    # print(gleu_scores)
    result = {"gleu": np.mean(gleu_scores) * 100} # *100 so that gleu score in range [0,100] instead of [0,1]
    return {k: round(v, 4) for k, v in result.items()}

In [None]:
### Evaluate performance using Errant Scorer
# takes three separate csv files or three pandas dataframes/series
def get_errant_result(inputs, labels, preds, is_csv_filename = False, save_score = True, model_and_data_name = ''):
    if not is_csv_filename:
        # save to csv, ignoring row names
        inputs.to_csv('errant_inputs.csv', header = False)
        labels.to_csv('errant_labels.csv', header = False)
        preds.to_csv('errant_preds.csv', header = False)
    
        inputs_csv = 'errant_inputs.csv'
        labels_csv = 'errant_labels.csv'
        preds_csv = 'errant_preds.csv'

    else:
        inputs_csv = inputs
        labels_csv = labels
        preds_csv = preds

    create_ref_m2 = 'errant_parallel -orig ' + inputs_csv + ' -cor ' + labels_csv + ' -out true_labs_val.m2'
    subprocess.run(create_ref_m2, capture_output=True, text=True, input=None, check=True, shell=True)
    
    create_pred_m2 = 'errant_parallel -orig ' + inputs_csv + ' -cor ' + preds_csv + ' -out pred_labs_val.m2'
    subprocess.run(create_pred_m2, capture_output=True, text=True, input=None, check=True, shell=True)
    
    generate_score = 'errant_compare -hyp pred_labs_val.m2 -ref true_labs_val.m2 -cat 3'
    result = subprocess.run(generate_score, capture_output=True, text=True, input=None, check=True, shell=True)
    print(result.stdout)

    if save_score:
        # Get the current date and time
        current_time = datetime.now()
        # Convert the current time to a string
        current_time_str = current_time.strftime("%H-%M-%S")
        save_name = 'errant_eval_' + model_and_data_name + current_time_str + '.txt'
        with open(save_name, 'w') as f:
            # Write the result to the file
            f.write(result.stdout)

    return result.stdout

In [None]:
valid_df = pd.read_csv('./test_output_lmhead3.csv')
label = valid_df['label'].tolist()
pred = valid_df['predict'].tolist()
ori = valid_df['test'].tolist()
gleu_score = compute_metrics(pred,label)
print('gleu_score: ', gleu_score)
res = get_errant_result(valid_df['test'], valid_df['label'], valid_df['predict'], False, True, 'gpt2')