This is the notebook for experimenting with a zero-shot sliding-window Question-Answering model approach to the Modern Slavery Hackathon classification task: For each Modern Slavery Statement document, classifying whether or not the document gives evidence that the company has provided any training about modern slavery to their employees. 

The motivation is to use transfer learning from models pre-trained to extract relevant answers (as a span) from a document (context) in order to automate the identification of which small subsets of the documents might be relevant to modern slavery training. These smaller subsets can then make the job of human-labelling additional documents more efficient or be fed into another model which can only handle a limited number of tokens (perhaps a transformer trained for sequence classification).

The idea behind the approach is to use a pretrained QA model (one trained on SQuAD v2 such that it can return a "no span found" result) to ask questions of the documents. Since most documents in the dataset are longer than the maximum input length, a sliding window approach is used: after the entire document is tokenized, the QA model is run on successive windows, each slid by stride=128 tokens (~1/4th of the window size). All spans returned by the QA model are recorded in a new dataframe (df_with_segments.parquet). A notebook for visualizing the results of the sliding-window QA model approach is available: 'QA results viewer.ipynb'

Six questions are trialed to see which one(s) provide the best results:
 - 'Is there training provided?'
 - 'Is there training already in place?'
 - 'Has training been done?'
 - 'Is training planned?'
 - 'Is training in development?'
 - 'What kind of training is provided?'

Note: as this is a zero-shot approach, we can ignore the labels as we will not be doing any training. Therefore, the labeled (train) and unlabeled (test) data will be concatenated into a single input dataframe.

In [1]:
import pandas as pd
import re
from datetime import datetime, timedelta

# the custom classes used in this notebook are defined in SlidingWindowTransformersQA.py:
from SlidingWindowTransformersQA import SliderModel, SliderDocument 

In [2]:
# import the data, strip away the labels and combine into a single df
df_labeled = pd.read_csv('train (3).csv',index_col=0)
df_hidden = pd.read_csv('test (3).csv',index_col=0)
df_labeled['source'] = 'labeled'
df_hidden['source'] = 'hidden'
df = pd.concat([df_labeled[['source','TEXT']],
                df_hidden[['source','TEXT']]],axis=0).reset_index()

# any characters repeated more than 4 times will be shortened to 4 repetitions: 
# https://stackoverflow.com/questions/10072744/remove-repeating-characters-from-words
df['TEXT'] = df['TEXT'].apply(lambda x: re.sub(r'(.)\1{4,}', r'\1\1\1\1', str(x)))

df

Unnamed: 0,ID,source,TEXT
0,0,labeled,Modern Slavery Statement\n\nUa\n\n> Responsibi...
1,1,labeled,Burton's Biscuit Company (a trading name of Bu...
2,2,labeled,MODERN SLAVERY ACT STATEMENT\nOUR BUSINESS Zal...
3,3,labeled,MENU\nHOME\nU.K. MODERN SLAVERY ACT STATEMENT\...
4,4,labeled,Modern Slavery Act Statement\nIntroduction fro...
...,...,...,...
976,326,hidden,CECP Advisors LLP Modern Slavery Act Statement...
977,327,hidden,Modern Slavery Act Transparency Statement\n201...
978,328,hidden,MENU\n\n0333 2203 121\nBOOK A ROOM\n\nAnti Sla...
979,329,hidden,"We have placed cookies on your computer, as th..."


In [3]:
# Model chosen based on SQuAD v2 leaderboards December 2020 which favored ALBERT-based models. 
# A base-sized model was selected for speed in this Proof-of-Concept. An ALBERT XLarge model could be substituted for better
# performance at the cost of inference time
model_name = 'twmkn9/albert-base-v2-squad2'

#instantiate the slider model:
slider_model = SliderModel(model_name = model_name,
                           max_batch_size = 8,
                           stride = 128)

In [4]:
questions=['Is there training provided?', 
           'Is there training already in place?',
           'Has training been done?',
           'Is training planned?',
           'Is training in development?',
           'What kind of training is provided?'
          ]
slider_model.set_questions(questions)

# Create placeholder columns for each question that will receive the answer-spans identified by the sliding window model:
for question in questions:
    df[question]=[[] for _ in range(len(df))] # each cell start as an empty list

