In [1]:
import torch
import transformers
import pandas as pd
import numpy as np
from argparse import Namespace
from tqdm.auto import tqdm
from datasets import load_from_disk
from transformers import AutoTokenizer

In [2]:
ds = load_from_disk("../../Violence_data/geo_corpus.0.0.1_dataset_for_train")

In [4]:
# Possible values
# Smaller-LABSE: setu4993/smaller-LaBSE
# LABSE: setu4993/LaBSE
# XLMT: cardiffnlp/twitter-xlm-roberta-base-sentiment
config = {
    "model_ckpt": "cardiffnlp/twitter-xlm-roberta-base-sentiment",
    "batch_size": 1024,
    "num_labels" : 6,
    "max_length": 32,
    "seed": 42,
    "fout": "../../Violence_data/geo_corpus.0.0.1_tok_ds_xlmt"
}

args = Namespace(**config)

In [5]:
# Instantiate the tokenizer
model_ckpt = args.model_ckpt
tokenizer = AutoTokenizer.from_pretrained(model_ckpt, 
                                              model_max_length=args.max_length)

In [6]:
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True)

In [7]:
%time tokenized_ds = ds.map(tokenize, batched=True)

  0%|          | 0/16770 [00:00<?, ?ba/s]

  0%|          | 0/4193 [00:00<?, ?ba/s]

  0%|          | 0/2330 [00:00<?, ?ba/s]

CPU times: user 3h 26min 55s, sys: 4min 48s, total: 3h 31min 44s
Wall time: 14min 40s


In [8]:
tokenized_ds = tokenized_ds.remove_columns('text')

In [9]:
tokenized_ds.set_format('torch')

In [10]:
tokenized_ds

DatasetDict({
    train: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 16769932
    })
    validation: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 4192483
    })
    test: Dataset({
        features: ['labels', 'input_ids', 'attention_mask'],
        num_rows: 2329158
    })
})

In [11]:
tokenized_ds["train"].features

{'labels': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

In [12]:
tokenized_ds.save_to_disk(args.fout)