In [2]:
import os
import functools

import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

In [3]:
DATA_PATH = "./data/"

class Config:
    TRANSFORMER_CHECKPOINT = "allenai/longformer-base-4096"
    BATCH_SIZE = 4
    MAX_LENGTH = 4096
    STRIDE = 64
    NUM_FOLDS = 5
    RANDOM_STATE = 42
    NUM_WORKERS = 2

In [4]:
df_train = pd.read_csv(DATA_PATH + "train.csv")
df_train["predictionstring"] = df_train.predictionstring.apply(lambda str: [int(item) for item in str.split()])
df_train["discoursetype"] = df_train.loc[:, "discourse_type"]
df_train.head()

Unnamed: 0,id,discourse_id,discourse_start,discourse_end,discourse_text,discourse_type,discourse_type_num,predictionstring,discoursetype
0,423A1CA112E2,1622628000000.0,8.0,229.0,Modern humans today are always on their phone....,Lead,Lead 1,"[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14...",Lead
1,423A1CA112E2,1622628000000.0,230.0,312.0,They are some really bad consequences when stu...,Position,Position 1,"[45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 5...",Position
2,423A1CA112E2,1622628000000.0,313.0,401.0,Some certain areas in the United States ban ph...,Evidence,Evidence 1,"[60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 7...",Evidence
3,423A1CA112E2,1622628000000.0,402.0,758.0,"When people have phones, they know about certa...",Evidence,Evidence 2,"[76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 8...",Evidence
4,423A1CA112E2,1622628000000.0,759.0,886.0,Driving is one of the way how to get around. P...,Claim,Claim 1,"[139, 140, 141, 142, 143, 144, 145, 146, 147, ...",Claim


In [5]:
df_train_onehot = pd.get_dummies(df_train, columns=["discoursetype"])
df_train_onehot = df_train_onehot.groupby(["id"], as_index=False).sum()
label_cols = [c for c in df_train_onehot.columns if c.startswith("discoursetype_") or c == "id"]
df_train_onehot = df_train_onehot[label_cols]
df_train_onehot.head()

Unnamed: 0,id,discoursetype_Claim,discoursetype_Concluding Statement,discoursetype_Counterclaim,discoursetype_Evidence,discoursetype_Lead,discoursetype_Position,discoursetype_Rebuttal
0,0000D23A521A,1,1,1,3,0,1,1
1,00066EA9880D,3,1,0,3,1,1,0
2,000E6DE9E817,5,1,1,3,0,1,1
3,001552828BD0,4,0,0,4,1,1,0
4,0016926B079C,7,0,0,3,0,1,0


In [6]:
def create_multilabel_targets(data_row, label_cols):
    targets = []
    for col in label_cols:
        targets.append(data_row[col])
    return targets

In [7]:
# For each essay, there can be multiple discourse_types, the target which is discourse type is thus multilabel
# For each essay this multilabel target column needs to be created first 

if "id" in label_cols:
    label_cols.remove("id")
df_train_onehot["targets"] = df_train_onehot.apply(lambda row: create_multilabel_targets(row, label_cols), axis=1)
df_train_onehot["targets_str"] = df_train_onehot.targets.apply(lambda x: ",".join([str(item) for item in x]))
df_train_onehot["kfold"] = -1
df_train_onehot.head()

Unnamed: 0,id,discoursetype_Claim,discoursetype_Concluding Statement,discoursetype_Counterclaim,discoursetype_Evidence,discoursetype_Lead,discoursetype_Position,discoursetype_Rebuttal,targets,targets_str,kfold
0,0000D23A521A,1,1,1,3,0,1,1,"[1, 1, 1, 3, 0, 1, 1]",1113011,-1
1,00066EA9880D,3,1,0,3,1,1,0,"[3, 1, 0, 3, 1, 1, 0]",3103110,-1
2,000E6DE9E817,5,1,1,3,0,1,1,"[5, 1, 1, 3, 0, 1, 1]",5113011,-1
3,001552828BD0,4,0,0,4,1,1,0,"[4, 0, 0, 4, 1, 1, 0]",4004110,-1
4,0016926B079C,7,0,0,3,0,1,0,"[7, 0, 0, 3, 0, 1, 0]",7003010,-1