# Create a dataframe to receive the tokens and token-class-labels as the rows are processed:
df_tokens = pd.DataFrame(columns=['tokens','token classes'], index=df.index)

In [5]:
# This function will be applied to each document (row) in the dataframe, storing the results (tokens, token-classes, and 
# answer-spans) in the associated dataframes
def process_row(row_id, slider_model):
    #instantiate the slider document for this row:
    slider_document = SliderDocument(slider_model=slider_model,
                                     text=df.loc[row_id,'TEXT'])
    
    # feed the document through the slider model to classify all the tokens:
    slider_document.classify_tokens()
    
    # store the token info in the df_tokens dataframe
    df_tokens.loc[row_id,'tokens'] = slider_document.tokens[0].tolist()
    df_tokens.loc[row_id,'token classes'] = slider_document.token_classes.tolist()
    
    # store the answer-spans in the df dataframe
    filtered_text = slider_document.filtered_text()
    for question_response in filtered_text:
        col_header = question_response['question']
        text_segments = question_response['text segments']
        for text_segment in text_segments:
            df.loc[row_id,col_header].append(text_segment)

In [6]:
# process all rows!
start_time=datetime.now()

for i in range(len(df)):
    row_start = datetime.now()
    process_row(row_id = i, slider_model = slider_model)
    df.to_parquet('df_with_segments.parquet')
    df_tokens.to_parquet('df_token_classes.parquet')
    print(f'row {i}: Row time = {datetime.now() - row_start}. Total time elapsed = {datetime.now() - start_time}')

Token indices sequence length is longer than the specified maximum sequence length for this model (872 > 512). Running this sequence through the model will result in indexing errors


row 0: Row time = 0:00:33.787627. Total time elapsed = 0:00:33.788615
row 1: Row time = 0:00:50.032386. Total time elapsed = 0:01:23.822017
row 2: Row time = 0:01:16.255244. Total time elapsed = 0:02:40.077261
row 3: Row time = 0:03:25.265755. Total time elapsed = 0:06:05.344019
row 4: Row time = 0:00:18.911833. Total time elapsed = 0:06:24.255852
row 5: Row time = 0:00:19.308359. Total time elapsed = 0:06:43.565209
row 6: Row time = 0:08:52.633812. Total time elapsed = 0:15:36.199021
row 7: Row time = 0:00:07.170617. Total time elapsed = 0:15:43.369638
row 8: Row time = 0:01:54.195973. Total time elapsed = 0:17:37.565611
row 9: Row time = 0:01:11.147554. Total time elapsed = 0:18:48.713165
row 10: Row time = 0:00:48.697067. Total time elapsed = 0:19:37.410232
row 11: Row time = 0:03:44.544050. Total time elapsed = 0:23:21.954282
row 12: Row time = 0:00:14.460498. Total time elapsed = 0:23:36.414780
row 13: Row time = 0:00:53.746393. Total time elapsed = 0:24:30.161173
row 14: Row time

row 116: Row time = 0:00:03.342288. Total time elapsed = 1:51:28.816284
row 117: Row time = 0:00:27.920168. Total time elapsed = 1:51:56.736452
row 118: Row time = 0:00:30.518221. Total time elapsed = 1:52:27.255639
row 119: Row time = 0:00:39.334996. Total time elapsed = 1:53:06.590635
row 120: Row time = 0:00:24.383040. Total time elapsed = 1:53:30.973675
row 121: Row time = 0:00:22.879761. Total time elapsed = 1:53:53.853436
row 122: Row time = 0:02:20.656675. Total time elapsed = 1:56:14.510111
row 123: Row time = 0:00:29.827295. Total time elapsed = 1:56:44.337406
row 124: Row time = 0:01:02.335293. Total time elapsed = 1:57:46.673697
row 125: Row time = 0:00:22.595291. Total time elapsed = 1:58:09.268988
row 126: Row time = 0:00:58.788209. Total time elapsed = 1:59:08.057197
row 127: Row time = 0:00:46.880522. Total time elapsed = 1:59:54.938678
row 128: Row time = 0:01:08.565568. Total time elapsed = 2:01:03.505207
row 129: Row time = 0:01:00.220274. Total time elapsed = 2:02:03

