In [52]:
import pandas as pd
import regex as re
import argparse, os, csv       

<xml.etree.ElementTree.ElementTree at 0x7f074228a160>

## dff

In [4]:
import regex as re

parentheses=r'\([^)(]+[^)( ] *\)'
parenthesestokeep=r'\([^)(]+[^)(.!?—\-, ] *\)'
speakertag=r'((?<=[^\w\d \",])|^) *(?![?\.,!:\-\—\[\]\(\)])(?:[A-Z\d][^\s.?!\[\]\(\)]*\s?)*:(?=[^\w]*[A-Z])'#lookahead keeps semicolon in false cases.
parenthesestoremove=r'\(([^\(\)]+[\w ]+)\):?'
parenthesesaroundsentence=r'\(([^\w]*[^\(\)]+[_^\W]+)\):?'
squarebracketsaroundsentence=r'\[([^\[\]]+)\]' #generic since it seems like the square brackets just denote unclear speech.


def to_emdash(s):
    return re.sub('--','—',s)

def strip_accents(s):
    return ''.join(c for c in unicodedata.normalize('NFKD', s)
                  if unicodedata.category(c) != 'Mn')

def removespeakertags(text):
    return re.sub(speakertag,' ',text)

def removenametags(text):
    # return re.sub(r"(?<=[a-z][.?!;]) *[ A-z.,\-']{1,25}:",' ',text)
    return re.sub(r"(?<=[a-z][.?!;])([\(\[]* *)[ A-Za-z.,\-']{1,25}:", "\g<1>",text)

def removeparentheses(text):
    return re.sub(parenthesestoremove, ' ',text)

def removeparenthesesaroundsentence(text):
    return re.sub(parenthesesaroundsentence,r'\g<1>',text)

def removedashafterpunct(text):
    return re.sub(r"([^A-Za-zÀ-ÖØ-öø-ÿ0-9 ]+ *)-+( *[^- ])",r"\g<1> \g<2>",text)

def removesquarebrackets(text):
    return re.sub(squarebracketsaroundsentence, r'\g<1>',text)

def removemusic(text):
    text = re.sub(r'♫( *[^♫ ])+ *♫', ' ',text)
    return re.sub(r'♪( *[^♫ ])+ *♪', ' ',text)

def reducewhitespaces(text):
    text=re.sub(r'(?<=[.?!,;:\—\-]) *(?=[.?!,;:\—\-])','',text)
    return re.sub(r'\s+', ' ',text)

def removeemptyquotes(text):
    text= re.sub(r"'[_^\W]*'",' ',text)
    text= re.sub(r"\([_^\W]*\)",' ',text)
    text= re.sub(r"\[[_^\W]*\]",' ',text)
    return re.sub(r'"[_^\W]*"',' ',text)

def ellipsistounicode(text):
    text = re.sub(r'\.{3,}(?= )','…',text) #ellipsis without trailing punctuation
    return re.sub(r'\.{3,}([^\w\s])','…\g<1>',text) #ellipsis with trailing punctuation

def removenonsentencepunct(text):
    return re.sub(r'[^A-Za-z\d\s$%&+=€²£¢¥…,.!?;:\-\–\—\']',' ',text)

def combinerepeatedpunct(text):
    newtext=[text,re.sub(r'([_^\W]+) *\1+','\g<1> ',text)]
    i=1
    while (newtext[0]!=newtext[1]):
        i+=1
        newtext[i%2]=re.sub(r'([_^\W]+) *\1+','\g<1> ',newtext[(1+i)%2])
    return newtext[i%2]

def endashtohyphen(text):
    return re.sub('–','-',text)

def removedashafterpunct(text):
    return re.sub(r"([^A-Za-z0-9 ]+ *)-+( *[^- ])",r"\g<1> \g<2>",text)

def pronouncesymbol(text):
    text=re.sub("\$ *([\d](\.[\d])?+)", "\g<1> dollars ",text)
    text=re.sub('\£ *([\d](\.[\d])?+)', " pounds ",text)
    text=re.sub("\$", " dollars ",text)
    text=re.sub("\£", " pounds ",text)
    text=re.sub('€', " euro ",text)
    text=re.sub('¥', " yen ",text)
    text=re.sub("¢"," cents ",text)
    text=re.sub('(?<=\d)\.(?=\d)',' point ',text)
    text=re.sub('\+',' plus ',text)
    text=re.sub('%',' percent ',text)
    text=re.sub('²',' squared ',text)
    text=re.sub('&', ' and ',text)
    return text

def stripleadingpunctuation(text):
    return re.sub(r'^[^A-Z]*','',text)

def striptrailingtext(text):
    return re.sub(r'[^!.?…;]*$','',text)

def preprocess(tedtalks):
    print('stripping accents')
    tedtalks=tedtalks.apply(strip_accents)
    print('removing speaker tags')
    tedtalks=tedtalks.apply(removespeakertags)
    print('removing name tags')
    tedtalks=tedtalks.apply(removenametags) # Remove *Mr Brown: *Hi!
    print('removing non-sentence parenthesis')
    tedtalks=tedtalks.apply(removeparentheses) # Remove (Whispers) without punct
    print('removing parenthesis')
    tedtalks=tedtalks.apply(removeparenthesesaroundsentence) #Remove -> (<- Hi Everyone! ->)<-
    print('removing square brackets')
    tedtalks=tedtalks.apply(removesquarebrackets) #Remove entire [unclear text]
    print('removing music lyrics')
    tedtalks=tedtalks.apply(removemusic)
    print('removing empty tags')
    tedtalks=tedtalks.apply(removeemptyquotes)
    print('removing non-sentence punctuation')
    tedtalks=tedtalks.apply(removenonsentencepunct)
    print('change to unicode ellipsis')
    tedtalks=tedtalks.apply(ellipsistounicode)
    print('2 hyphen to emdash')
    tedtalks=tedtalks.apply(to_emdash)
    print('endash to hyphen')
    tedtalks=tedtalks.apply(endashtohyphen)
    print('remove hyphen after punct')
    tedtalks=tedtalks.apply(removedashafterpunct)
    print('combine repeated punctuation')
    tedtalks=tedtalks.apply(combinerepeatedpunct)
    print('pronounce symbol')
    tedtalks=tedtalks.apply(pronouncesymbol)
    print('strip leading')
    tedtalks=tedtalks.apply(stripleadingpunctuation)
    print('strip trailing')
    tedtalks=tedtalks.apply(striptrailingtext)
    print('reduce whitespaces')
    tedtalks=tedtalks.apply(reducewhitespaces)
    print('--done--')
    return tedtalks