In [8]:
# we need to split the train data into k folds using multilabel stratification
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

# This method uses the iterstrat library for multilabel stratification
def iterstrat_multilabel_stratified_kfold_cv_split(df_train_onehot):
    mskf = MultilabelStratifiedKFold(n_splits=Config.NUM_FOLDS, shuffle=True, random_state=Config.RANDOM_STATE)    
    df_targets = df_train_onehot[label_cols]
    for fold, (train_index, val_index) in enumerate(mskf.split(df_train_onehot["id"], df_targets)):        
        df_train_onehot.loc[val_index, "kfold"] = fold
    return df_train_onehot

In [9]:
from skmultilearn.model_selection import IterativeStratification

# This method uses the skmultilearn library for multilabel stratification
def skml_multilabel_stratified_kfold_cv_split(df_train_onehot):
    mskf = IterativeStratification(n_splits=Config.NUM_FOLDS, order=1)
    X = df_train_onehot["id"]
    y = df_train_onehot[label_cols]
    for fold, (train_index, val_index) in enumerate(mskf.split(X, y)):        
        df_train_onehot.loc[val_index, "kfold"] = fold
    return df_train_onehot

In [10]:
df_train_onehot = skml_multilabel_stratified_kfold_cv_split(df_train_onehot)
df_train_onehot.kfold.value_counts()

0    3122
3    3120
2    3118
1    3117
4    3117
Name: kfold, dtype: int64

In [11]:
from skmultilearn.model_selection.measures import get_combination_wise_output_matrix
from collections import Counter

def get_train_val_split_stats(df):
    counts = {}
    for fold in range(Config.NUM_FOLDS):
        y_train = df[df.kfold != fold][label_cols].values
        y_val = df[df.kfold == fold][label_cols].values
        counts[(fold, "train_count")] = Counter(
                                        str(combination) for row in get_combination_wise_output_matrix(y_train, order=1) 
                                        for combination in row
                                    )
        counts[(fold, "val_count")] = Counter(
                                        str(combination) for row in get_combination_wise_output_matrix(y_val, order=1) 
                                        for combination in row
                                    )
    # View distributions
    df_counts = pd.DataFrame(counts).T.fillna(0)
    df_counts.index.set_names(["fold", "counts"], inplace=True)
    for fold in range(Config.NUM_FOLDS):
        train_counts = df_counts.loc[(fold, "train_count"), :]
        val_counts = df_counts.loc[(fold, "val_count"), :]
        val_train_ratio = pd.Series({i: val_counts[i] / train_counts[i] for i in train_counts.index}, name=(fold, "val_train_ratio"))
        df_counts = df_counts.append(val_train_ratio)
    df_counts = df_counts.sort_index() 
    return df_counts

In [12]:
df_stats = get_train_val_split_stats(df_train_onehot)
df_stats

Unnamed: 0_level_0,Unnamed: 1_level_0,"(6,)","(2,)","(5,)","(1,)","(0,)","(3,)","(4,)"
fold,counts,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0,train_count,2895.0,3674.0,12284.0,10735.0,11941.0,12440.0,7437.0
0,val_count,703.0,902.0,3082.0,2683.0,2986.0,3110.0,1864.0
0,val_train_ratio,0.242832,0.245509,0.250895,0.24993,0.250063,0.25,0.250639
1,train_count,2878.0,3674.0,12296.0,10734.0,11942.0,12440.0,7437.0
1,val_count,720.0,902.0,3070.0,2684.0,2985.0,3110.0,1864.0
1,val_train_ratio,0.250174,0.245509,0.249675,0.250047,0.249958,0.25,0.250639
2,train_count,2870.0,3654.0,12294.0,10734.0,11942.0,12440.0,7450.0
2,val_count,728.0,922.0,3072.0,2684.0,2985.0,3110.0,1851.0
2,val_train_ratio,0.253659,0.252326,0.249878,0.250047,0.249958,0.25,0.248456
3,train_count,2898.0,3672.0,12281.0,10735.0,11941.0,12440.0,7455.0