row 230: Row time = 0:00:29.364529. Total time elapsed = 4:07:28.900769
row 231: Row time = 0:00:43.414239. Total time elapsed = 4:08:12.315008
row 232: Row time = 0:00:05.900259. Total time elapsed = 4:08:18.215267
row 233: Row time = 0:00:15.695618. Total time elapsed = 4:08:33.910885
row 234: Row time = 0:00:01.682062. Total time elapsed = 4:08:35.592947
row 235: Row time = 0:00:40.984192. Total time elapsed = 4:09:16.577139
row 236: Row time = 0:31:32.564318. Total time elapsed = 4:40:49.141457
row 237: Row time = 0:01:56.289706. Total time elapsed = 4:42:45.431163
row 238: Row time = 0:00:54.776360. Total time elapsed = 4:43:40.207523
row 239: Row time = 0:00:30.619262. Total time elapsed = 4:44:10.827807
row 240: Row time = 0:00:37.451220. Total time elapsed = 4:44:48.279027
row 241: Row time = 0:00:05.066275. Total time elapsed = 4:44:53.345302
row 242: Row time = 0:00:26.541329. Total time elapsed = 4:45:19.886631
row 243: Row time = 0:00:54.985259. Total time elapsed = 4:46:14

row 344: Row time = 0:00:54.110879. Total time elapsed = 7:00:29.262747
row 345: Row time = 0:00:24.638285. Total time elapsed = 7:00:53.901032
row 346: Row time = 0:00:43.048772. Total time elapsed = 7:01:36.949804
row 347: Row time = 0:00:50.935165. Total time elapsed = 7:02:27.884969
row 348: Row time = 0:00:31.185105. Total time elapsed = 7:02:59.070074
row 349: Row time = 0:00:28.837956. Total time elapsed = 7:03:27.909032
row 350: Row time = 0:01:14.786427. Total time elapsed = 7:04:42.695459
row 351: Row time = 0:03:21.029727. Total time elapsed = 7:08:03.725186
row 352: Row time = 0:00:29.325019. Total time elapsed = 7:08:33.050205
row 353: Row time = 0:00:59.615852. Total time elapsed = 7:09:32.666057
row 354: Row time = 0:01:08.021738. Total time elapsed = 7:10:40.687795
row 355: Row time = 0:01:02.311775. Total time elapsed = 7:11:42.999570
row 356: Row time = 0:00:15.704419. Total time elapsed = 7:11:58.703989
row 357: Row time = 0:00:49.562303. Total time elapsed = 7:12:48

row 458: Row time = 0:01:04.234637. Total time elapsed = 9:26:55.530442
row 459: Row time = 0:00:47.801236. Total time elapsed = 9:27:43.331678
row 460: Row time = 0:00:32.379165. Total time elapsed = 9:28:15.710843
row 461: Row time = 0:00:41.948662. Total time elapsed = 9:28:57.660508
row 462: Row time = 0:04:10.989128. Total time elapsed = 9:33:08.649636
row 463: Row time = 0:03:32.630357. Total time elapsed = 9:36:41.280992
row 464: Row time = 0:01:34.213603. Total time elapsed = 9:38:15.494595
row 465: Row time = 0:01:11.820624. Total time elapsed = 9:39:27.315219
row 466: Row time = 0:03:32.363198. Total time elapsed = 9:42:59.678417
row 467: Row time = 0:01:25.315284. Total time elapsed = 9:44:24.993701
row 468: Row time = 0:00:49.117912. Total time elapsed = 9:45:14.111613
row 469: Row time = 0:01:09.133983. Total time elapsed = 9:46:23.245596
row 470: Row time = 0:00:24.946252. Total time elapsed = 9:46:48.191848
row 471: Row time = 0:00:30.884617. Total time elapsed = 9:47:19