def text2csv(source:str,target:str):
    rows=dict()
    talkid=-1
    with open(source,'r') as f:
        for line in f:
            if line[:8]=='<talkid>':
                talkid=int(re.search("(?<=<talkid>)[0-9]+",line)[0])
                print(talkid)
                continue
            if line[0]!='<':
                line=re.sub('\n',' ',line)
                if not talkid in rows.keys():
                    rows[talkid]=''
                rows[talkid]+=line

    tedtalks=pd.DataFrame.from_dict({'id':rows.keys(),'transcript':rows.values()})

    tedtalks.loc[:,'transcript'] = preprocess(tedtalks.transcript.astype(str))
    tedtalks=tedtalks.loc[tedtalks.transcript.map(lambda x:len(x.split())>=1)]
    tedtalks.to_csv(target,index=None)
    
def xml2csv(source:str,target:str):
    tree=ET.parse(source)
    tree.getroot()[0]
    rows={}
    for child in tree.getroot()[0]:
        talkid=int(child[3].text)
        if not talkid in rows.keys():
            rows[talkid]=''
        for i in child.findall('seg'):
            rows[talkid]+=re.sub('\n',' ',i.text)
    tedtalks=pd.DataFrame.from_dict({'id':rows.keys(),'transcript':rows.values()})

    tedtalks.loc[:,'transcript'] = preprocess(tedtalks.transcript.astype(str))
    tedtalks=tedtalks.loc[tedtalks.transcript.map(lambda x:len(x.split())>=1)]
    tedtalks.to_csv(target,index=None)        

f=open("/home/nxingyu/data/LDC99T42/treebank_3/dysfl/dff/swbd/2/sw2005.dff")
started=False
for line in f:
    if started==False and re.match('^=+',line):
        started=True
        continue
    if started:
        print(line)



A.1:  Okay. /  {F Uh, } first, {F um, } I need to know, {F uh, } 

how do you feel [ about, + {F uh, } about ]

sending, {F uh, } an elderly, {F uh, } family member to a nursing home? /



B.2:  {D Well, } of course, [ it's, + {D you know, } it's ]

 one of the last few things in the

world you'd ever want to do, {D you know. }  

Unless it's just, {D you know, } really,

{D you know, }  and, {F uh, } [ for their, + {F uh, } {D you know, } 

for their ] own good. /



A.3:  Yes. /  Yeah. /



B.4:  I'd be very very careful [ and, + ] {F uh, } {D you know, } 

checking them out. / {F Uh, } our, -/ 

had t-, place my mother in a nursing home. /

  She had a rather massive stroke

[ about, + {F uh, } about ] --



A.5:  Uh-huh. /



B.6:  -- {F uh, } eight months ago I guess. / {C And, }  {F uh, } 

[ we were, + I was ] fortunate in that /

I was personally acquainted with the, {F uh, } 

people who, {F uh, } ran the nursing home

in our little hometown. /



A.7:  Yeah. /



