In [1]:
from datasets import Dataset
import pandas as pd

In [5]:
# 1. Load the data to a dataframe and get necessary columns
df_data = pd.read_csv('./sample_few_shot_data/consolidated_2023-10-24.csv')
df_data.head(2)
# assuming that the label column is "FIRST_PREDICTION_CLASS" - to be replaced later
df = df_data[['URI', 'TOPIC', 'BODY_SUMMARY','FIRST_PREDICTION_CLASS']]
df.head(2)

Unnamed: 0,URI,TOPIC,BODY_SUMMARY,FIRST_PREDICTION_CLASS
0,7790305844,weather,Air Quality Index or AQI measures the concentr...,later reports of past transportation disruptio...
1,2023-10-125234841,warehouse_fire,Israeli military says it has bombed hundreds o...,leisure or other news


In [7]:
# 2. convert the class label column to onehot format, with each class label occupying one column of [0,1] values
onehot_df = pd.get_dummies(df, columns = ['FIRST_PREDICTION_CLASS'], prefix="CLASS_LABEL", dtype=int) 
onehot_df.head(2)

Unnamed: 0,URI,TOPIC,BODY_SUMMARY,"CLASS_LABEL_general social, business, economic reports, studies and trends","CLASS_LABEL_later reports of past transportation disruption event, bad news","CLASS_LABEL_lawsuits, legal or insurance impact of past event, bad news",CLASS_LABEL_leisure or other news,CLASS_LABEL_very recent breaking news on forced labor and sweatshop,"CLASS_LABEL_very recent breaking news on major air transportation or airport disruption, bad news","CLASS_LABEL_very recent breaking news on major and large scale worker strike actions causing disruption, bad news","CLASS_LABEL_very recent breaking news on major maritime transportation disruption, bad news","CLASS_LABEL_very recent breaking news on major railway transportation disruption, bad news","CLASS_LABEL_very recent breaking news on severe and extreme weather causing disruption, bad news","CLASS_LABEL_very recent breaking news on warehouse and storage facilities disruption or destruction, bad news"
0,7790305844,weather,Air Quality Index or AQI measures the concentr...,0,1,0,0,0,0,0,0,0,0,0
1,2023-10-125234841,warehouse_fire,Israeli military says it has bombed hundreds o...,0,0,0,1,0,0,0,0,0,0,0


In [9]:
# 2. convert to 'list' type dictionary to meet Dataset requirement
dict_data = onehot_df[:10].to_dict('list')
dict_data