row 571: Row time = 0:01:06.071196. Total time elapsed = 11:49:32.190206
row 572: Row time = 0:01:13.756329. Total time elapsed = 11:50:45.946535
row 573: Row time = 0:00:40.935431. Total time elapsed = 11:51:26.881966
row 574: Row time = 0:00:07.680317. Total time elapsed = 11:51:34.562283
row 575: Row time = 0:01:38.341853. Total time elapsed = 11:53:12.904136
row 576: Row time = 0:01:04.090599. Total time elapsed = 11:54:16.994735
row 577: Row time = 0:01:52.050902. Total time elapsed = 11:56:09.046610
row 578: Row time = 0:00:31.969801. Total time elapsed = 11:56:41.016411
row 579: Row time = 0:03:39.949157. Total time elapsed = 12:00:20.965568
row 580: Row time = 0:06:41.096755. Total time elapsed = 12:07:02.062323
row 581: Row time = 0:00:16.407177. Total time elapsed = 12:07:18.469500
row 582: Row time = 0:01:58.175624. Total time elapsed = 12:09:16.645124
row 583: Row time = 0:00:03.253691. Total time elapsed = 12:09:19.898815
row 584: Row time = 0:00:31.643220. Total time elap

row 684: Row time = 0:00:38.527693. Total time elapsed = 13:59:37.964878
row 685: Row time = 0:00:59.125149. Total time elapsed = 14:00:37.090027
row 686: Row time = 0:00:02.233246. Total time elapsed = 14:00:39.323273
row 687: Row time = 0:01:59.422673. Total time elapsed = 14:02:38.745946
row 688: Row time = 0:00:06.751321. Total time elapsed = 14:02:45.498269
row 689: Row time = 0:01:06.533817. Total time elapsed = 14:03:52.032086
row 690: Row time = 0:00:52.418348. Total time elapsed = 14:04:44.450434
row 691: Row time = 0:00:04.523232. Total time elapsed = 14:04:48.973666
row 692: Row time = 0:01:35.161315. Total time elapsed = 14:06:24.134981
row 693: Row time = 0:00:45.839423. Total time elapsed = 14:07:09.975403
row 694: Row time = 0:10:35.822680. Total time elapsed = 14:17:45.798083
row 695: Row time = 0:01:01.456459. Total time elapsed = 14:18:47.254542
row 696: Row time = 0:00:23.508289. Total time elapsed = 14:19:10.762831
row 697: Row time = 0:01:00.005933. Total time elap

row 797: Row time = 0:01:23.887409. Total time elapsed = 17:05:22.525898
row 798: Row time = 0:00:38.109839. Total time elapsed = 17:06:00.635737
row 799: Row time = 0:00:54.151735. Total time elapsed = 17:06:54.787472
row 800: Row time = 0:00:05.382786. Total time elapsed = 17:07:00.170258
row 801: Row time = 0:00:52.723378. Total time elapsed = 17:07:52.894637
row 802: Row time = 0:01:01.067137. Total time elapsed = 17:08:53.961774
row 803: Row time = 0:01:36.531625. Total time elapsed = 17:10:30.494401
row 804: Row time = 0:00:39.914972. Total time elapsed = 17:11:10.409373
row 805: Row time = 0:00:23.313855. Total time elapsed = 17:11:33.723228
row 806: Row time = 0:00:03.935455. Total time elapsed = 17:11:37.658683
row 807: Row time = 0:00:23.382317. Total time elapsed = 17:12:01.041000
row 808: Row time = 0:01:45.260101. Total time elapsed = 17:13:46.301101
row 809: Row time = 0:00:23.448600. Total time elapsed = 17:14:09.749701
row 810: Row time = 0:00:30.222795. Total time elap

row 910: Row time = 0:00:56.765522. Total time elapsed = 18:44:14.614405
row 911: Row time = 0:01:07.639037. Total time elapsed = 18:45:22.253442
row 912: Row time = 0:00:14.302737. Total time elapsed = 18:45:36.557216
row 913: Row time = 0:00:25.816732. Total time elapsed = 18:46:02.373948
row 914: Row time = 0:00:53.560627. Total time elapsed = 18:46:55.934575
row 915: Row time = 0:00:53.777387. Total time elapsed = 18:47:49.711962
row 916: Row time = 0:00:59.026414. Total time elapsed = 18:48:48.738376
row 917: Row time = 0:00:26.101316. Total time elapsed = 18:49:14.839692
row 918: Row time = 0:05:48.405764. Total time elapsed = 18:55:03.245456
row 919: Row time = 0:00:16.406285. Total time elapsed = 18:55:19.651741
row 920: Row time = 0:00:23.093905. Total time elapsed = 18:55:42.745646
row 921: Row time = 0:00:58.326573. Total time elapsed = 18:56:41.073218
row 922: Row time = 0:00:02.220695. Total time elapsed = 18:56:43.293913
row 923: Row time = 0:01:58.079495. Total time elap

