In [16]:
from farm.data_handler.data_silo import DataSilo
from farm.data_handler.processor import TextClassificationProcessor
from farm.modeling.optimization import initialize_optimizer
from farm.infer import Inferencer
from farm.modeling.adaptive_model import AdaptiveModel
from farm.modeling.language_model import LanguageModel
from farm.modeling.prediction_head import TextClassificationHead
from farm.modeling.tokenization import Tokenizer
from farm.train import Trainer
from farm.utils import set_all_seeds, MLFlowLogger, initialize_device_settings
from transformers import AutoTokenizer, AutoModel
import pandas as pd

In [17]:
set_all_seeds(seed=42)
device, n_gpu = initialize_device_settings(use_cuda=True)
n_epochs = 1
batch_size = 32 # larger batch sizes might use too much computing power in Colab
evaluate_every = 100

10/04/2020 16:54:02 - INFO - farm.utils -   device: cpu n_gpu: 0, distributed training: False, automatic mixed precision training: None


In [23]:
lang_model = "allenai/longformer-base-4096"
do_lower_case = True

tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")

In [24]:
topics_index_to_name_map = {
    0: 'Agriculture, animals, food and rural affairs',
    1: 'Asylum, immigration and nationality',
    2: 'Business, industry and consumers',
    3: 'Communities and families',
    4: 'Crime, civil law, justice and rights',
    5: 'Culture, media and sport',
    6: 'Defence',
    7: 'Economy and finance',
    8: 'Education',
    9: 'Employment and training',
    10: 'Energy and environment',
    11: 'European Union',
    12: 'Health services and medicine',
    13: 'Housing and planning',
    14: 'International affairs',
    15: 'Parliament, government and politics',
    16: 'Science and technology',
    17: 'Social security and pensions',
    18: 'Social services',
    19: 'Transport',
    20: 'Others'
}
topics_name_to_index_map = {y:x for x,y in topics_index_to_name_map.items()}

def strip_short2(text):
    return strip_short(text, minsize=4)


def preprocess_text(text):
    FILTERS = [lambda x: x.lower(), strip_multiple_whitespaces, strip_tags, strip_punctuation,
                   strip_non_alphanum, strip_numeric, strip_short2]
    return preprocess_string(text, FILTERS)

def preprocess(topic):
    ret = []
    topic = topic.strip()
    
    if '|' in topic:
        topics = topic.split('|')
        t = topics[0]
        t = t.strip()
        return t
        
    return topic

In [25]:
class CustomTextClassificationProcessor(TextClassificationProcessor):
  
    # we need to overwrite this function from the parent class
    def file_to_dicts(self, file: str) -> [dict]:
        # read into df
        df = pd.read_csv(file)
        #df = df.drop(['date'], axis=1)
        df = df.drop(df[df.topic == 'admin'].index)
        df = df.drop(df[df.transcript.str.split().map(len) < 10].index)
        df['topic'] = df.apply(lambda row: preprocess(row['topic']), axis=1)

        df.columns = ["text_classification_label","text"]
        dicts = df.to_dict(orient="records")
        return dicts

In [26]:
label_list = ['Agriculture, animals, food and rural affairs', 'Asylum, immigration and nationality',
              'Business, industry and consumers', 'Communities and families',
              'Crime, civil law, justice and rights', 'Culture, media and sport', 'Defence',
              'Economy and finance', 'Education', 'Employment and training',
              'Energy and environment', 'European Union', 'Health services and medicine',
              'Housing and planning', 'International affairs', 'Parliament, government and politics',
              'Science and technology', 'Social security and pensions', 'Social services', 'Transport',
              'Others'] #labels in our data set

metric = "f1_macro" # desired metric for evaluation

