In [1]:
import torch
import transformers
import pandas as pd
import numpy as np
from argparse import Namespace
from tqdm.auto import tqdm
from datasets import load_from_disk
from transformers import AutoTokenizer

In [2]:
ds = load_from_disk("/data3/mmendieta/Violence_data/geo_corpus.0.0.1_dataset_for_train")

In [3]:
ds['train'][0]

{'text': 'Venezuela en crisis, y la Fiscal de shopping en Alemania (Video)',
 'labels': [1.0, 1.0, 1.0, 0.0, 0.0, 0.0]}

In [3]:
# Possible values
# Smaller-LABSE: setu4993/smaller-LaBSE
# LABSE: setu4993/LaBSE
# XLMT: cardiffnlp/twitter-xlm-roberta-base-sentiment
config = {
    "model_ckpt": "setu4993/LaBSE",
    "batch_size": 1024,
    "num_labels" : 6,
    "max_length": 32,
    "seed": 42,
    "fout": "../../Violence_data/geo_corpus.0.0.1_tok_ds_labse"
}

args = Namespace(**config)

In [4]:
# Instantiate the tokenizer
model_ckpt = args.model_ckpt
tokenizer = AutoTokenizer.from_pretrained(model_ckpt, 
                                              model_max_length=args.max_length)

In [5]:
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True)

In [6]:
%time tokenized_ds = ds.map(tokenize, batched=True)

Loading cached processed dataset at ../../Violence_data/geo_corpus.0.0.1_dataset_for_train/train/cache-326757a4e4fe2325.arrow


  0%|          | 0/4193 [00:00<?, ?ba/s]

Loading cached processed dataset at ../../Violence_data/geo_corpus.0.0.1_dataset_for_train/test/cache-30e6c2d36da7870e.arrow


CPU times: user 37min 9s, sys: 37.6 s, total: 37min 46s
Wall time: 2min 50s


In [7]:
tokenized_ds = tokenized_ds.remove_columns('text')

In [8]:
tokenized_ds.set_format('torch')

In [9]:
tokenized_ds

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16769932
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 4192483
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2329158
    })
})

In [10]:
tokenized_ds["train"].features

{'labels': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

In [11]:
tokenized_ds.save_to_disk(args.fout)