In [13]:
from collections import defaultdict

ner_labels = df_train.discourse_type.unique().tolist()
labels = defaultdict()

for index, lbl in enumerate(ner_labels):
    labels[f"B-{lbl}"] = index
    labels[f"I-{lbl}"] = index + len(ner_labels)

labels[f"O"] = 2 * len(ner_labels)
labels[f"Special"] = -100

ids_to_labels = {value: key for key, value in enumerate(labels)}

In [14]:
def read_text(file_name):
    with open(DATA_PATH + "train/" + file_name + ".txt", "r") as file:
        text = file.read()
        return text

In [15]:
df_train_grouped = df_train.groupby(["id"])
essay_id = pd.Series([*df_train_grouped.groups.keys()])
text = essay_id.apply(lambda x: read_text(x))
df_text = pd.concat([essay_id, text], axis=1, keys=["id", "text"])
df_text["text_length"] = df_text.text.apply(lambda text: len(text.split()))
df_ner_labelslist = df_train_grouped["discourse_type"].apply(lambda x:list(x.sort_values())).reset_index(name="ner_labelslist")
df_discourse_start = df_train_grouped["discourse_start"].apply(list).reset_index(name="discourse_start")
df_discourse_end = df_train_grouped["discourse_end"].apply(list).reset_index(name="discourse_end")
df_predictionsstring = df_train_grouped["predictionstring"].apply(list).reset_index(name="predictionstring")
df_train_onehot = df_train_onehot[["id", "targets", "targets_str", "kfold"]]
df_list = [df_train_onehot, df_ner_labelslist, df_discourse_start, df_discourse_end, df_predictionsstring, df_text]
df_train_merged = functools.reduce(lambda df1, df2: pd.merge(left=df1, right=df2, on=["id"], how="inner"), df_list)
df_train_merged.head()

Unnamed: 0,id,targets,targets_str,kfold,ner_labelslist,discourse_start,discourse_end,predictionstring,text,text_length
0,0000D23A521A,"[1, 1, 1, 3, 0, 1, 1]",1113011,3,"[Claim, Concluding Statement, Counterclaim, Ev...","[0.0, 170.0, 358.0, 438.0, 627.0, 722.0, 836.0...","[170.0, 357.0, 438.0, 626.0, 722.0, 836.0, 101...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...","Some people belive that the so called ""face"" o...",251
1,00066EA9880D,"[3, 1, 0, 3, 1, 1, 0]",3103110,0,"[Claim, Claim, Claim, Concluding Statement, Ev...","[0.0, 456.0, 638.0, 738.0, 1399.0, 1488.0, 231...","[455.0, 592.0, 738.0, 1398.0, 1487.0, 2219.0, ...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...",Driverless cars are exaclty what you would exp...,646
2,000E6DE9E817,"[5, 1, 1, 3, 0, 1, 1]",5113011,0,"[Claim, Claim, Claim, Claim, Claim, Concluding...","[17.0, 64.0, 158.0, 310.0, 438.0, 551.0, 776.0...","[56.0, 157.0, 309.0, 422.0, 551.0, 775.0, 961....","[[2, 3, 4, 5, 6, 7, 8], [10, 11, 12, 13, 14, 1...",Dear: Principal\n\nI am arguing against the po...,274
3,001552828BD0,"[4, 0, 0, 4, 1, 1, 0]",4004110,2,"[Claim, Claim, Claim, Claim, Evidence, Evidenc...","[0.0, 161.0, 872.0, 958.0, 1191.0, 1542.0, 161...","[160.0, 872.0, 957.0, 1190.0, 1541.0, 1612.0, ...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...",Would you be able to give your car up? Having ...,512
4,0016926B079C,"[7, 0, 0, 3, 0, 1, 0]",7003010,4,"[Claim, Claim, Claim, Claim, Claim, Claim, Cla...","[0.0, 58.0, 94.0, 206.0, 236.0, 272.0, 542.0, ...","[57.0, 91.0, 150.0, 235.0, 271.0, 542.0, 650.0...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, ...",I think that students would benefit from learn...,261