B.8:  {C S

## rest

In [110]:
import unicodedata
import xml.etree.ElementTree as ET

parentheses=r'\([^)(]+[^)( ] *\)'
parenthesestokeep=r'\([^)(]+[^)(.!?—\-, ] *\)'
speakertag=r'((?<=[^\w\d \",])|^) *(?![?\.,!:\-\—\[\]\(\)])(?:[A-Z\d][^\s.?!\[\]\(\)]*\s?)*:(?=[^\w]*[A-Z])'#lookahead keeps semicolon in false cases.
parenthesestoremove=r'\(([^\(\)]+[\w ]+)\):?'
parenthesesaroundsentence=r'\(([^\w]*[^\(\)]+[_^\W]+)\):?'
squarebracketsaroundsentence=r'\[([^\[\]]+)\]' #generic since it seems like the square brackets just denote unclear speech.


def to_emdash(s):
    return re.sub('--','—',s)

def strip_accents(s):
    return ''.join(c for c in unicodedata.normalize('NFKD', s)
                  if unicodedata.category(c) != 'Mn')

def removespeakertags(text):
    return re.sub(speakertag,' ',text)

def removenametags(text):
    # return re.sub(r"(?<=[a-z][.?!;]) *[ A-z.,\-']{1,25}:",' ',text)
    return re.sub(r"(?<=[a-z][.?!;])([\(\[]* *)[ A-Za-z.,\-']{1,25}:", "\g<1>",text)

def removeparentheses(text):
    return re.sub(parenthesestoremove, ' ',text)

def removeparenthesesaroundsentence(text):
    return re.sub(parenthesesaroundsentence,r'\g<1>',text)

def removedashafterpunct(text):
    return re.sub(r"([^A-Za-zÀ-ÖØ-öø-ÿ0-9 ]+ *)-+( *[^- ])",r"\g<1> \g<2>",text)

def removesquarebrackets(text):
    return re.sub(squarebracketsaroundsentence, r'\g<1>',text)

def removemusic(text):
    text = re.sub(r'♫( *[^♫ ])+ *♫', ' ',text)
    return re.sub(r'♪( *[^♫ ])+ *♪', ' ',text)

def reducewhitespaces(text):
    text=re.sub(r'(?<=[.?!,;:\—\-]) *(?=[.?!,;:\—\-])','',text)
    return re.sub(r'\s+', ' ',text)

def removeemptyquotes(text):
    text= re.sub(r"'[_^\W]*'",' ',text)
    text= re.sub(r"\([_^\W]*\)",' ',text)
    text= re.sub(r"\[[_^\W]*\]",' ',text)
    return re.sub(r'"[_^\W]*"',' ',text)

def ellipsistounicode(text):
    text = re.sub(r'\.{3,}(?= )','…',text) #ellipsis without trailing punctuation
    return re.sub(r'\.{3,}([^\w\s])','…\g<1>',text) #ellipsis with trailing punctuation

def removenonsentencepunct(text):
    return re.sub(r'[^A-Za-z\d\s$%&+=€²£¢¥…,.!?;:\-\–\—\']',' ',text)

def combinerepeatedpunct(text):
    newtext=[text,re.sub(r'([_^\W]+) *\1+','\g<1> ',text)]
    i=1
    while (newtext[0]!=newtext[1]):
        i+=1
        newtext[i%2]=re.sub(r'([_^\W]+) *\1+','\g<1> ',newtext[(1+i)%2])
    return newtext[i%2]

def endashtohyphen(text):
    return re.sub('–','-',text)

def removedashafterpunct(text):
    return re.sub(r"([^A-Za-z0-9 ]+ *)-+( *[^- ])",r"\g<1> \g<2>",text)

def pronouncesymbol(text):
    text=re.sub("\$ *([\d](\.[\d])?+)", "\g<1> dollars ",text)
    text=re.sub('\£ *([\d](\.[\d])?+)', " pounds ",text)
    text=re.sub("\$", " dollars ",text)
    text=re.sub("\£", " pounds ",text)
    text=re.sub('€', " euro ",text)
    text=re.sub('¥', " yen ",text)
    text=re.sub("¢"," cents ",text)
    text=re.sub('(?<=\d)\.(?=\d)',' point ',text)
    text=re.sub('\+',' plus ',text)
    text=re.sub('%',' percent ',text)
    text=re.sub('²',' squared ',text)
    text=re.sub('&', ' and ',text)
    return text

def stripleadingpunctuation(text):
    return re.sub(r'^[^A-Z]*','',text)

def striptrailingtext(text):
    return re.sub(r'[^!.?…;]*$','',text)

def preprocess(tedtalks):
    print('stripping accents')
    tedtalks=tedtalks.apply(strip_accents)
    print('removing speaker tags')
    tedtalks=tedtalks.apply(removespeakertags)
    print('removing name tags')
    tedtalks=tedtalks.apply(removenametags) # Remove *Mr Brown: *Hi!
    print('removing non-sentence parenthesis')
    tedtalks=tedtalks.apply(removeparentheses) # Remove (Whispers) without punct
    print('removing parenthesis')
    tedtalks=tedtalks.apply(removeparenthesesaroundsentence) #Remove -> (<- Hi Everyone! ->)<-
    print('removing square brackets')
    tedtalks=tedtalks.apply(removesquarebrackets) #Remove entire [unclear text]
    print('removing music lyrics')
    tedtalks=tedtalks.apply(removemusic)
    print('removing empty tags')
    tedtalks=tedtalks.apply(removeemptyquotes)
    print('removing non-sentence punctuation')
    tedtalks=tedtalks.apply(removenonsentencepunct)
    print('change to unicode ellipsis')
    tedtalks=tedtalks.apply(ellipsistounicode)
    print('2 hyphen to emdash')
    tedtalks=tedtalks.apply(to_emdash)
    print('endash to hyphen')
    tedtalks=tedtalks.apply(endashtohyphen)
    print('remove hyphen after punct')
    tedtalks=tedtalks.apply(removedashafterpunct)
    print('combine repeated punctuation')
    tedtalks=tedtalks.apply(combinerepeatedpunct)
    print('pronounce symbol')
    tedtalks=tedtalks.apply(pronouncesymbol)
    print('strip leading')
    tedtalks=tedtalks.apply(stripleadingpunctuation)
    print('strip trailing')
    tedtalks=tedtalks.apply(striptrailingtext)
    print('reduce whitespaces')
    tedtalks=tedtalks.apply(reducewhitespaces)
    print('--done--')
    return tedtalks

def text2csv(source:str,target:str):
    rows=dict()
    talkid=-1
    with open(source,'r') as f:
        for line in f:
            if line[:8]=='<talkid>':
                talkid=int(re.search("(?<=<talkid>)[0-9]+",line)[0])
                print(talkid)
                continue
            if line[0]!='<':
                line=re.sub('\n',' ',line)
                if not talkid in rows.keys():
                    rows[talkid]=''
                rows[talkid]+=line

    tedtalks=pd.DataFrame.from_dict({'id':rows.keys(),'transcript':rows.values()})

    tedtalks.loc[:,'transcript'] = preprocess(tedtalks.transcript.astype(str))
    tedtalks=tedtalks.loc[tedtalks.transcript.map(lambda x:len(x.split())>=1)]
    tedtalks.to_csv(target,index=None)
    
def xml2csv(source:str,target:str):
    tree=ET.parse(source)
    tree.getroot()[0]
    rows={}
    for child in tree.getroot()[0]:
        talkid=int(child[3].text)
        if not talkid in rows.keys():
            rows[talkid]=''
        for i in child.findall('seg'):
            rows[talkid]+=re.sub('\n',' ',i.text)
    tedtalks=pd.DataFrame.from_dict({'id':rows.keys(),'transcript':rows.values()})

    tedtalks.loc[:,'transcript'] = preprocess(tedtalks.transcript.astype(str))
    tedtalks=tedtalks.loc[tedtalks.transcript.map(lambda x:len(x.split())>=1)]
    tedtalks.to_csv(target,index=None)        
            
xml2csv("/home/nxingyu2/data/2012-03/texts/en/fr/en-fr/IWSLT12.TALK.tst2010.en-fr.en.xml",        
        '/home/nxingyu2/data/ted2010.test.csv')

stripping accents
removing speaker tags
removing name tags
removing non-sentence parenthesis
removing parenthesis
removing square brackets
removing music lyrics
removing empty tags
removing non-sentence punctuation
change to unicode ellipsis
2 hyphen to emdash
endash to hyphen
remove hyphen after punct
combine repeated punctuation
pronounce symbol
strip leading
strip trailing
reduce whitespaces
--done--


In [109]:
tree=ET.parse("/home/nxingyu2/data/2012-03/texts/en/fr/en-fr/IWSLT12.TALK.tst2010.en-fr.en.xml")
tree.getroot()[0]
rows={}
# print(tre)
# for child in tree.getroot()[0]:
#     talkid=int(child[3].text)
#     if not talkid in rows.keys():
#         rows[talkid]=''
#     print(child.findall('seg'))
#     for i in child.findall('seg'):
#         rows[talkid]+=re.sub('\n',' ',i.text)

In [13]:
# !pip install kaggle==1.5.6
import os
os.environ['KAGGLE_USERNAME'] = "ngxingyu"
os.environ['KAGGLE_KEY'] = "1bbb182c67c54f035f76bf34bac11751"
%mkdir ~/data/
%cd ~/data/
!kaggle datasets download -d miguelcorraljr/ted-ultimate-dataset
!unzip ~/data/ted-ultimate-dataset.zip
from google.colab import drive
drive.mount('/content/gdrive')

mkdir: cannot create directory ‘/root/data/’: File exists
/root/data
Downloading ted-ultimate-dataset.zip to /root/data
 89% 177M/199M [00:02<00:00, 80.1MB/s]
100% 199M/199M [00:02<00:00, 95.1MB/s]


In [1]:
# !pip install datasets
# !pip install transformers
from datasets import load_dataset
from transformers import DistilBertTokenizerFast

In [2]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
tags=sorted(list('.?!,;:-—…'))
tag2id = {tag: id+1 for id, tag in enumerate(tags)}
tag2id[' ']=0
tag2id['']=-100
id2tag = {id: tag for tag, id in tag2id.items()}
tag2id,id2tag

({'!': 1,
  ',': 2,
  '-': 3,
  '.': 4,
  ':': 5,
  ';': 6,
  '?': 7,
  '—': 8,
  '…': 9,
  ' ': 0,
  '': -100},
 {1: '!',
  2: ',',
  3: '-',
  4: '.',
  5: ':',
  6: ';',
  7: '?',
  8: '—',
  9: '…',
  0: ' ',
  -100: ''})

In [3]:
import torch
from torch import LongTensor,FloatTensor
from typing import Optional
class PunctuationDataset(torch.utils.data.Dataset):
    def __init__(self, input_ids:LongTensor, attention_mask:FloatTensor, labels:Optional[LongTensor] = None) -> None:
        """
        :param input_ids: tokenids
        :param attention_mask: attention_mask, null->0
        :param labels: true labels, optional
        :return None
        """
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels
    def __getitem__(self, idx):
        """:param idx: implement index"""
        item = {'input_ids': torch.as_tensor(self.input_ids[idx],dtype=torch.long),
        'attention_mask': torch.as_tensor(self.attention_mask[idx],dtype=torch.float32),
        'labels': torch.as_tensor(self.labels[idx],dtype=torch.long)}
        return item
    def view(self,idx:int)->str:
        return ' '.join([''.join(x) for x in list(zip(tokenizer.convert_ids_to_tokens(self.input_ids[idx]),[id2tag[x] for x in self.labels[idx].tolist()]))])
    def __len__(self):
        return len(self.labels)


In [5]:
# ted=load_dataset('csv',data_files={'train':'/content/gdrive/MyDrive/ASR/ted_talks_processed.train.csv',
#                                          'dev':'/content/gdrive/MyDrive/ASR/ted_talks_processed.dev.csv',
#                                          'test':'/content/gdrive/MyDrive/ASR/ted_talks_processed.test.csv'})
ted=load_dataset('csv',data_files={'train':'/home/nxingyu/project/data/ted_talks_processed.train.csv',
                                         'dev':'/home/nxingyu/project/data/ted_talks_processed.dev.csv',
                                         'test':'/home/nxingyu/project/data/ted_talks_processed.test.csv'})
# subtitles=load_dataset('csv',data_files={'train':'/home/nxingyu/project/data/open_subtitles_processed.train.csv',
#                                          'dev':'/home/nxingyu/project/data/open_subtitles_processed.dev.csv',
#                                          'test':'/home/nxingyu/project/data/open_subtitles_processed.test.csv'})

Using custom data configuration default


Downloading and preparing dataset csv/default-5020f345641810cf (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /home/nxingyu/.cache/huggingface/datasets/csv/default-5020f345641810cf/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset csv downloaded and prepared to /home/nxingyu/.cache/huggingface/datasets/csv/default-5020f345641810cf/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2. Subsequent calls will reuse this data.


In [5]:
import pandas as pd
def punct_proportion(df):
    l=[]
    st=pd.Series(df['transcript']).str
    for c in ".?!,;:-—…":
        l.append(sum(st.count("\\"+c)))
    [print(i[0],i[1],i[1]/sum(l)) for i in zip(list(".?!,;:-—…"),l)]
    print(st.count('\w+').sum())
for split in ted:
    punct_proportion(ted[split])
    print()

. 40835 0.3979476484690198
? 3679 0.035852807609098175
! 291 0.002835870349075175
, 48727 0.4748572319566531
; 586 0.005710721733876469
: 1243 0.012113356851891554
- 3874 0.03775313310074649
— 3168 0.03087298029508644
… 211 0.00205624963455279
734841



In [None]:
for split in subtitles:
    punct_proportion(subtitles[split])
    print()

. 47443035 0.44392753421469083
? 13250829 0.12398886041482206
! 6519047 0.0609991426589736
, 24551760 0.22973239965425646
; 57285 0.0005360194346227758
: 374276 0.003502124638437183
- 10479724 0.09805945244798349
— 9564 8.949096399986432e-05
… 4185605 0.03916497557221373
419201886

. 5921757 0.4430048481920479
? 1660096 0.12419127911939411
! 813038 0.06082312660995144
, 3062385 0.22909609462708524
; 8073 0.0006039386856729181
: 57019 0.004265574125899185
- 1321067 0.09882862227992877
— 591 4.421253105818092e-05
… 523225 0.03914230382896229
52422785

. 5921304 0.4422626972770535
? 1670946 0.12480309826421737
! 834767 0.06234881793231256
, 3063118 0.2287845428570959
; 7344 0.0005485239820152251
: 44933 0.0033560495756930976
- 1324557 0.09893127451608667
— 1210 9.037500248344532e-05
… 520479 0.03887462059304226
52341671



In [10]:
degree=0
max_length=128
data=ted['dev'].map(chunk_examples_with_degree(degree), batched=True, batch_size=128,remove_columns=ted['dev'].column_names)
# tokenizer.tokenize()

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




In [488]:
# tokenizer.convert_tokens_to_ids(data['texts'][0])
# subwords = list([list(map(tokenizer.tokenize, x)) for x in data['texts']])

max_len=128
def flatten(list_of_lists):
    for list in list_of_lists:
        for item in list:
            yield item

def subword_tokenize(tokens):
    subwords = list(map(tokenizer.tokenize, tokens))
    subword_lengths = list(map(len, subwords))
    subwords = list(flatten(subwords))
    token_end_idxs = np.cumsum([0]+subword_lengths[:-1])+np.array(subword_lengths)-1
    return subwords, token_end_idxs

def position_to_mask(max_len,indices):
    o=np.zeros(max_len,dtype=np.int)
    o[indices%(max_len-2)+1]=1
    return o

def pad_ids_to_len(max_len,ids):
    o=np.zeros(max_len, dtype=np.int)
    o[:len(ids)]=np.array(ids)
    return o

def chunk_to_len(max_len,tokens,labels):
    subwords,token_end_idxs = subword_tokenize(tokens)
    teim=token_end_idxs%(max_length-2)
    split_token_end_idxs=np.array_split(token_end_idxs,(np.argwhere((teim[1:])<teim[:-1]).flatten()+1).tolist())
    split_subwords=np.array_split(subwords,token_end_idxs[np.argwhere((teim[1:])<teim[:-1]).flatten()+1].tolist())
    split_labels=np.array_split(labels[1:],(np.argwhere((teim[1:])<teim[:-1]).flatten()+1).tolist())
    ids=torch.tensor([pad_ids_to_len(max_len,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)[:max_len-2]+['[SEP]'])) for _ in split_subwords])
    masks=torch.tensor([position_to_mask(max_len,_) for _ in split_token_end_idxs])
    padded_labels=torch.tensor([pad_ids_to_len(max_len,[0]+list(_)[:max_len-2]+[0]) for _ in split_labels])
    return ids,masks,padded_labels