{'URI': ['7790305844',
  '2023-10-125234841',
  '7790289556',
  '2023-10-125242951',
  '7790432248',
  '7790329990',
  '7790295159',
  '7790356009',
  '7790279927',
  '7790341989'],
 'TOPIC': ['weather',
  'warehouse_fire',
  'weather',
  'train',
  'weather',
  'weather',
  'train',
  'train',
  'train',
  'weather'],
 'BODY_SUMMARY': ['Air Quality Index or AQI measures the concentration of PM 2.5 levels. There are six AQI categories, namely Good Satisfactory, Moderately polluted, Poor, Very Poor, and Severe. The most affected areas have been Andheri, Mazgaon, Navi Mumbai where AQI remained beyond 300. The situation was such that due to fog on Wednesday, local trains on the main line of Mumbai suburban network ran late by 15 to 20 mi due to the fog.',
  "Israeli military says it has bombed hundreds of sites in Gaza over past 24 hours. Targets included command centres and tunnel shafts used by Hamas terror group. Civilian casualties continue to mount, with at least 12 killed in Khan Yo

In [5]:
# 3. create a dataset object
ds = Dataset.from_dict(dict_data)
print(ds)

# by default, all columns are of "Value" type, i.e., X_data
ds.features

Dataset({
    features: ['URI', 'TOPIC', 'BODY_SUMMARY', 'CLASS_LABEL_general social, business, economic reports, studies and trends', 'CLASS_LABEL_later reports of past transportation disruption event, bad news', 'CLASS_LABEL_lawsuits, legal or insurance impact of past event, bad news', 'CLASS_LABEL_leisure or other news', 'CLASS_LABEL_very recent breaking news on forced labor and sweatshop', 'CLASS_LABEL_very recent breaking news on major air transportation or airport disruption, bad news', 'CLASS_LABEL_very recent breaking news on major and large scale worker strike actions causing disruption, bad news', 'CLASS_LABEL_very recent breaking news on major maritime transportation disruption, bad news', 'CLASS_LABEL_very recent breaking news on major railway transportation disruption, bad news', 'CLASS_LABEL_very recent breaking news on severe and extreme weather causing disruption, bad news ', 'CLASS_LABEL_very recent breaking news on warehouse and storage facilities disruption or destru

{'URI': Value(dtype='string', id=None),
 'TOPIC': Value(dtype='string', id=None),
 'BODY_SUMMARY': Value(dtype='string', id=None),
 'CLASS_LABEL_general social, business, economic reports, studies and trends': Value(dtype='int64', id=None),
 'CLASS_LABEL_later reports of past transportation disruption event, bad news': Value(dtype='int64', id=None),
 'CLASS_LABEL_lawsuits, legal or insurance impact of past event, bad news': Value(dtype='int64', id=None),
 'CLASS_LABEL_leisure or other news': Value(dtype='int64', id=None),
 'CLASS_LABEL_very recent breaking news on forced labor and sweatshop': Value(dtype='int64', id=None),
 'CLASS_LABEL_very recent breaking news on major air transportation or airport disruption, bad news': Value(dtype='int64', id=None),
 'CLASS_LABEL_very recent breaking news on major and large scale worker strike actions causing disruption, bad news': Value(dtype='int64', id=None),
 'CLASS_LABEL_very recent breaking news on major maritime transportation disruption, ba

In [6]:
# getting the label feature column names:
label_features = [feature for feature in ds.features if feature not in ['URI','TOPIC', 'BODY_SUMMARY']]
label_features 

['CLASS_LABEL_general social, business, economic reports, studies and trends',
 'CLASS_LABEL_later reports of past transportation disruption event, bad news',
 'CLASS_LABEL_lawsuits, legal or insurance impact of past event, bad news',
 'CLASS_LABEL_leisure or other news',
 'CLASS_LABEL_very recent breaking news on forced labor and sweatshop',
 'CLASS_LABEL_very recent breaking news on major air transportation or airport disruption, bad news',
 'CLASS_LABEL_very recent breaking news on major and large scale worker strike actions causing disruption, bad news',
 'CLASS_LABEL_very recent breaking news on major maritime transportation disruption, bad news',
 'CLASS_LABEL_very recent breaking news on major railway transportation disruption, bad news',
 'CLASS_LABEL_very recent breaking news on severe and extreme weather causing disruption, bad news ',
 'CLASS_LABEL_very recent breaking news on warehouse and storage facilities disruption or destruction, bad news']

In [7]:
# we can cast the label column to the correct type:
# IMPORTANT: if names= is specified, the order of the names matter 
from datasets import ClassLabel
# ds = ds.cast_column("RELEVANCE_CLASS", ClassLabel(num_classes=2,
#                                                   names=['NOT RELEVANT', 'RELEVANT']))
# ds.features


new_features = ds.features.copy()
for label_feature in label_features:
    new_features[label_feature] = ClassLabel(num_classes=2)
ds = ds.cast(new_features)
ds.features

Casting the dataset:   0%|          | 0/820 [00:00<?, ? examples/s]

{'URI': Value(dtype='string', id=None),
 'TOPIC': Value(dtype='string', id=None),
 'BODY_SUMMARY': Value(dtype='string', id=None),
 'CLASS_LABEL_general social, business, economic reports, studies and trends': ClassLabel(names=['0', '1'], id=None),
 'CLASS_LABEL_later reports of past transportation disruption event, bad news': ClassLabel(names=['0', '1'], id=None),
 'CLASS_LABEL_lawsuits, legal or insurance impact of past event, bad news': ClassLabel(names=['0', '1'], id=None),
 'CLASS_LABEL_leisure or other news': ClassLabel(names=['0', '1'], id=None),
 'CLASS_LABEL_very recent breaking news on forced labor and sweatshop': ClassLabel(names=['0', '1'], id=None),
 'CLASS_LABEL_very recent breaking news on major air transportation or airport disruption, bad news': ClassLabel(names=['0', '1'], id=None),
 'CLASS_LABEL_very recent breaking news on major and large scale worker strike actions causing disruption, bad news': ClassLabel(names=['0', '1'], id=None),
 'CLASS_LABEL_very recent break

In [65]:
ds.features['CLASS_LABEL_general social, business, economic reports, studies and trends'].__dict__

{'names': ['0', '1'],
 'id': None,
 'num_classes': 2,
 'names_file': None,
 '_int2str': ['0', '1'],
 '_str2int': {'0': 0, '1': 1}}

In [43]:
# the ds can be split if the stratify column is a ClassLabel column only
# ds = ds.train_test_split(test_size=0.2, shuffle=True, seed=99, stratify_by_column='RELEVANCE_CLASS')
# ds

DatasetDict({
    train: Dataset({
        features: ['URI', 'TOPIC', 'BODY_SUMMARY', 'RELEVANCE_CLASS'],
        num_rows: 96
    })
    test: Dataset({
        features: ['URI', 'TOPIC', 'BODY_SUMMARY', 'RELEVANCE_CLASS'],
        num_rows: 24
    })
})

In [8]:
# we can save to disk and reload later
ds.save_to_disk('./custom_datasets/sample_dataset')

Saving the dataset (0/1 shards):   0%|          | 0/820 [00:00<?, ? examples/s]

In [9]:
# loading from disk
from datasets import load_from_disk
reloaded_dataset = load_from_disk("./custom_datasets/sample_dataset")
reloaded_dataset

Dataset({
    features: ['URI', 'TOPIC', 'BODY_SUMMARY', 'CLASS_LABEL_general social, business, economic reports, studies and trends', 'CLASS_LABEL_later reports of past transportation disruption event, bad news', 'CLASS_LABEL_lawsuits, legal or insurance impact of past event, bad news', 'CLASS_LABEL_leisure or other news', 'CLASS_LABEL_very recent breaking news on forced labor and sweatshop', 'CLASS_LABEL_very recent breaking news on major air transportation or airport disruption, bad news', 'CLASS_LABEL_very recent breaking news on major and large scale worker strike actions causing disruption, bad news', 'CLASS_LABEL_very recent breaking news on major maritime transportation disruption, bad news', 'CLASS_LABEL_very recent breaking news on major railway transportation disruption, bad news', 'CLASS_LABEL_very recent breaking news on severe and extreme weather causing disruption, bad news ', 'CLASS_LABEL_very recent breaking news on warehouse and storage facilities disruption or destru