In [1]:
# CONFIG

EXP_NUM = 4
task = "ner"
model_checkpoint = "allenai/longformer-base-4096"
max_length = 1024
stride = 128
min_tokens = 6
model_path = f'{model_checkpoint.split("/")[-1]}-{EXP_NUM}'

# TRAINING HYPERPARAMS
BS = 4
GRAD_ACC = 8
LR = 5e-5
WD = 0.01
WARMUP = 0.1
N_EPOCHS = 5

In [2]:
import pandas as pd

# read train data
train = pd.read_csv('train.csv')
train.head(1)

Unnamed: 0,id,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,predictionstring
0,423A1CA112E2,1622628000000.0,8.0,229.0,Modern humans today are always on their phone....,Lead,Lead 1,1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 1...


In [3]:
classes = train.discourse_type.unique().tolist()
classes

['Lead',
 'Position',
 'Evidence',
 'Claim',
 'Concluding Statement',
 'Counterclaim',
 'Rebuttal']

In [4]:
from collections import defaultdict
tags = defaultdict()

for i, c in enumerate(classes):
    tags[f'B-{c}'] = i
    tags[f'I-{c}'] = i + len(classes)
tags[f'O'] = len(classes) * 2
tags[f'Special'] = -100
    
l2i = dict(tags)

i2l = defaultdict()
for k, v in l2i.items(): 
    i2l[v] = k
i2l[-100] = 'Special'

i2l = dict(i2l)

N_LABELS = len(i2l) - 1 # not accounting for -100

In [8]:
from pathlib import Path

path = Path('./train')

def get_raw_text(ids):
    with open(path/f'{ids}.txt', 'r') as file: data = file.read()
    return data

In [9]:
df1 = train.groupby('id')['discourse_type'].apply(list).reset_index(name='classlist')
df2 = train.groupby('id')['discourse_start'].apply(list).reset_index(name='starts')
df3 = train.groupby('id')['discourse_end'].apply(list).reset_index(name='ends')
df4 = train.groupby('id')['predictionstring'].apply(list).reset_index(name='predictionstrings')

df = pd.merge(df1, df2, how='inner', on='id')
df = pd.merge(df, df3, how='inner', on='id')
df = pd.merge(df, df4, how='inner', on='id')
df['text'] = df['id'].apply(get_raw_text)

df.head()

Unnamed: 0,id,classlist,starts,ends,predictionstrings,text
0,0000D23A521A,"[Position, Evidence, Evidence, Claim, Counterc...","[0.0, 170.0, 358.0, 438.0, 627.0, 722.0, 836.0...","[170.0, 357.0, 438.0, 626.0, 722.0, 836.0, 101...",[0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 1...,"Some people belive that the so called ""face"" o..."
1,00066EA9880D,"[Lead, Position, Claim, Evidence, Claim, Evide...","[0.0, 456.0, 638.0, 738.0, 1399.0, 1488.0, 231...","[455.0, 592.0, 738.0, 1398.0, 1487.0, 2219.0, ...",[0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 1...,Driverless cars are exaclty what you would exp...
2,000E6DE9E817,"[Position, Counterclaim, Rebuttal, Evidence, C...","[17.0, 64.0, 158.0, 310.0, 438.0, 551.0, 776.0...","[56.0, 157.0, 309.0, 422.0, 551.0, 775.0, 961....","[2 3 4 5 6 7 8, 10 11 12 13 14 15 16 17 18 19 ...",Dear: Principal\n\nI am arguing against the po...
3,001552828BD0,"[Lead, Evidence, Claim, Claim, Evidence, Claim...","[0.0, 161.0, 872.0, 958.0, 1191.0, 1542.0, 161...","[160.0, 872.0, 957.0, 1190.0, 1541.0, 1612.0, ...",[0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 1...,Would you be able to give your car up? Having ...
4,0016926B079C,"[Position, Claim, Claim, Claim, Claim, Evidenc...","[0.0, 58.0, 94.0, 206.0, 236.0, 272.0, 542.0, ...","[57.0, 91.0, 150.0, 235.0, 271.0, 542.0, 650.0...","[0 1 2 3 4 5 6 7 8 9, 10 11 12 13 14 15, 16 17...",I think that students would benefit from learn...


In [11]:
from datasets import Dataset, load_metric

ds = Dataset.from_pandas(df)
datasets = ds.train_test_split(test_size=0.1, shuffle=True, seed=42)
datasets

DatasetDict({
    train: Dataset({
        features: ['id', 'classlist', 'starts', 'ends', 'predictionstrings', 'text', '__index_level_0__'],
        num_rows: 14034
    })
    test: Dataset({
        features: ['id', 'classlist', 'starts', 'ends', 'predictionstrings', 'text', '__index_level_0__'],
        num_rows: 1560
    })
})

In [12]:
from transformers import AutoTokenizer
    
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)

Downloading: 100%|██████████| 694/694 [00:00<00:00, 238kB/s]
Downloading: 100%|██████████| 878k/878k [00:03<00:00, 249kB/s] 
Downloading: 100%|██████████| 446k/446k [00:01<00:00, 431kB/s]  
Downloading: 100%|██████████| 1.29M/1.29M [00:01<00:00, 1.12MB/s]


In [13]:
e = [0,7,7,7,1,1,8,8,8,9,9,9,14,4,4,4]

def fix_beginnings(labels):
    for i in range(1,len(labels)):
        curr_lab = labels[i]
        prev_lab = labels[i-1]
        if curr_lab in range(7,14):
            if prev_lab != curr_lab and prev_lab != curr_lab - 7:
                labels[i] = curr_lab -7
    return labels

fix_beginnings(e)

[0, 7, 7, 7, 1, 1, 8, 8, 8, 2, 9, 9, 14, 4, 4, 4]

In [14]:
def tokenize_and_align_labels(examples):

    o = tokenizer(examples['text'], truncation=True, padding=True, return_offsets_mapping=True, max_length=max_length, stride=stride, return_overflowing_tokens=True)

    # Since one example might give us several features if it has a long context, we need a map from a feature to
    # its corresponding example. This key gives us just that.
    sample_mapping = o["overflow_to_sample_mapping"]
    # The offset mappings will give us a map from token to character position in the original context. This will
    # help us compute the start_positions and end_positions.
    offset_mapping = o["offset_mapping"]
    
    o["labels"] = []

    for i in range(len(offset_mapping)):
        sample_index = sample_mapping[i]

        labels = [l2i['O'] for i in range(len(o['input_ids'][i]))]

        for label_start, label_end, label in list(zip(examples['starts'][sample_index], examples['ends'][sample_index], examples['classlist'][sample_index])):
            for j in range(len(labels)):
                token_start = offset_mapping[i][j][0]
                token_end = offset_mapping[i][j][1]
                if token_start == label_start: 
                    labels[j] = l2i[f'B-{label}']    
                if token_start > label_start and token_end <= label_end: 
                    labels[j] = l2i[f'I-{label}']

        for k, input_id in enumerate(o['input_ids'][i]):
            if input_id in [0,1,2]:
                labels[k] = -100

        labels = fix_beginnings(labels)
                   
        o["labels"].append(labels)
        
    return o

tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True, batch_size=20000, remove_columns=datasets["train"].column_names)

  0%|          | 0/1 [00:24<?, ?ba/s]


KeyboardInterrupt: 