In [7]:
# The results!
pd.concat([df,df_tokens],axis=1)

Unnamed: 0,ID,source,TEXT,Is there training provided?,Is there training already in place?,Has training been done?,Is training planned?,Is training in development?,What kind of training is provided?,tokens,token classes
0,0,labeled,Modern Slavery Statement\n\nUa\n\n> Responsibi...,[we will also aim to develop training for our ...,[],[we will also aim to develop training for our ...,[we will also aim to develop training for our ...,[we will also aim to develop training for our ...,[],"[773, 9822, 3331, 13, 3786, 13, 1, 4024, 13, 1...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
1,1,labeled,Burton's Biscuit Company (a trading name of Bu...,[the board carries out a strategic risk assess...,[training relevant to other members of staff i...,[training relevant to other members of staff i...,[brc global standards (a manufacturing certifi...,"[no discrimination is practised., training rel...",[],"[9759, 22, 18, 20947, 237, 13, 5, 58, 5205, 20...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
2,2,labeled,MODERN SLAVERY ACT STATEMENT\nOUR BUSINESS Zal...,[all factories must provide a recent audit don...,[training the code of conduct is part of our <...,[training the code of conduct is part of our <...,"[all work shall be voluntary,, the training is...",[training the code of conduct is part of our <...,[<unk>compliance basics<unk> training],"[773, 9822, 601, 3331, 318, 508, 13, 10662, 13...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
3,3,labeled,MENU\nHOME\nU.K. MODERN SLAVERY ACT STATEMENT\...,"[no forced labor prison, indentured, bonded, i...",[],"[in 2017, we also conducted purchasing practic...","[no forced labor prison, indentured, bonded, i...",[any associate who contracts a factory that us...,[],"[11379, 213, 287, 9, 197, 9, 773, 9822, 601, 3...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
4,4,labeled,Modern Slavery Act Statement\nIntroduction fro...,[we pay all employees at least the national li...,[],[we provide training for our staff.],[we provide training for our staff.],[we provide training for our staff.],[],"[773, 9822, 601, 3331, 3445, 37, 14, 903, 1452...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
...,...,...,...,...,...,...,...,...,...,...,...
976,326,hidden,CECP Advisors LLP Modern Slavery Act Statement...,[training with respect to the msa policy will ...,[],[training with respect to the msa policy will ...,[training with respect to the msa policy will ...,"[training and communication as necessary,]",[training with respect to the msa policy will],"[23943, 306, 6721, 18, 13, 211, 306, 773, 9822...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
977,327,hidden,Modern Slavery Act Transparency Statement\n201...,[we firmly believe the work we have started re...,"[training training on ethical buying, social c...",[not be subjected to harsh or inhumane treatme...,[workers should not be subjected to harsh or i...,[not be subjected to harsh or inhumane treatme...,"[ethical buying, social compliance and factory...","[773, 9822, 601, 19668, 3331, 690, 8, 1053, 84...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
978,328,hidden,MENU\n\n0333 2203 121\nBOOK A ROOM\n\nAnti Sla...,[management at all levels are responsible for ...,[],[management at all levels are responsible for ...,[any person working for our business or as a s...,[],[adequate and regular training],"[11379, 713, 20165, 1024, 3601, 13, 12586, 360...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
979,329,hidden,"We have placed cookies on your computer, as th...",[greater internal training will be given to re...,[],[greater internal training will be given to re...,[greater internal training will be given to re...,[greater internal training will be given to re...,[greater internal training will],"[95, 57, 1037, 19396, 27, 154, 1428, 15, 28, 5...","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


All documents have now been processed through the sliding window model with the results stored in the two parquet files (df_with_segments.parquet and df_token_classes.parquet). I have visualized the results in a separate notebook: 'QA results viewer.ipynb' Feel free to hop over there to view them.