In [16]:
def label_words(row):    
    words = row["text"].split()
    word_labels = ["O" for word in words]
    word_label_ids = [labels["O"] for word in words]
    for idx, label in enumerate(row["ner_labelslist"]):
        word_idx = row["predictionstring"][idx]        
        # As per the NER IOB tagging scheme
        # The starting word of the discourse has label B-
        word_labels[word_idx[0]] = f"B-{label}"
        word_label_ids[word_idx[0]] = labels[f"B-{label}"]
        # All other words of the discourse have label I-
        for index in word_idx[1:]:
            word_labels[index] = f"I-{label}"
            word_label_ids[index] = labels[f"I-{label}"]        
    row["word_labels"] = word_labels
    row["word_label_ids"] = word_label_ids
    return row

In [17]:
df_train_final = df_train_merged.apply(lambda row: label_words(row), axis=1)
df_train_final.head()

Unnamed: 0,id,targets,targets_str,kfold,ner_labelslist,discourse_start,discourse_end,predictionstring,text,text_length,word_labels,word_label_ids
0,0000D23A521A,"[1, 1, 1, 3, 0, 1, 1]",1113011,3,"[Claim, Concluding Statement, Counterclaim, Ev...","[0.0, 170.0, 358.0, 438.0, 627.0, 722.0, 836.0...","[170.0, 357.0, 438.0, 626.0, 722.0, 836.0, 101...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...","Some people belive that the so called ""face"" o...",251,"[B-Claim, I-Claim, I-Claim, I-Claim, I-Claim, ...","[3, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10..."
1,00066EA9880D,"[3, 1, 0, 3, 1, 1, 0]",3103110,0,"[Claim, Claim, Claim, Concluding Statement, Ev...","[0.0, 456.0, 638.0, 738.0, 1399.0, 1488.0, 231...","[455.0, 592.0, 738.0, 1398.0, 1487.0, 2219.0, ...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...",Driverless cars are exaclty what you would exp...,646,"[B-Claim, I-Claim, I-Claim, I-Claim, I-Claim, ...","[3, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10..."
2,000E6DE9E817,"[5, 1, 1, 3, 0, 1, 1]",5113011,0,"[Claim, Claim, Claim, Claim, Claim, Concluding...","[17.0, 64.0, 158.0, 310.0, 438.0, 551.0, 776.0...","[56.0, 157.0, 309.0, 422.0, 551.0, 775.0, 961....","[[2, 3, 4, 5, 6, 7, 8], [10, 11, 12, 13, 14, 1...",Dear: Principal\n\nI am arguing against the po...,274,"[O, O, B-Claim, I-Claim, I-Claim, I-Claim, I-C...","[14, 14, 3, 10, 10, 10, 10, 10, 10, 14, 3, 10,..."
3,001552828BD0,"[4, 0, 0, 4, 1, 1, 0]",4004110,2,"[Claim, Claim, Claim, Claim, Evidence, Evidenc...","[0.0, 161.0, 872.0, 958.0, 1191.0, 1542.0, 161...","[160.0, 872.0, 957.0, 1190.0, 1541.0, 1612.0, ...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13...",Would you be able to give your car up? Having ...,512,"[B-Claim, I-Claim, I-Claim, I-Claim, I-Claim, ...","[3, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10..."
4,0016926B079C,"[7, 0, 0, 3, 0, 1, 0]",7003010,4,"[Claim, Claim, Claim, Claim, Claim, Claim, Cla...","[0.0, 58.0, 94.0, 206.0, 236.0, 272.0, 542.0, ...","[57.0, 91.0, 150.0, 235.0, 271.0, 542.0, 650.0...","[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10, 11, 12, ...",I think that students would benefit from learn...,261,"[B-Claim, I-Claim, I-Claim, I-Claim, I-Claim, ...","[3, 10, 10, 10, 10, 10, 10, 10, 10, 10, 3, 10,..."