def chunk_to_len_batch(max_len,tokens,labels):
    batch_ids=[]
    batch_masks=[]
    batch_labels=[]
    for i,_ in enumerate(zip(tokens,labels)):
        a,b,c=chunk_to_len(max_len,*_)
        batch_ids.append(a)
        batch_masks.append(b)
        batch_labels.append(c)
    return PunctuationDataset(torch.cat(batch_ids), torch.cat(batch_masks), torch.cat(batch_labels))

In [489]:
# data=ted['dev'].map(chunk_examples_with_degree(degree), batched=True, batch_size=128,remove_columns=ted['dev'].column_names)
dev_data=chunk_to_len_batch(max_len,data['texts'],data['tags'])

In [491]:
dev_data[0]

{'input_ids': tensor([  101,  1045,  2064,  1005,  1056,  2393,  2021,  2023,  4299,  2000,
          2228,  2055,  2043,  2017,  1005,  2128,  1037,  2210,  4845,  1998,
          2035,  2115,  2814,  3198,  2017,  2065,  1037, 22519,  2071,  2507,
          2017,  2028,  4299,  1999,  1996,  2088,  2054,  2052,  2009,  2022,
          1998,  1045,  2467,  4660,  2092,  1045,  1005,  1040,  2215,  1996,
          4299,  2000,  2031,  1996,  9866,  2000,  2113,  3599,  2054,  2000,
          4299,  2005,  2092,  2059,  2017,  1005,  1040,  2022, 14180,  2138,
          2017,  1005,  1040,  2113,  2054,  2000,  4299,  2005,  1998,  2017,
          1005,  1040,  2224,  2039,  2115,  4299,  1998,  2085,  2144,  2057,
          2069,  2031,  2028,  4299,  4406,  2197,  2095,  2027,  2018,  2093,
          8996,  1045,  1005,  1049,  2025,  2183,  2000,  4299,  2005,  2008,
          2061,  2292,  1005,  1055,  2131,  2000,  2054,  1045,  2052,  2066,
          2029,  2003,  2088,  3521,  1

In [370]:
# subwords,token_end_idxs = subword_tokenize(tokens)
# token_end_idxs
len(tokens),len(labels)

(3834, 3835)

In [413]:
tokens=data['texts'][0]
labels=data['tags'][0]
subwords,token_end_idxs = subword_tokenize(tokens)
teim=token_end_idxs%(max_length-2)
split_token_end_idxs=np.array_split(token_end_idxs,(np.argwhere((teim[1:])<teim[:-1]).flatten()+1).tolist())
split_subwords=np.array_split(subwords,token_end_idxs[np.argwhere((teim[1:])<teim[:-1]).flatten()+1].tolist())
split_labels=np.array_split(labels[1:],(np.argwhere((teim[1:])<teim[:-1]).flatten()+1).tolist())
list(zip([len(_) for _ in split_token_end_idxs],[len(_) for _ in split_labels]))
# ids=torch.tensor([pad_ids_to_len(max_len,tokenizer.convert_tokens_to_ids(['[CLS]']+list(_)+['[SEP]'])) for _ in split_subwords])
# masks=torch.tensor([position_to_mask(max_len,_) for _ in split_token_end_idxs])
# labels=torch.tensor([pad_ids_to_len(max_len,[0]+list(_)+[0]) for _ in split_labels])

[(110, 110),
 (107, 107),
 (116, 116),
 (112, 112),
 (109, 109),
 (122, 122),
 (123, 123),
 (124, 124),
 (124, 124),
 (122, 122),
 (119, 119),
 (112, 112),
 (122, 122),
 (116, 116),
 (111, 111),
 (116, 116),
 (117, 117),
 (115, 115),
 (116, 116),
 (109, 109),
 (112, 112),
 (113, 113),
 (116, 116),
 (118, 118),
 (114, 114),
 (112, 112),
 (108, 108),
 (124, 124),
 (116, 116),
 (120, 120),
 (120, 120),
 (114, 114),
 (112, 112),
 (13, 13)]

In [7]:
import torch
import regex as re
import numpy as np

def text2masks(n):
    def text2masks(text):
        '''Converts single paragraph of text into a list of words and corresponding punctuation based on the degree requested.'''
        if n==0: 
            refilter="(?<=[.?!,;:\-—… ])(?=[^.?!,;:\-—… ])|$" 
        else:
            refilter="[.?!,;:\-—…]{1,%d}(?= *[^.?!,;:\-—…]+|$)|(?<=[^.?!,;:\-—…]) +(?=[^.?!,;:\-—…])"%(n)
        word=re.split(refilter,text, flags=re.V1)
        punct=re.findall(refilter,text, flags=re.V1)
        wordlist,punctlist=([] for _ in range(2))
        for i in zip(word,punct+['']):
            w,p=i[0].strip(),i[1].strip()
            if w!='':
                wordlist.append(re.sub(r'[.?!,;:\-—… ]','',w))
                punctlist.append(0 if not w[-1] in '.?!,;:-—…' else tag2id[w[-1]])
            if p!='':
                wordlist.append(p)
                punctlist.append(0)
        return(wordlist,punctlist)
    return text2masks
assert(text2masks(0)('"Hello!!')==(['"Hello'], [1]))
assert(text2masks(1)('"Hello!!')==(['"Hello', '!'], [1, 0]))
assert(text2masks(0)('"Hello!!, I am human.')==(['"Hello','I','am','human'], [2,0,0,4]))
assert(text2masks(2)('"Hello!!, I am human.')==(['"Hello', '!,','I','am','human','.'], [1,0,0,0,0,0]))
def chunk_examples_with_degree(n):
    '''Ensure batched=True if using dataset.map or ensure the examples are wrapped in lists.'''
    def chunk_examples(examples):
        output={}
        output['texts']=[]
        output['tags']=[]
        for sentence in examples['transcript']:
            text,tag=text2masks(n)(sentence)
            output['texts'].append(text)
            output['tags'].append([0]+tag if text[0]!='' else tag) # [0]+tag so that in all case, the first tag refers to [CLS]
        return output
    return chunk_examples
assert(chunk_examples_with_degree(0)({'transcript':['Hello!Bye…']})=={'texts': [['Hello', 'Bye']], 'tags': [[0, 1, 9]]})

def encode_tags(encodings, docs, max_length, overlap):
    encoded_labels = []
    doc_id=0
    label_offset=0
    for doc_offset,current_doc_id in zip(encodings.offset_mapping,encodings['overflow_to_sample_mapping']):
        if current_doc_id>doc_id:
            doc_id+=1
            label_offset=0
            print('.', end='')
        doc_enc_labels = np.ones(len(doc_offset),dtype=int) * 0 #-100
        arr_offset = np.array(doc_offset)
        if arr_offset[1,0]!=0: #Resolution if first token belongs to previous word.
            label_offset+=1
        # Gives the labels that should be assigned punctuation 
        # (after the tok before word prefixes: arr_offset :(0,i) where i>0)
        arr_mask = ((arr_offset[:,0] == 0) & (arr_offset[:,1] != 0))[1:].tolist()+[False] 
        #Get index of last non-sep/unk/pad word
        sep_idx=np.argwhere(arr_offset.sum(1)>0)[-1,0]
        arr_mask[sep_idx]=True
        doc_enc_labels[arr_mask] = docs[doc_id][label_offset:label_offset+sum(arr_mask)]
        encoded_labels.append(doc_enc_labels)
        label_offset+=sum(arr_mask[:max_length-overlap-1])-1 #-1 Assuming the last token is standalone word
    return encoded_labels

def process_dataset(dataset, split, max_length=128, overlap=63, degree=0):
    data=dataset[split].map(chunk_examples_with_degree(degree), batched=True, batch_size=max_length,remove_columns=dataset[split].column_names)
    encodings=tokenizer(data['texts'], is_split_into_words=True, return_offsets_mapping=True,
              return_overflowing_tokens=True, padding=True, truncation=True, max_length=max_length, stride=overlap)
    labels=encode_tags(encodings, data['tags'], max_length, overlap)
    encodings.pop("offset_mapping")
    encodings.pop("overflow_to_sample_mapping")
    return PunctuationDataset(torch.tensor(encodings['input_ids'],dtype=torch.long),
        torch.tensor(encodings['attention_mask'],dtype=torch.long),
        torch.tensor(labels,dtype=torch.long))
    
test_dataset=process_dataset(ted,'test',degree=1)#,10,3)
# dev_dataset=process_dataset(ted,'dev')
# train_dataset=process_dataset(ted,'train')

# for name,dataset in {'test':test_dataset ,'train':train_dataset, 'dev':dev_dataset}.items():#
#     torch.save(dataset, '/content/gdrive/MyDrive/ASR/ted-'+name+'.pt')

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))


...............................................................................................................................................................................................................................................................................................................................................................................................................

In [14]:
encodings=tokenizer(['Helloooo','I','am','a','human','meat','meat','meat'], is_split_into_words=True, return_offsets_mapping=True,
              return_overflowing_tokens=True, padding=True, truncation=True,max_length=8)

In [103]:
arr_offset=np.array(encodings['offset_mapping'][1])
arr_mask = ((arr_offset[:,0] == 0) & (arr_offset[:,1] != 0))[1:].tolist()+[False] # Gives the labels that should be assigned punctuation
# arr_mask[encodings['input_ids'][1].index(102)-1]=True
sep_idx=np.argwhere(arr_offset.sum(1)>0)[-1,0]
# arr_mask[sep_idx]=True
# arr_mask,encodings
arr_offset[0]

array([0, 0])

In [571]:
torch.hstack([torch.zeros(5),torch.zeros(5)]).astype(torch.long)

AttributeError: 'Tensor' object has no attribute 'astype'

In [642]:
# import numpy as np
# !rm 'sample.csv'
# # f=open('sample','ab')
# with open("sample.csv","ab") as f:
#     np.savetxt(f,torch.hstack([torch.ones(1,10),torch.zeros(1,10),torch.zeros(1,10)]),)
#     np.savetxt(f,torch.hstack([torch.zeros(1,10),torch.ones(1,10),torch.zeros(1,10)]),)
#     np.savetxt(f,torch.hstack([torch.zeros(1,10),torch.zeros(1,10),torch.ones(1,10)]),)
# f.close()
# np.loadtxt("./data/ted_talks_processed.test-batched.csv").shape
import torch.utils.data as data

class CSVDataset(data.Dataset):
    def __init__(self, path, chunksize, nb_samples):
        self.path = path
        self.chunksize = chunksize
        self.len = int(nb_samples / self.chunksize)

    def __getitem__(self, index):
        print(index)
        x = next(
            pd.read_csv(
                self.path,
                skiprows=(index%self.len)*self.chunksize,
                chunksize=self.chunksize,
                header=None,
                delimiter=' '))
        x = torch.from_numpy(x.values).reshape(-1,3,128)
        return x[:,0,:],x[:,1,:],x[:,2,:]
#         return x

    def __len__(self):
        return self.len
ted_train_batched=CSVDataset('./data/ted_talks_processed.train-batched.csv',1000,57676)
ted_dev_batched=CSVDataset('./data/ted_talks_processed.dev-batched.csv',1000,7439)
ted_test_batched=CSVDataset('./data/ted_talks_processed.test-batched.csv',1000,7236)
open_test_batched=CSVDataset('./data/open_subtitles_processed.test-batched.csv',1000,568113)
# test_batched[0].shape
# pd.read_csv('./data/ted_talks_processed.dev-batched.csv',header=None,delimiter=' ')


In [646]:
# ted_train_batched[5]
pd.read_csv('./data/open_subtitles_processed.test-batched.csv',header=None,delimiter=' ')

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,374,375,376,377,378,379,380,381,382,383
0,101,2054,1996,3109,2003,2008,1029,5037,4157,1998,...,0,0,0,0,0,0,0,0,0,0
1,101,2000,2022,7294,1996,6971,1010,1045,1005,1049,...,0,0,0,0,0,0,0,0,0,0
2,101,2222,2022,2204,2005,3010,2945,1012,2057,2097,...,0,0,0,0,0,0,0,0,0,0
3,101,2183,2000,3288,2032,2067,1012,3357,2185,2013,...,0,0,0,0,0,0,0,0,0,0
4,101,2054,2055,26879,17083,3051,1005,1055,4861,1029,...,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
568108,101,1010,2012,2115,2067,1529,2022,6176,1529,2017,...,0,0,0,0,0,0,0,0,0,0
568109,101,1029,2428,1029,2033,1029,2054,2055,2009,1029,...,0,0,0,0,0,0,0,0,0,0
568110,101,1029,4030,2091,1529,2054,1005,1055,1999,2023,...,0,0,0,0,0,0,0,0,0,0
568111,101,3040,11132,2055,2023,1029,1996,28997,2003,6728,...,0,0,0,0,0,0,0,0,0,0


In [520]:
from torchtext import data
dev = data.TabularDataset(
    path='/home/nxingyu/project/data/ted_talks_processed.dev.csv', format='csv',
    fields={'transcript': ('transcripts', data.Field(sequential=False)),})
# train_iter = data.BucketIterator(dataset=ted['dev'], batch_size=32)





AttributeError: 'Field' object has no attribute 'vocab'

In [119]:
dataset=process_dataset(ted,'test')
# dataset.view(10)

Loading cached processed dataset at /root/.cache/huggingface/datasets/csv/default-bdd308fcdec498e9/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2/cache-5964375f1255733d.arrow


...............................................................................................................................................................................................................................................................................................................................................................................................................

In [190]:
test_dataset.view(3007)

"[CLS]  rightful  space  in  public  life  .  and  as  we  saw  earlier  ,  one  of  the  most  critical  variables  in  determining  whether  a  movement  will  be  successful  or  not  is  a  movement  '  s  ideology  regarding  the  role  of  women  in  public  life  .  this  is  a  question  of  whether  we  '  re  moving  towards  more  democratic  and  peaceful  societies  .  in  a  world  where  so  much  change  is  happening  ,  and  where  change  is  bound  to  continue  at  an  increasingly  faster  pace  ,  it  is  not  a  question  of  whether  we  will  face  conflict  ,  but  rather  a  question  of  which  stories  will  shape  how  we  choose  to  wage  conflict  .  thank  you  .  [SEP]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD] "

In [116]:
# %pprint24
dataset.view(-1)

"[CLS]  of  eyes  on  her  naked  body, even  though, intellectual  ##ly, she  knows  it  wasn  '  t  her  body. and  she  has  frequent  panic  attacks, especially  when  someone  she  doesn  '  t  know  tries  to  take  her  picture. what  if  they  '  re  going  to  make  another  deep  ##fa  ##ke? she  thinks  to  herself. and  so  for  the  sake  of  individuals  like  rana  a  ##y  ##yu  ##b  and  the  sake  of  our  democracy, we  need  to  do  something  right  now. thank  you. [SEP]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD]  [PAD] "

In [71]:
dataset.view(24)

"[CLS]  re, often  concerned  with  performing  at  our  best  and, as  a  result  we  try  and  control  what  we  '  re  doing  to  force  the  best  performance  the. end  result  is  that  we  actually  screw  up  in. basketball  the, term  unconscious  is  used  to  describe  a  shooter  who  can  '  t  miss  and. san  antonio  spurs  star  tim  duncan  has  said  when, you  have  to  stop  and  think  that  '  s, when  you  mess  up  in. dance  the, great  choreographer  george, bala  ##nch  ##ine  used, to  urge  his  dancers  don  '  t, think  just, do  when. the  pressure  '  s  on  when, we  want  to  put  our  best  foot  forward  somewhat, ironically  we, often  try  and  control  what  we  '  re  doing  in  a  way  that  leads  [SEP] "

In [None]:
list(iter(test_dataset))[:4]

[(tensor([  101,  2627,  1303,  1110, 17136,  1118,  1297,  1223,  1103,   102]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([   0,    0,    0,    0,    0,    0,    0,    0,    0, -100])),
 (tensor([  101,  1297,  1223,  1103,  2343, 19420,  1986,  1184,  1225,   102]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([   0,    0,    0,    0,    7,    4,    2,    0,    0, -100])),
 (tensor([ 101, 1986, 1184, 1225, 1195, 1198, 1202, 2421,  112,  102]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([   4,    2,    0,    0,    0,    0,    7,    0, -100, -100])),
 (tensor([  101,  1202,  2421,   112,   188,  4267, 11553,  5822,  1142,   102]),
  tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1]),
  tensor([   0,    7,    0, -100, -100,    0, -100, -100,    0, -100]))]

In [None]:
try:
    # Disable all GPUS
    tf.config.set_visible_devices([], 'GPU')
    visible_devices = tf.config.get_visible_devices()
    for device in visible_devices:
        assert device.device_type != 'GPU'
except:
    # Invalid device or cannot modify virtual devices once initialized.
    pass

In [None]:
import tensorflow as tf

# Set the offset as 1
# Set the cls tag as offset-1 value
# Set the non zeros with the offset to sum(nonzeros)
# offset+= sum(nonzeros in first half)
# stop when offset+ len(mapping)>len(tag)


def dataset2tf(dataset,split):
    data=dataset[split].map(chunk_examples_with_degree(0), batched=True, batch_size=10,remove_columns=dataset[split].column_names)
    train_encodings=tokenizer(data['texts'], is_split_into_words=True, return_offsets_mapping=True, 
              return_overflowing_tokens=True, padding=True, truncation=True, max_length=32, stride=15,)

    train_labels=encode_tags(train_encodings, data['tags'])
    train_encodings.pop("offset_mapping")
    train_dataset = tf.data.Dataset.from_tensor_slices((
        dict(train_encodings),
        train_labels
    ))
    return train_dataset
test_dataset=dataset2tf(ted,'test')
    

Loading cached processed dataset at /home/nxingyu/.cache/huggingface/datasets/csv/default-295ea44b803e5492/0.0.0/2960f95a26e85d40ca41a230ac88787f715ee3003edaacb8b1f0891e9f04dda2/cache-47a2b754535508b1.arrow


.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.
.


In [None]:
for i in test_dataset.take(4):
    print(i)

AttributeError: 'TensorDataset' object has no attribute 'take'

In [None]:
train_encodings[146]

Encoding(num_tokens=32, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [None]:
# [len(data[i]['tags']) for i in range(144,150)]
list(zip(data[145]['texts'][2700:],[id2tag[t] if t>0 else '' for t in data[145]['tags'][2701:]]))

[('was', ''),
 ('to', ''),
 ('focus', ''),
 ('on', ''),
 ('math', ''),
 ('and', ''),
 ('science', ','),
 ('to', ''),
 ('focus', ''),
 ('on', ''),
 ('basic', ''),
 ('research', '.'),
 ('And', ''),
 ("that's", ''),
 ('what', ''),
 ("we've", ''),
 ('done', '.'),
 ('Six', ''),
 ('years', ''),
 ('ago', ''),
 ('or', ''),
 ('so', ','),
 ('I', ''),
 ('left', ''),
 ('Renaissance', ''),
 ('and', ''),
 ('went', ''),
 ('to', ''),
 ('work', ''),
 ('at', ''),
 ('the', ''),
 ('foundation', '.'),
 ('So', ''),
 ("that's", ''),
 ('what', ''),
 ('we', ''),
 ('do', '.'),
 ('And', ''),
 ('so', ''),
 ('Math', ''),
 ('for', ''),
 ('America', ''),
 ('is', ''),
 ('basically', ''),
 ('investing', ''),
 ('in', ''),
 ('math', ''),
 ('teachers', ''),
 ('around', ''),
 ('the', ''),
 ('country', ','),
 ('giving', ''),
 ('them', ''),
 ('some', ''),
 ('extra', ''),
 ('income', ','),
 ('giving', ''),
 ('them', ''),
 ('support', ''),
 ('and', ''),
 ('coaching', '.'),
 ('And', ''),
 ('really', ''),
 ('trying', ''),
 ('to

In [None]:
# chunked_ted_val=ted['val'].select([0,1]).map(chunk_examples_with_degree(0), batched=True, batch_size=10,remove_columns=ted['train'].column_names)
# encodings=tokenizer(chunked_ted['texts'], is_split_into_words=True, return_offsets_mapping=True, 
#           return_overflowing_tokens=True, padding=True, truncation=True, max_length=32, stride=15,)
# labels=encode_tags(encodings, chunked_ted['tags'])
# encodings.pop("offset_mapping")
# dataset = tf.data.Dataset.from_tensor_slices((
#     dict(encodings),
#     labels
# ))
# val_dataset

In [None]:
from transformers import TFDistilBertForTokenClassification
model = TFDistilBertForTokenClassification.from_pretrained('distilbert-base-cased', num_labels=len(tags)+1)

Some layers from the model checkpoint at distilbert-base-cased were not used when initializing TFDistilBertForTokenClassification: ['activation_13', 'vocab_projector', 'vocab_layer_norm', 'vocab_transform']
- This IS expected if you are initializing TFDistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFDistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier', 'dropout_39']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
model.summary()
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer, loss=model.compute_loss) # can also use any keras loss fn
model.fit(test_dataset.batch(16), epochs=3, batch_size=16, validation_data=test_dataset,verbose=1)

Model: "tf_distil_bert_for_token_classification_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
distilbert (TFDistilBertMain multiple                  65190912  
_________________________________________________________________
dropout_39 (Dropout)         multiple                  0         
_________________________________________________________________
classifier (Dense)           multiple                  7690      
Total params: 65,198,602
Trainable params: 65,198,602
Non-trainable params: 0
_________________________________________________________________
Epoch 1/3

KeyboardInterrupt: 

In [None]:
sample_output=model.predict(dataset)

In [None]:
sample_output['logits'].argmax(axis=2).reshape(32,-1)

array([[0, 0, 0, ..., 2, 0, 0],
       [5, 0, 0, ..., 0, 0, 0],
       [0, 5, 0, ..., 0, 0, 2],
       ...,
       [0, 0, 0, ..., 5, 0, 0],
       [0, 0, 0, ..., 0, 5, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [None]:
tokenizer(chunked_ted['texts'][0], is_split_into_words=True, return_offsets_mapping=True, return_overflowing_tokens=True, padding=True, truncation=True, stride=8, max_length=16)

{'input_ids': [[101, 1448, 2247, 4427, 1107, 1381, 5227, 2021, 15474, 8449, 1105, 8703, 170, 1299, 1150, 102], [101, 2021, 15474, 8449, 1105, 8703, 170, 1299, 1150, 1691, 10108, 102, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0]], 'offset_mapping': [[(0, 0), (0, 3), (0, 6), (0, 9), (0, 2), (0, 4), (0, 2), (0, 6), (0, 8), (0, 10), (0, 3), (0, 8), (0, 1), (0, 3), (0, 3), (0, 0)], [(0, 0), (0, 6), (0, 8), (0, 10), (0, 3), (0, 8), (0, 1), (0, 3), (0, 3), (0, 8), (0, 10), (0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]], 'overflow_to_sample_mapping': [0, 0]}

In [None]:
import tensorflow as tf
encodings.pop("offset_mapping")

dataset = tf.data.Dataset.from_tensor_slices((
    dict(encodings),
    labels
))

In [None]:
for data in dataset.take(1):
    print(data)

({'input_ids': <tf.Tensor: shape=(25,), dtype=int32, numpy=
array([  101,  1448,  2247,  4427,  1107,  1381,  5227,  2021, 15474,
        8449,  1105,  8703,   170,  1299,  1150,  1691,   102,     0,
           0,     0,     0,     0,     0,     0,     0], dtype=int32)>, 'token_type_ids': <tf.Tensor: shape=(25,), dtype=int32, numpy=
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0], dtype=int32)>, 'attention_mask': <tf.Tensor: shape=(25,), dtype=int32, numpy=
array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
       0, 0, 0], dtype=int32)>}, <tf.Tensor: shape=(25,), dtype=int32, numpy=
array([-100,    0,    0,    0,    0,    2,    0,    0,    2,    0,    0,
          0,    0,    0,    0,    0, -100, -100, -100, -100, -100, -100,
       -100, -100, -100], dtype=int32)>)


In [None]:
tokenizer(ted['train'][0]['transcript'], return_offsets_mapping=True, padding=True, truncation=True)

{'input_ids': [101, 20062, 117, 1166, 122, 1553, 126, 3775, 1234, 2541, 4223, 4139, 119, 1130, 2593, 117, 1234, 1132, 2257, 1106, 10556, 1147, 1583, 117, 2128, 1166, 1405, 1550, 8940, 119, 4288, 117, 1443, 170, 4095, 117, 1132, 1103, 1211, 7386, 1105, 8018, 5256, 795, 1133, 1136, 1198, 1121, 1103, 5119, 2952, 18552, 117, 1133, 1121, 1103, 1510, 8362, 20080, 27443, 3154, 1115, 8755, 1138, 1113, 1147, 2073, 119, 1109, 5758, 1104, 1594, 1817, 1482, 1120, 170, 1842, 1344, 3187, 1111, 1103, 1718, 1104, 6438, 1105, 18560, 2645, 119, 4288, 117, 1112, 1195, 1169, 1178, 5403, 117, 1209, 1631, 4472, 117, 4963, 1105, 1120, 3187, 119, 1252, 1175, 1110, 1363, 2371, 119, 1109, 3068, 1104, 1920, 1115, 1482, 3531, 1107, 1147, 2073, 1169, 1138, 170, 1167, 2418, 2629, 1113, 1147, 1218, 118, 1217, 1190, 1121, 1103, 4315, 5758, 1104, 1594, 1115, 1152, 1138, 1151, 5490, 1106, 119, 1573, 2140, 117, 1482, 1169, 1129, 4921, 1118, 3258, 117, 5343, 6486, 1158, 1219, 1105, 1170, 4139, 119, 1130, 1349, 117, 146, 

In [None]:
#create_pretraining_data.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import random
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
import tensorflow as tf

flags = tf.compat.v1.flags
FLAGS = flags.FLAGS

In [None]:
class TrainingInstance(object):
    def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.is_random_next = is_random_next
        self.masked_lm_positions = masked_lm_positions
        self.masked_lm_labels = masked_lm_labels

<module 'tensorflow.python.platform.flags' from '/home/nxingyu/miniconda3/envs/NLP/lib/python3.8/site-packages/tensorflow/python/platform/flags.py'>

In [None]:
import logging
root_logger= logging.getLogger()
root_logger.setLevel(logging.DEBUG)
handler = logging.FileHandler('test.log', 'w', 'utf-8')
handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(message)s"))
root_logger.addHandler(handler)

import subprocess
root_logger.info(subprocess.check_output(['git', 'describe', '--always']))
import datetime
root_logger.warning(datetime.datetime.now().strftime('%Y-%m-%d %H:%M'))

2020-12-21 11:43:00,756 [INFO] b'fcfb139\n'
