In [None]:
!git clone https://github.com/xashru/punctuation-restoration.git

In [None]:
%cd punctuation-restoration/

In [None]:
!pip download -r requirements.txt

In [None]:

!pip install gdown 

In [None]:
!rm -rf /content/punctuation-restoration/punctuation_adder
!gdown --id --folder 1PsmJvIa0MwUdryAj2pholK3_kh98nWuc

In [None]:
from transformers import *

# special tokens indices in different models available in transformers
TOKEN_IDX = {
    'bert': {
        'START_SEQ': 101,
        'PAD': 0,
        'END_SEQ': 102,
        'UNK': 100
    },
    'xlm': {
        'START_SEQ': 0,
        'PAD': 2,
        'END_SEQ': 1,
        'UNK': 3
    },
    'roberta': {
        'START_SEQ': 0,
        'PAD': 1,
        'END_SEQ': 2,
        'UNK': 3
    },
    'albert': {
        'START_SEQ': 2,
        'PAD': 0,
        'END_SEQ': 3,
        'UNK': 1
    },
}

# 'O' -> No punctuation
punctuation_dict = {'O': 0, 'COMMA': 1, 'PERIOD': 2, 'QUESTION': 3}


# pretrained model name: (model class, model tokenizer, output dimension, token style)
MODELS = {
    'xlm-roberta-large': (XLMRobertaModel, XLMRobertaTokenizer, 1024, 'roberta')
}

In [None]:
import torch.nn as nn
import torch
from torchcrf import CRF


class DeepPunctuation(nn.Module):
    def __init__(self, pretrained_model, freeze_bert=False, lstm_dim=-1):
        super(DeepPunctuation, self).__init__()
        self.output_dim = len(punctuation_dict)
        self.bert_layer = MODELS[pretrained_model][0].from_pretrained(pretrained_model)
        # Freeze bert layers
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False
        bert_dim = MODELS[pretrained_model][2]
        if lstm_dim == -1:
            hidden_size = bert_dim
        else:
            hidden_size = lstm_dim
        self.lstm = nn.LSTM(input_size=bert_dim, hidden_size=hidden_size, num_layers=1, bidirectional=True)
        self.linear = nn.Linear(in_features=hidden_size*2, out_features=len(punctuation_dict))

    def forward(self, x, attn_masks):
        if len(x.shape) == 1:
            x = x.view(1, x.shape[0])  # add dummy batch for single sample
        # (B, N, E) -> (B, N, E)
        x = self.bert_layer(x, attention_mask=attn_masks)[0]
        # (B, N, E) -> (N, B, E)
        x = torch.transpose(x, 0, 1)
        x, (_, _) = self.lstm(x)
        # (N, B, E) -> (B, N, E)
        x = torch.transpose(x, 0, 1)
        x = self.linear(x)
        return x

class DeepPunctuationCRF(nn.Module):
    def __init__(self, pretrained_model, freeze_bert=False, lstm_dim=-1):
        super(DeepPunctuationCRF, self).__init__()
        self.bert_lstm = DeepPunctuation(pretrained_model, freeze_bert, lstm_dim)
        self.crf = CRF(len(punctuation_dict), batch_first=True)

    def log_likelihood(self, x, attn_masks, y):
        x = self.bert_lstm(x, attn_masks)
        attn_masks = attn_masks.byte()
        return -self.crf(x, y, mask=attn_masks, reduction='token_mean')

    def forward(self, x, attn_masks, y):
        if len(x.shape) == 1:
            x = x.view(1, x.shape[0])  # add dummy batch for single sample
        x = self.bert_lstm(x, attn_masks)
        attn_masks = attn_masks.byte()
        dec_out = self.crf.decode(x, mask=attn_masks)
        y_pred = torch.zeros(y.shape).long().to(y.device)
        for i in range(len(dec_out)):
            y_pred[i, :len(dec_out[i])] = torch.tensor(dec_out[i]).to(y.device)
        return y_pred

In [None]:
!pip install sentencepiece

In [None]:
import re
import torch

import argparse

lstm_dim=-1
use_crf=False
language='bn'
in_file='data/test_bn.txt'
weight_path='/content/punctuation-restoration/xlm-roberta-large-bn.pt'
sequence_length=256
out_file='data/test_en_out.txt'
pretrained_model='xlm-roberta-large'

# tokenizer
tokenizer = MODELS[pretrained_model][1].from_pretrained(pretrained_model)
token_style = MODELS[pretrained_model][3]

# logs
model_save_path = weight_path

# Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if use_crf:
    deep_punctuation = DeepPunctuationCRF(pretrained_model, freeze_bert=False, lstm_dim=lstm_dim)