In [32]:
from transformers import LongformerTokenizerFast, DataCollatorWithPadding

tokenizer = LongformerTokenizerFast.from_pretrained(DATA_PATH + "longformer/", local_files_only=True, add_prefix_space=True)
# DataCollatorWithPadding pads each batch to the longest sequence length
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [22]:
def tokenize_and_label_with_offsets(tokenizer, row):    
    text = row["text"]    
    result = tokenizer(
        text, 
        max_length=Config.MAX_LENGTH,
        padding=False, 
        truncation=True,
        return_offsets_mapping=True, 
        return_overflowing_tokens=True,
        stride=Config.STRIDE
    )    
    # If you are doing batch encoding, one sentence with length exceeding > max_length gets split into
    # multiple rows, "overflow_to_sample_mapping" which is an array with length = number of splits tells you
    # which split came from which sentence of the batch. So if you batch has two sentences (0, 1) with sentence 0
    # having two splits and sentence 1 having three splits the result["overflow_to_sample_mapping"] = [0, 0, 1, 1, 1]    
    token_label_ids = []
    token_labels = []
    token_words = []
    # the word_id in the text that a token belongs to. 
    token_word_ids = []
    for i in result["overflow_to_sample_mapping"]:        
        # There are as many labels as there are tokens. For each token set the label to a default value
        row_token_label_ids = [labels["O"] for j in range(len(result["input_ids"][i]))]         
        row_token_labels = ["O" for j in range(len(result["input_ids"][i]))]        
        subtext_token_words = ["" for j in range(len(result["input_ids"][i]))]
        row_ner_labels = row["ner_labelslist"][i]        
        discourse_start = row["discourse_start"][i]
        discourse_end = row["discourse_end"][i]
        token_word_ids.append(result.word_ids(batch_index=i))
        # loop thru the tokens
        for j in range(len(result["input_ids"][i])):
            input_id = result["input_ids"][i][j]
            # Set the label of special tokens 'CLS' and 'SEP' to -100
            if input_id in [0, 2]:
                row_token_label_ids[j] = -100
                row_token_labels[j] = "Special"
                continue
            token_start, token_end = result["offset_mapping"][i][j]
            subtext_token_words[j] = text[i][token_start:token_end]
            for ner_label, label_start, label_end in list(zip(row_ner_labels, discourse_start, discourse_end)):
                if token_start == label_start and token_end > token_start:
                    row_token_label_ids[j] = labels[f"B-{ner_label}"]
                    row_token_labels[j] = f"B-{ner_label}"
                elif token_start > label_start and token_end <= label_end:
                    row_token_label_ids[j] = labels[f"I-{ner_label}"]
                    row_token_labels[j] = f"I-{ner_label}"
        token_labels.append(row_token_labels)
        token_label_ids.append(row_token_label_ids)
        token_words.append(subtext_token_words)
    result["token_label_ids"] = token_label_ids    
    result["token_labels"] = token_labels
    result["token_words"] = token_words
    result["token_word_ids"] = token_word_ids
    return result    