processor = CustomTextClassificationProcessor(tokenizer=tokenizer,
                                            max_seq_len=4096, # BERT can only handle sequence lengths of up to 512
                                            data_dir='data/', 
                                            label_list=label_list,
                                            metric=metric,
                                            quote_char='"',
                                            multilabel=False,
                                            train_filename="2012_debate.csv",
                                            dev_filename=None,
                                            test_filename="2013_debate.csv",
                                            dev_split=0.1 # this will extract 10% of the train set to create a dev set
                                            )

In [27]:
data_silo = DataSilo(
    processor=processor,
    batch_size=batch_size)

10/04/2020 16:55:34 - INFO - farm.data_handler.data_silo -   
Loading data into the data silo ... 
              ______
               |o  |   !
   __          |:`_|---'-.
  |__|______.-/ _ \-----.|       
 (o)(o)------'\ _ /     ( )      
 
10/04/2020 16:55:34 - INFO - farm.data_handler.data_silo -   Loading train set from: data/2012_debate.csv 
10/04/2020 16:55:51 - INFO - farm.data_handler.data_silo -   Got ya 31 parallel workers to convert 1792 dictionaries to pytorch datasets (chunksize = 12)...
10/04/2020 16:55:51 - INFO - farm.data_handler.data_silo -    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0 
10/04/2020 16:55:51 - INFO - farm.data_handler.data_silo -   /w\  /|\  /|\  /w\  /|\  /w\  /w\  /w\  /w\  /w\  /w\  /w\  /w\  /w\  /w\  /|\  /w\  /|\  /|\  /|\  /w\  /|\  /w\  /|\  /w\  /w\  /w\  /w\  /w\  /|\  /w\
10/04/2020 16:55:51 - INFO - farm.data_handler.data_silo -   /'\

10/04/2020 16:56:01 - INFO - farm.data_handler.processor -   

      .--.        _____                       _      
    .'_\/_'.     / ____|                     | |     
    '. /\ .'    | (___   __ _ _ __ ___  _ __ | | ___ 
      "||"       \___ \ / _` | '_ ` _ \| '_ \| |/ _ \ 
       || /\     ____) | (_| | | | | | | |_) | |  __/
    /\ ||//\)   |_____/ \__,_|_| |_| |_| .__/|_|\___|
   (/\||/                             |_|           
______\||/___________________________________________                     

ID: 1-0
Clear Text: 
 	text_classification_label: Crime, civil law, justice and rights
 	text: 6.  What steps she has taken to ensure that the criminalisation of forced marriage does not discourage victims from bringing complaints forward. [114073]  Forced marriage is a hidden problem, and criminalising this abhorrent act will give victims the option of seeking  the toughest form of justice. To ensure that victims and others are not discouraged from coming forward, civil remedie

Preprocessing Dataset data/2012_debate.csv: 100%|██████████| 1792/1792 [06:34<00:00,  4.55 Dicts/s]
10/04/2020 17:02:26 - INFO - farm.data_handler.data_silo -   Loading dev set as a slice of train set
10/04/2020 17:02:26 - INFO - farm.data_handler.data_silo -   Took 184 samples out of train set to create dev set (dev split is roughly 0.1)
10/04/2020 17:02:26 - INFO - farm.data_handler.data_silo -   Loading test set from: data/2013_debate.csv
10/04/2020 17:02:44 - INFO - farm.data_handler.data_silo -   Got ya 31 parallel workers to convert 1901 dictionaries to pytorch datasets (chunksize = 13)...
10/04/2020 17:02:44 - INFO - farm.data_handler.data_silo -    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0    0 
10/04/2020 17:02:44 - INFO - farm.data_handler.data_silo -   /|\  /|\  /|\  /|\  /|\  /|\  /|\  /w\  /w\  /|\  /|\  /|\  /|\  /|\  /|\  /w\  /w\  /|\  /w\  /w\  /|\  /w\  /w\  /|\  /

10/04/2020 17:02:45 - INFO - farm.data_handler.processor -   

      .--.        _____                       _      
    .'_\/_'.     / ____|                     | |     
    '. /\ .'    | (___   __ _ _ __ ___  _ __ | | ___ 
      "||"       \___ \ / _` | '_ ` _ \| '_ \| |/ _ \ 
       || /\     ____) | (_| | | | | | | |_) | |  __/
    /\ ||//\)   |_____/ \__,_|_| |_| |_| .__/|_|\___|
   (/\||/                             |_|           
______\||/___________________________________________                     

ID: 10-0
Clear Text: 
 	text_classification_label: Crime, civil law, justice and rights
 	text: 15.  What recent steps she has taken to reduce gang-related and youth violence. [135594]  Large areas of Government policy are having a positive impact on the matter. Specifically, we are supporting 29 local areas that face problems of gang and youth violence. That includes tackling young people possessing knives, which we were talking about a moment ago. We have also recently announc

Preprocessing Dataset data/2013_debate.csv: 100%|██████████| 1901/1901 [13:30<00:00,  2.35 Dicts/s]
10/04/2020 17:16:17 - INFO - farm.data_handler.data_silo -   Examples in train: 1608
10/04/2020 17:16:17 - INFO - farm.data_handler.data_silo -   Examples in dev  : 184
10/04/2020 17:16:17 - INFO - farm.data_handler.data_silo -   Examples in test : 1901
10/04/2020 17:16:17 - INFO - farm.data_handler.data_silo -   
10/04/2020 17:16:17 - INFO - farm.data_handler.data_silo -   Longest sequence length observed after clipping:     4096
10/04/2020 17:16:17 - INFO - farm.data_handler.data_silo -   Average sequence length after clipping: 1660.2294776119404
10/04/2020 17:16:17 - INFO - farm.data_handler.data_silo -   Proportion clipped:      0.24502487562189054


In [30]:
# loading the pretrained BERT base cased model
# language_model = LanguageModel.load(lang_model)
# prediction head for our model that is suited for classifying news article genres
prediction_head = TextClassificationHead(
    class_weights=data_silo.calculate_class_weights(task_name="text_classification"), num_labels=len(label_list))

# model = AdaptiveModel(
#         language_model=language_model,
#         prediction_heads=[prediction_head],
#         embeds_dropout_prob=0.1,
#         lm_output_types=["per_sequence"],
#         device=device)

model = AutoModel.from_pretrained("allenai/longformer-base-4096")

model, optimizer, lr_schedule = initialize_optimizer(
        model=model,
        learning_rate=2e-5,
        device=device,
        n_batches=len(data_silo.loaders["train"]),
        n_epochs=n_epochs)

 'Asylum, immigration and nationality' 'Business, industry and consumers'
 'Communities and families' 'Crime, civil law, justice and rights'
 'Culture, media and sport' 'Defence' 'Economy and finance' 'Education'
 'Employment and training' 'Energy and environment' 'European Union'
 'Health services and medicine' 'Housing and planning'
 'International affairs' 'Parliament, government and politics'
 'Science and technology' 'Social security and pensions' 'Social services'
 'Transport' 'Others'], y=['Agriculture, animals, food and rural affairs', 'Asylum, immigration and nationality', 'Business, industry and consumers', 'Communities and families', 'Crime, civil law, justice and rights', 'Culture, media and sport', 'Defence', 'Economy and finance', 'Education', 'Employment and training', 'Energy and environment', 'European Union', 'Health services and medicine', 'Housing and planning', 'International affairs', 'Parliament, government and politics', 'Science and technology', 'Social securit

10/04/2020 18:06:53 - INFO - farm.modeling.prediction_head -   Using class weights for task 'text_classification': [ 1.9890109   5.5408163   0.7835498   0.825228    0.36938775  1.5210084
  1.0206767   0.6464286   0.91260505  0.90199333  0.8431677   4.3095236
  0.61564624  8.619047    0.69260204  0.3447619   7.757143   11.081633
 15.514286    0.95767194  1.1081632 ]
10/04/2020 18:06:53 - INFO - filelock -   Lock 139734587902992 acquired on /home/ubuntu/.cache/torch/transformers/dfc92dbbf5c555abf807425ebdb22b55de7a17e21fe1c48cbaa5764982c1d9c0.cd65234711d2e83d420aa696eb9186cdec6ab79ef8bf090b442cf249443dfa92.lock


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=597257159.0, style=ProgressStyle(descri…

10/04/2020 18:07:08 - INFO - filelock -   Lock 139734587902992 released on /home/ubuntu/.cache/torch/transformers/dfc92dbbf5c555abf807425ebdb22b55de7a17e21fe1c48cbaa5764982c1d9c0.cd65234711d2e83d420aa696eb9186cdec6ab79ef8bf090b442cf249443dfa92.lock





Some weights of LongformerModel were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['longformer.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
10/04/2020 18:07:12 - INFO - farm.modeling.optimization -   Loading optimizer `TransformersAdamW`: '{'correct_bias': False, 'weight_decay': 0.01, 'lr': 2e-05}'
10/04/2020 18:07:12 - INFO - farm.modeling.optimization -   Using scheduler 'get_linear_schedule_with_warmup'
10/04/2020 18:07:12 - INFO - farm.modeling.optimization -   Loading schedule `get_linear_schedule_with_warmup`: '{'num_warmup_steps': 5.1000000000000005, 'num_training_steps': 51}'


In [31]:
model.connect_heads_with_processor(processor.tasks, require_labels=True)

ModuleAttributeError: 'LongformerModel' object has no attribute 'connect_heads_with_processor'

In [10]:
trainer = Trainer(
        model=model,
        optimizer=optimizer,
        data_silo=data_silo,
        epochs=n_epochs,
        n_gpu=n_gpu,
        lr_schedule=lr_schedule,
        evaluate_every=evaluate_every, # we defined this value in the setup section. We set it to 100
        device=device)

In [11]:
trainer.train()

10/04/2020 12:27:00 - INFO - farm.train -   
 

          &&& &&  & &&             _____                   _             
      && &\/&\|& ()|/ @, &&       / ____|                 (_)            
      &\/(/&/&||/& /_/)_&/_&     | |  __ _ __ _____      ___ _ __   __ _ 
   &() &\/&|()|/&\/ '%" & ()     | | |_ | '__/ _ \ \ /\ / / | '_ \ / _` |
  &_\_&&_\ |& |&&/&__%_/_& &&    | |__| | | | (_) \ V  V /| | | | | (_| |
&&   && & &| &| /& & % ()& /&&    \_____|_|  \___/ \_/\_/ |_|_| |_|\__, |
 ()&_---()&\&\|&&-&&--%---()~                                       __/ |
     &&     \|||                                                   |___/
             |||
             |||
             |||
       , -=-~  .-^- _
              `

Train epoch 0/0 (Cur. train loss: 3.9871): 100%|██████████| 51/51 [26:22<00:00, 31.02s/it] 
Evaluating: 100%|██████████| 60/60 [02:00<00:00,  2.02s/it]
  _warn_prf(average, modifier, msg_start, len(result))
10/04/2020 12:55:23 - INFO - farm.eval -   

\\|//       \\|//  

AdaptiveModel(
  (language_model): Electra(
    (model): ElectraModel(
      (embeddings): ElectraEmbeddings(
        (word_embeddings): Embedding(30522, 1024, padding_idx=0)
        (position_embeddings): Embedding(512, 1024)
        (token_type_embeddings): Embedding(2, 1024)
        (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (embeddings_project): Linear(in_features=1024, out_features=256, bias=True)
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=256, out_features=256, bias=True)
                (key): Linear(in_features=256, out_features=256, bias=True)
                (value): Linear(in_features=256, out_features=256, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOut