In [None]:
import torch
torch.cuda.device_count()

In [28]:
from datasets import load_from_disk
import torch

### Preprocess the dataset

In [37]:
ds = load_from_disk("../../Violence_data/geo_corpus.0.0.1_datasets")

In [None]:
# Peek at one sample
ds["train"][0]

Since this is a multi-label classification problem, there are 6 labels = ('pre7geo10', 'pre7geo30', 'pre7geo50', 'post7geo10', 'post7geo30', 'post7geo50')

In [38]:
# Remove unncesary columns
keep_cols = ['text', 'pre7geo10', 'pre7geo30', 'pre7geo50', 'post7geo10', 
             'post7geo30', 'post7geo50']
remove_columns = [col for col in ds['train'].column_names if col not in keep_cols]

In [39]:
ds = ds.remove_columns(remove_columns)

In [45]:
ds["train"].features

{'text': Value(dtype='string', id=None),
 'post7geo10': Value(dtype='int64', id=None),
 'post7geo30': Value(dtype='int64', id=None),
 'post7geo50': Value(dtype='int64', id=None),
 'pre7geo10': Value(dtype='int64', id=None),
 'pre7geo30': Value(dtype='int64', id=None),
 'pre7geo50': Value(dtype='int64', id=None)}

In [46]:
# We need to to cast integer labels to float in order to calculate the Binary Cross
# Entropy loss during training
from datasets import Value
new_features = ds["train"].features.copy()
new_features['post7geo10'] = Value(dtype='float32')
new_features['post7geo30'] = Value(dtype='float32')
new_features['post7geo50'] = Value(dtype='float32')
new_features['pre7geo10'] = Value(dtype='float32')
new_features['pre7geo30'] = Value(dtype='float32')
new_features['pre7geo50'] = Value(dtype='float32')
ds["train"] = ds["train"].cast(new_features)
ds["validation"] = ds["validation"].cast(new_features)
ds["test"] = ds["test"].cast(new_features)

Casting the dataset:   0%|          | 0/1677 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/420 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/233 [00:00<?, ?ba/s]

In [47]:
ds["train"].features

{'text': Value(dtype='string', id=None),
 'post7geo10': Value(dtype='float32', id=None),
 'post7geo30': Value(dtype='float32', id=None),
 'post7geo50': Value(dtype='float32', id=None),
 'pre7geo10': Value(dtype='float32', id=None),
 'pre7geo30': Value(dtype='float32', id=None),
 'pre7geo50': Value(dtype='float32', id=None)}

In [50]:
# This cell takes approximately 30 min to run
# It is important that the labels are float in order to calculate Binary Cross Entropy loss
# create 'labels' columm
cols = ds["train"].column_names
ds = ds.map(lambda x : {"labels": [x[c] for c in cols if c != "text"]})
ds                                   

  0%|          | 0/16769932 [00:00<?, ?ex/s]

  0%|          | 0/4192483 [00:00<?, ?ex/s]

  0%|          | 0/2329158 [00:00<?, ?ex/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'post7geo10', 'post7geo30', 'post7geo50', 'pre7geo10', 'pre7geo30', 'pre7geo50', 'labels'],
        num_rows: 16769932
    })
    validation: Dataset({
        features: ['text', 'post7geo10', 'post7geo30', 'post7geo50', 'pre7geo10', 'pre7geo30', 'pre7geo50', 'labels'],
        num_rows: 4192483
    })
    test: Dataset({
        features: ['text', 'post7geo10', 'post7geo30', 'post7geo50', 'pre7geo10', 'pre7geo30', 'pre7geo50', 'labels'],
        num_rows: 2329158
    })
})

In [51]:
ds["train"][0]

{'text': 'Venezuela en crisis, y la Fiscal de shopping en Alemania (Video)',
 'post7geo10': 1.0,
 'post7geo30': 1.0,
 'post7geo50': 1.0,
 'pre7geo10': 0.0,
 'pre7geo30': 0.0,
 'pre7geo50': 0.0,
 'labels': [1.0, 1.0, 1.0, 0.0, 0.0, 0.0]}

In [52]:
# save the dataset only with columns 'text' and 'labels'
col_names = ds["train"].column_names
col_names.remove("labels")
col_names.remove('text')

In [53]:
ds_clean = ds.remove_columns(col_names)
ds_clean

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 16769932
    })
    validation: Dataset({
        features: ['text', 'labels'],
        num_rows: 4192483
    })
    test: Dataset({
        features: ['text', 'labels'],
        num_rows: 2329158
    })
})

In [54]:
ds_clean.save_to_disk("../../Violence_data/geo_corpus.0.0.1_dataset_for_train")