In [62]:
def tokenize_and_label(tokenizer, data_row):    
    # convert the text to word tokens splitting on " "
    text_words = [item.split() for item in data_row["text"]]
    encoding = tokenizer(
        text_words, 
        is_split_into_words=True,
        max_length=Config.MAX_LENGTH,
        padding=False, 
        truncation=True,
        return_offsets_mapping=True, 
        return_overflowing_tokens=True,
        stride=Config.STRIDE
    )       
    token_label_ids = [] 
    token_labels = []
    tokens = []
    token_words = []
    # A text(essay) may get split into multiple sub texts if text length > max_length. 
    # For e.g. text_id=0 => sub_text_id=[0], text_id=1 => sub_text_ids = [1, 2, 3], text_id=2 => sub_text_ids=[4,5] 
    for sub_text_id, (token_ids, text_id) in enumerate(zip(encoding["input_ids"], encoding["overflow_to_sample_mapping"])):
        # There are as many labels as there are tokens. For each token set the label to a default value
        sub_token_label_ids = [labels["O"] for j in range(len(token_ids))]         
        sub_token_labels = ["O" for j in range(len(token_ids))]   
        sub_tokens = tokenizer.convert_ids_to_tokens(token_ids)     
        words = text_words[text_id]
        word_ids = encoding.word_ids(batch_index=sub_text_id)        
        sub_text_words = []
        for word_id in word_ids:
            if word_id is None:
                sub_text_words.append(None)
            else:
                sub_text_words.append(words[word_id])                
        for token_idx, word_idx in enumerate(word_ids):
            # Set the label of special tokens 'CLS' and 'SEP' to -100
            if word_idx is None:
                sub_token_label_ids[token_idx] = -100
                sub_token_labels[token_idx] = "Special"
            else:                            
                sub_token_label_ids[token_idx] = data_row["word_label_ids"][text_id][word_idx]
                sub_token_labels[token_idx] = data_row["word_labels"][text_id][word_idx]            
        token_labels.append(sub_token_labels)            
        token_label_ids.append(sub_token_label_ids)
        tokens.append(sub_tokens)
        token_words.append(sub_text_words)
    encoding["token_labels"] = token_labels
    encoding["token_label_ids"] = token_label_ids
    encoding["tokens"] = tokens
    encoding["token_words"] = token_words
    return encoding
        

In [30]:
test_data = df_train_final[df_train_final.id.isin(["0000D23A521A", "00066EA9880D"])]
text_words = [item.split() for item in test_data["text"].values]

In [47]:
#test_data["word_label_ids"][0]

In [48]:
#test_result = tokenize_and_label(tokenizer, test_data)

In [63]:
from functools import partial

preprocess_train_data = partial(tokenize_and_label, tokenizer)

In [64]:
from torch.utils.data import Dataset, DataLoader
from datasets import Dataset

def get_fold_dls(fold, df):
    train_df = df[df.kfold != fold].reset_index(drop=True)
    valid_df = df[df.kfold == fold].reset_index(drop=True)
    ds_train_raw = Dataset.from_pandas(train_df)
    ds_valid_raw = Dataset.from_pandas(valid_df)
    raw_ds_col_names = ds_train_raw.column_names    
    ds_train = ds_train_raw.map(preprocess_train_data, batched=True, batch_size=1000, remove_columns=raw_ds_col_names)
    ds_valid = ds_valid_raw.map(preprocess_train_data, batched=True, batch_size=1000, remove_columns=raw_ds_col_names)
    dl_train = DataLoader(ds_train, batch_size=Config.BATCH_SIZE, shuffle=True, collate_fn=data_collator, num_workers=Config.NUM_WORKERS)
    dl_valid = DataLoader(ds_valid, batch_size=Config.BATCH_SIZE, collate_fn=data_collator, num_workers=Config.NUM_WORKERS)
    return dl_train, dl_valid, ds_train, ds_valid

In [65]:
dl_train, dl_valid, ds_train, ds_valid = get_fold_dls(0, df_train_final)

  0%|          | 0/13 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

In [20]:
from transformers import AutoModelForTokenClassification, LongformerForTokenClassification
model = AutoModelForTokenClassification.from_pretrained(Config.TRANSFORMER_CHECKPOINT)

Downloading:   0%|          | 0.00/694 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570M [00:00<?, ?B/s]

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForTokenClassification: ['lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing LongformerForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerForTokenClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN

In [None]:
model.config.id2label