else:
    deep_punctuation = DeepPunctuation(pretrained_model, freeze_bert=False, lstm_dim=lstm_dim)
deep_punctuation.to(device)

deep_punctuation.load_state_dict(torch.load(model_save_path))
deep_punctuation.eval()

In [None]:
def inference(text):

    words_original_case = text.split()
    words = text.split()

    word_pos = 0
    sequence_len = sequence_length
    result = ""
    decode_idx = 0
    punctuation_map = {0: '', 1: ',', 2: '।', 3: '?'}
    if language != 'en':
        punctuation_map[2] = '।'

    while word_pos < len(words):
        x = [TOKEN_IDX[token_style]['START_SEQ']]
        y_mask = [0]

        while len(x) < sequence_len and word_pos < len(words):
            tokens = tokenizer.tokenize(words[word_pos])
            if len(tokens) + len(x) >= sequence_len:
                break
            else:
                for i in range(len(tokens) - 1):
                    x.append(tokenizer.convert_tokens_to_ids(tokens[i]))
                    y_mask.append(0)
                x.append(tokenizer.convert_tokens_to_ids(tokens[-1]))
                y_mask.append(1)
                word_pos += 1
        x.append(TOKEN_IDX[token_style]['END_SEQ'])
        y_mask.append(0)
        if len(x) < sequence_len:
            x = x + [TOKEN_IDX[token_style]['PAD'] for _ in range(sequence_len - len(x))]
            y_mask = y_mask + [0 for _ in range(sequence_len - len(y_mask))]
        attn_mask = [1 if token != TOKEN_IDX[token_style]['PAD'] else 0 for token in x]

        x = torch.tensor(x).reshape(1,-1)
        y_mask = torch.tensor(y_mask)
        attn_mask = torch.tensor(attn_mask).reshape(1,-1)
        x, attn_mask, y_mask = x.to(device), attn_mask.to(device), y_mask.to(device)

        with torch.no_grad():
            if use_crf:
                y = torch.zeros(x.shape[0])
                y_predict = deep_punctuation(x, attn_mask, y)
                y_predict = y_predict.view(-1)
            else:
                y_predict = deep_punctuation(x, attn_mask)
                y_predict = y_predict.view(-1, y_predict.shape[2])
                y_predict = torch.argmax(y_predict, dim=1).view(-1)
        for i in range(y_mask.shape[0]):
            if y_mask[i] == 1:
                result += words_original_case[decode_idx] + punctuation_map[y_predict[i].item()] + ' '
                decode_idx += 1
    return result

In [None]:
!!pip install pandarallel

In [None]:
import pandas as pd
from tqdm import tqdm
from IPython import display as ipd

from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True,nb_workers=8)
tqdm.pandas()



In [None]:
df = pd.read_csv('/content/punctuation-restoration/2.csv')

In [None]:
def space_remover(x):
    return " ".join(x.split())

In [None]:
df =df.drop('sentence',axis=1)

In [None]:
df.spelling = df.spelling.apply(space_remover)

In [None]:
df

In [None]:
df = df.rename(columns={'spelling':'predicted'})

In [None]:
df['predicted']=df['predicted'].str.replace('।','')

In [None]:
df

In [None]:
df['punctuation'] = df['predicted'].apply(lambda x : inference(x))

In [None]:
df1 = df.drop('predicted',axis=1)

In [None]:
df1.to_csv('final_with_punctuation_needs_cleaning.csv')

In [None]:
df['punctuation'][7612]

In [None]:
df[df.punctuation.str.find('?') != -1]

In [None]:
df['punctuation_truncated'] =df.punctuation.apply(lambda x : x.strip())
df['punctuation_truncated'] =df.punctuation.apply(space_remover)

In [None]:
df['punctuation_truncated'] = df['punctuation_truncated'].str.replace('।', '')
df['punctuation_truncated'] = df['punctuation_truncated'].str.replace('.', '')

In [None]:
for i in range(len(df)):
    if df['punctuation_truncated'][i][-1] == ',':
      df['punctuation_truncated'][i] = df['punctuation_truncated'][i][:-1]

In [None]:
for i in range(len(df)):
    if df['punctuation_truncated'][i][-1] not in [',','?']:
        df['punctuation_truncated'][i] = df['punctuation_truncated'][i] + '।'

In [None]:
df=df.drop(['predicted','punctuation'],axis=1)

In [None]:
df=df.rename(columns={"punctuation_truncated":"sentence"})

In [None]:
for i in range(len(df)):
    if df['sentence'][i][-1] in [',','?']:
      print(df['sentence'][i])

In [None]:
df.to_csv("final.csv",index=False)