In [1]:
!pip install pandas==1.3.4
!pip install transformers==4.12.5
!pip install datasets==1.15.1
#!pip install datasets
!pip install ipywidgets

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting ipywidgets
  Downloading ipywidgets-7.7.0-py2.py3-none-any.whl (123 kB)
[K     |████████████████████████████████| 123 kB 18.6 MB/s eta 0:00:01
Collecting widgetsnbextension~=3.6.0
  Downloading widgetsnbextension-3.6.0-py2.py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 30.2 MB/s eta 0:00:01
Collecting jupyterlab-widgets>=1.0.0
  Downloading jupyterlab_widgets-1.1.0-py3-none-any.whl (245 kB)
[K     |████████████████████████████████| 245 kB 26.7 MB/s eta 0:00:01
Installing collected packages: widgetsnbextension, jupyterlab-widgets, ipywidgets
Successfully installed ipywidgets-7.7.0 jupyterlab-widgets-1.1.0 widgetsnbextension-3.6.0


In [2]:
import os
import pickle

from collections import Counter

# import pandas as pd
from sklearn.metrics import classification_report

import numpy as np
import torch
import torch.nn as nn

import transformers
from transformers import Trainer
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollatorWithPadding

import datasets
from datasets import Dataset
from datasets import ClassLabel
from datasets import load_metric

In [3]:
torch.cuda.empty_cache()

## Global variables

In [4]:
DATA_FOLDER = '/notebooks/Data/bert_sequence_classification'
DATA_FILE = '/notebooks/ICANN/Datasets/dataset_persuasive_essays_icann.pt'
RESULTS_FOLDER = '/notebooks/cascade_bert/saved_models'

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
device

device(type='cuda')

## Load data

In [7]:
dataset = torch.load(DATA_FILE)

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence', 'topic_full_sentence_stuctural_fts', 'topic_full_sentence_structural_fts_combined', 'feature_tensor'],
        num_rows: 4709
    })
    test: Dataset({
        features: ['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence', 'topic_full_sentence_stuctural_fts', 'topic_full_sentence_structural_fts_combined', 'feature_tensor'],
        num_rows: 1258
    })
})

In [9]:
dataset['train']['topic_full_sentence_structural_fts_combined'][230]

'Topic: Some young adults want independence from their parents quickly. Sentence: There will not be such worries when young adults live in their own home, because parents will take care for them. Structural features: Two. No. No. No. No.'

In [10]:
dataset['train'] = dataset['train'].flatten_indices()

Flattening the indices:   0%|          | 0/5 [00:00<?, ?ba/s]

In [11]:
dataset['test'] = dataset['test'].flatten_indices()

Flattening the indices:   0%|          | 0/2 [00:00<?, ?ba/s]

In [12]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [13]:
label_names = ['Claim', 'Premise', 'MajorClaim']
label_nb = len(label_names)
labels = ClassLabel(num_classes=label_nb, names=label_names)

In [14]:
labels

ClassLabel(num_classes=3, names=['Claim', 'Premise', 'MajorClaim'], names_file=None, id=None)

In [15]:
def tokenize(batch):
    tokens = tokenizer(batch['topic_full_sentence_structural_fts_combined'], truncation=True, padding=True, max_length=512)
    tokens['labels'] = labels.str2int(batch['labels'])
    return tokens

# this is just the text. if the results are nice, check transfer with text + topic 

In [16]:
dataset = dataset.map(tokenize, batched=True)



  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [17]:
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

In [18]:
dataset

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'feature_tensor', 'input_ids', 'labels', 'sentence', 'split', 'text', 'token_type_ids', 'topic_and_full_sentence', 'topic_full_sentence_structural_fts_combined', 'topic_full_sentence_stuctural_fts'],
        num_rows: 4709
    })
    test: Dataset({
        features: ['attention_mask', 'feature_tensor', 'input_ids', 'labels', 'sentence', 'split', 'text', 'token_type_ids', 'topic_and_full_sentence', 'topic_full_sentence_structural_fts_combined', 'topic_full_sentence_stuctural_fts'],
        num_rows: 1258
    })
})

In [19]:
train_dataset = dataset['train']#.shuffle(seed=42)
test_dataset = dataset['test']#.shuffle(seed=42)

# train_val_datasets = dataset['train'].train_test_split(train_size=0.8)
# train_dataset = train_val_datasets['train']
# val_dataset = train_val_datasets['test']

In [20]:
dataset_d = {}
dataset_d['train'] = train_dataset
dataset_d['test'] = test_dataset
# dataset_d['val'] = val_dataset

In [21]:
test_dataset

Dataset({
    features: ['attention_mask', 'feature_tensor', 'input_ids', 'labels', 'sentence', 'split', 'text', 'token_type_ids', 'topic_and_full_sentence', 'topic_full_sentence_structural_fts_combined', 'topic_full_sentence_stuctural_fts'],
    num_rows: 1258
})

In [22]:
# 4709, 1258

In [23]:
tokenizer.decode(dataset['train'][2945]['input_ids'])

"[CLS] topic : what's more important : hard work or luck? sentence : it is not for which ronaldo is more fortune than me. structural features : two. no. no. no. no. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]"

In [24]:
# sanity check
set(dataset_d['train']['split'])

{'TRAIN'}

In [25]:
# sanity check
set(dataset_d['test']['split'])

{'TEST'}

## load model

In [26]:
# # load model
model_file = os.path.join("/notebooks/cascade_bert/saved_models", 'best-model-probs')
# model_file = os.path.join(RESULTS_FOLDER, 'checkpoint-1500')

model = BertForSequenceClassification.from_pretrained(model_file, num_labels=3)
model.to(device)
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [27]:
trainer = Trainer(model, data_collator=DataCollatorWithPadding(tokenizer))

In [28]:
test_raw_preds, test_labels, _ = trainer.predict(test_dataset)
test_preds = np.argmax(test_raw_preds, axis=1)

The following columns in the test set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: split, topic_and_full_sentence, topic_full_sentence_stuctural_fts, feature_tensor, topic_full_sentence_structural_fts_combined, sentence, text.
***** Running Prediction *****
  Num examples = 1258
  Batch size = 8


In [29]:
# target_name = labels.int2str([0,1,2])
print(classification_report(test_labels, test_preds))

              precision    recall  f1-score   support

           0       0.67      0.71      0.69       301
           1       0.93      0.88      0.90       805
           2       0.79      0.92      0.85       152

    accuracy                           0.84      1258
   macro avg       0.80      0.84      0.82      1258
weighted avg       0.85      0.84      0.85      1258



In [30]:
train_raw_preds, train_labels, _ = trainer.predict(train_dataset)
train_preds = np.argmax(train_raw_preds, axis=1)

The following columns in the test set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: split, topic_and_full_sentence, topic_full_sentence_stuctural_fts, feature_tensor, topic_full_sentence_structural_fts_combined, sentence, text.
***** Running Prediction *****
  Num examples = 4709
  Batch size = 8


In [31]:
# target_name = labels.int2str([0,1,2])
print(classification_report(train_labels, train_preds))

              precision    recall  f1-score   support

           0       0.64      0.69      0.66      1173
           1       0.92      0.87      0.89      2957
           2       0.77      0.90      0.83       579

    accuracy                           0.83      4709
   macro avg       0.78      0.82      0.80      4709
weighted avg       0.83      0.83      0.83      4709



In [32]:
# this is the correct softmax thing

# x = torch.softmax(torch.tensor(train_raw_preds[0]), 0)

In [33]:
# Sanity check

list(train_labels) == list(dataset['train']['labels']) 

True

In [34]:
# Sanity check

list(test_labels) == list(dataset['test']['labels'])

True

## dataset work

In [35]:
train_raw_preds.shape

(4709, 3)

In [36]:
# this is the correct softmax thing

x = torch.softmax(torch.tensor(train_raw_preds[0]), 0)

In [37]:
np.array(x)

array([0.03926747, 0.0223834 , 0.9383492 ], dtype=float32)

In [38]:
train_real_probs = []

for i in range(0, 4709, 1):
    
    x = torch.softmax(torch.tensor(train_raw_preds[i]), 0)
    train_real_probs.append(np.array(x))

In [39]:
train_real_probs

[array([0.03926747, 0.0223834 , 0.9383492 ], dtype=float32),
 array([0.6915777, 0.2941995, 0.0142229], dtype=float32),
 array([0.03164915, 0.9661718 , 0.00217908], dtype=float32),
 array([0.03636508, 0.9617669 , 0.00186796], dtype=float32),
 array([0.06173128, 0.9363389 , 0.00192983], dtype=float32),
 array([0.7111723 , 0.27323553, 0.0155922 ], dtype=float32),
 array([0.3110964 , 0.6806043 , 0.00829932], dtype=float32),
 array([0.03111818, 0.96649677, 0.00238509], dtype=float32),
 array([0.02886308, 0.96896625, 0.00217068], dtype=float32),
 array([0.7601687 , 0.21284707, 0.02698427], dtype=float32),
 array([0.04654786, 0.0161973 , 0.93725485], dtype=float32),
 array([0.04393681, 0.03335391, 0.92270935], dtype=float32),
 array([0.68780047, 0.29785115, 0.01434838], dtype=float32),
 array([0.02412955, 0.9729787 , 0.00289167], dtype=float32),
 array([0.02712829, 0.9669231 , 0.00594869], dtype=float32),
 array([0.03582902, 0.9623149 , 0.00185605], dtype=float32),
 array([0.746743  , 0.22938

In [40]:
len(train_real_probs)

4709

In [41]:
test_real_probs = []

for i in range(0, 1258, 1):
    
    x = torch.softmax(torch.tensor(test_raw_preds[i]), 0)
    test_real_probs.append(np.array(x))

In [42]:
test_real_probs

[array([0.4547676 , 0.10806037, 0.43717203], dtype=float32),
 array([0.1776474 , 0.02093784, 0.8014148 ], dtype=float32),
 array([0.07955074, 0.9169022 , 0.00354715], dtype=float32),
 array([0.02552183, 0.97159797, 0.00288022], dtype=float32),
 array([0.02506617, 0.9714655 , 0.00346836], dtype=float32),
 array([0.77744114, 0.19153818, 0.03102056], dtype=float32),
 array([0.05591998, 0.9417408 , 0.00233923], dtype=float32),
 array([0.0213295 , 0.9735353 , 0.00513513], dtype=float32),
 array([0.02504824, 0.97252005, 0.00243167], dtype=float32),
 array([0.72810584, 0.24803638, 0.02385768], dtype=float32),
 array([0.11312176, 0.01489558, 0.87198263], dtype=float32),
 array([0.04064763, 0.02094669, 0.93840563], dtype=float32),
 array([0.695911  , 0.28952426, 0.01456469], dtype=float32),
 array([0.09943336, 0.8973342 , 0.00323238], dtype=float32),
 array([0.7201534 , 0.26442355, 0.01542305], dtype=float32),
 array([0.04156595, 0.9565642 , 0.00186991], dtype=float32),
 array([0.04156595, 0.95

In [43]:
len(test_real_probs)

1258

In [44]:
train_real_probs = np.round(train_real_probs, 3)

In [45]:
test_real_probs = np.round(test_real_probs, 3)

In [46]:
train_real_probs.shape

(4709, 3)

In [47]:
test_real_probs.shape

(1258, 3)

In [49]:
test_real_probs

array([[0.455, 0.108, 0.437],
       [0.178, 0.021, 0.801],
       [0.08 , 0.917, 0.004],
       ...,
       [0.232, 0.761, 0.006],
       [0.093, 0.904, 0.003],
       [0.049, 0.013, 0.938]], dtype=float32)

In [50]:
import pandas as pd

In [52]:
df_train_probs = pd.DataFrame(train_real_probs) 

In [53]:
df_train_probs

Unnamed: 0,0,1,2
0,0.039,0.022,0.938
1,0.692,0.294,0.014
2,0.032,0.966,0.002
3,0.036,0.962,0.002
4,0.062,0.936,0.002
...,...,...,...
4704,0.070,0.928,0.002
4705,0.145,0.851,0.004
4706,0.145,0.851,0.004
4707,0.401,0.592,0.007


In [54]:
df_train_probs.columns = ['class_1', 'class_2', 'class_3']

In [56]:
df_test_probs = pd.DataFrame(test_real_probs) 

In [57]:
df_test_probs

Unnamed: 0,0,1,2
0,0.455,0.108,0.437
1,0.178,0.021,0.801
2,0.080,0.917,0.004
3,0.026,0.972,0.003
4,0.025,0.971,0.003
...,...,...,...
1253,0.031,0.964,0.005
1254,0.707,0.275,0.017
1255,0.232,0.761,0.006
1256,0.093,0.904,0.003


In [58]:
df_test_probs.columns = ['class_1', 'class_2', 'class_3']

In [59]:
df_test_probs

Unnamed: 0,class_1,class_2,class_3
0,0.455,0.108,0.437
1,0.178,0.021,0.801
2,0.080,0.917,0.004
3,0.026,0.972,0.003
4,0.025,0.971,0.003
...,...,...,...
1253,0.031,0.964,0.005
1254,0.707,0.275,0.017
1255,0.232,0.761,0.006
1256,0.093,0.904,0.003


In [60]:
columns_l = ['class_1', 'class_2',
        'class_3']

# # DUMMY EXPERIMENT
# columns_l = columns_l + [ 'is_Claim', 'is_MajorClaim', 'is_Premise']

len(columns_l)

3

In [78]:
def create_feature_tensor(sample):
    
    t = np.zeros( shape=(len(columns_l),) )
    
    for i, c in enumerate(columns_l):
        t[i] = np.round(sample[c], 3)
    
    return t

In [79]:
df_train_probs['probs_feature'] = df_train_probs.apply(create_feature_tensor, axis=1)

In [80]:
df_train_probs['probs_feature']

0       [0.039, 0.022, 0.938]
1       [0.692, 0.294, 0.014]
2       [0.032, 0.966, 0.002]
3       [0.036, 0.962, 0.002]
4       [0.062, 0.936, 0.002]
                ...          
4704     [0.07, 0.928, 0.002]
4705    [0.145, 0.851, 0.004]
4706    [0.145, 0.851, 0.004]
4707    [0.401, 0.592, 0.007]
4708    [0.052, 0.012, 0.936]
Name: probs_feature, Length: 4709, dtype: object

In [81]:
df_test_probs['probs_feature'] = df_test_probs.apply(create_feature_tensor, axis=1)

In [82]:
df_test_probs['probs_feature']

0       [0.455, 0.108, 0.437]
1       [0.178, 0.021, 0.801]
2        [0.08, 0.917, 0.004]
3       [0.026, 0.972, 0.003]
4       [0.025, 0.971, 0.003]
                ...          
1253    [0.031, 0.964, 0.005]
1254    [0.707, 0.275, 0.017]
1255    [0.232, 0.761, 0.006]
1256    [0.093, 0.904, 0.003]
1257    [0.049, 0.013, 0.938]
Name: probs_feature, Length: 1258, dtype: object

In [89]:
df_train_data = pd.read_csv("train_dataset.csv")

In [90]:
df_test_data = pd.read_csv("test_dataset.csv")

In [91]:
df_train_data['probs_tensor'] = df_train_probs['probs_feature']

In [92]:
df_train_data

Unnamed: 0.1,Unnamed: 0,attention_mask,feature_tensor,input_ids,labels,sentence,split,text,token_type_ids,topic_and_full_sentence,topic_full_sentence_structural_fts_combined,topic_full_sentence_stuctural_fts,probs_tensor
0,0,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1. 1. 0. 1. 1.],[ 101 8476 1024 2323 2493 2022 4036 20...,2,"From this point of view, I firmly believe that...",TRAIN,we should attach more importance to cooperatio...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,"[0.039, 0.022, 0.938]"
1,1,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 1. 0.],[ 101 8476 1024 2323 2493 2022 4036 20...,0,"First of all, through cooperation, children ca...",TRAIN,"through cooperation, children can learn about ...",[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,"[0.692, 0.294, 0.014]"
2,2,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2323 2493 2022 4036 20...,1,What we acquired from team work is not only ho...,TRAIN,What we acquired from team work is not only ho...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,"[0.032, 0.966, 0.002]"
3,3,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2323 2493 2022 4036 20...,1,"During the process of cooperation, children ca...",TRAIN,"During the process of cooperation, children ca...",[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,"[0.036, 0.962, 0.002]"
4,4,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 1.],[ 101 8476 1024 2323 2493 2022 4036 20...,1,All of these skills help them to get on well w...,TRAIN,All of these skills help them to get on well w...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,Topic: Should students be taught to compete or...,"[0.062, 0.936, 0.002]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4704,4704,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[3. 0. 0. 0. 0.],[ 101 8476 1024 2336 2323 5702 2524 20...,1,"It will be good for children, because indirect...",TRAIN,indirectly they will learn how to socialize ea...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,"[0.07, 0.928, 0.002]"
4705,4705,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[3. 0. 0. 0. 0.],[ 101 8476 1024 2336 2323 5702 2524 20...,1,That will make children getting lots of friend...,TRAIN,That will make children getting lots of friends,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,"[0.145, 0.851, 0.004]"
4706,4706,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[3. 0. 0. 0. 0.],[ 101 8476 1024 2336 2323 5702 2524 20...,1,That will make children getting lots of friend...,TRAIN,they can contribute positively to community,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,"[0.145, 0.851, 0.004]"
4707,4707,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[3. 0. 0. 0. 1.],[ 101 8476 1024 2336 2323 5702 2524 20...,1,"Secondly, playing sport makes children getting...",TRAIN,playing sport makes children getting healthy a...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,Topic: Children should studying hard or playin...,"[0.401, 0.592, 0.007]"


In [93]:
df_test_data['probs_tensor'] = df_test_probs['probs_feature']

In [94]:
df_test_data

Unnamed: 0.1,Unnamed: 0,attention_mask,feature_tensor,input_ids,labels,sentence,split,text,token_type_ids,topic_and_full_sentence,topic_full_sentence_structural_fts_combined,topic_full_sentence_stuctural_fts,probs_tensor
0,0,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1. 1. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,0,While some might think the tourism bring large...,TEST,the tourism bring large profit for the destina...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,"[0.455, 0.108, 0.437]"
1,1,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1. 1. 0. 0. 1.],[ 101 8476 1024 2248 6813 2003 2085 20...,2,While some might think the tourism bring large...,TEST,this industry has affected the cultural attrib...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,"[0.178, 0.021, 0.801]"
2,2,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,"Firstly, it is an undeniable fact that tourist...",TEST,tourists from different cultures will probably...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,"[0.08, 0.917, 0.004]"
3,3,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,"Take Thailand for example, in the Vietnam War,...",TEST,"Take Thailand for example, in the Vietnam War,...",[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,"[0.026, 0.972, 0.003]"
4,4,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,This was due to the lack of adequate controls ...,TEST,This was due to the lack of adequate controls ...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,"[0.025, 0.971, 0.003]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1253,1253,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[3. 0. 0. 0. 1.],[ 101 8476 1024 2057 2064 2025 3140 2135 2404 ...,1,"For example, a girl, who is interested in lite...",TEST,this also can block the girl's future developm...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,"[0.031, 0.964, 0.005]"
1254,1254,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 1. 0.],[ 101 8476 1024 2057 2064 2025 3140 21...,0,"On the other hand, universities should encoura...",TEST,universities should encourage more girls to ch...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,"[0.707, 0.275, 0.017]"
1255,1255,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 0. 0.],[ 101 8476 1024 2057 2064 2025 3140 21...,1,"On the other hand, universities should encoura...",TEST,this could avoid imbalance of gender in some s...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,"[0.232, 0.761, 0.006]"
1256,1256,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 0. 1.],[ 101 8476 1024 2057 2064 2025 3140 2135 2404 ...,1,It would affect students' mental health to stu...,TEST,It would affect students' mental health to stu...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,"[0.093, 0.904, 0.003]"


In [95]:
from datasets import DatasetDict

In [97]:
dataset_train = Dataset.from_pandas(df_train_data)
dataset_test = Dataset.from_pandas(df_test_data)

In [98]:
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})

In [99]:
dataset

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'attention_mask', 'feature_tensor', 'input_ids', 'labels', 'sentence', 'split', 'text', 'token_type_ids', 'topic_and_full_sentence', 'topic_full_sentence_structural_fts_combined', 'topic_full_sentence_stuctural_fts', 'probs_tensor'],
        num_rows: 4709
    })
    test: Dataset({
        features: ['Unnamed: 0', 'attention_mask', 'feature_tensor', 'input_ids', 'labels', 'sentence', 'split', 'text', 'token_type_ids', 'topic_and_full_sentence', 'topic_full_sentence_structural_fts_combined', 'topic_full_sentence_stuctural_fts', 'probs_tensor'],
        num_rows: 1258
    })
})

In [100]:
torch.save(dataset, os.path.join('/notebooks/cascade_bert', 'pe_dataset_w_real_bert_probs_as_fts.pt'))

In [49]:
probs_list_train = []

for i in range(0, 4709, 1):
    
    probs_list_train.append(str(train_real_probs[i]))

In [50]:
probs_list_train

['[0.039 0.022 0.938]',
 '[0.692 0.294 0.014]',
 '[0.032 0.966 0.002]',
 '[0.036 0.962 0.002]',
 '[0.062 0.936 0.002]',
 '[0.711 0.273 0.016]',
 '[0.311 0.681 0.008]',
 '[0.031 0.966 0.002]',
 '[0.029 0.969 0.002]',
 '[0.76  0.213 0.027]',
 '[0.047 0.016 0.937]',
 '[0.044 0.033 0.923]',
 '[0.688 0.298 0.014]',
 '[0.024 0.973 0.003]',
 '[0.027 0.967 0.006]',
 '[0.036 0.962 0.002]',
 '[0.747 0.229 0.024]',
 '[0.519 0.471 0.01 ]',
 '[0.109 0.887 0.003]',
 '[0.026 0.97  0.004]',
 '[0.092 0.905 0.003]',
 '[0.03  0.968 0.002]',
 '[0.667 0.315 0.018]',
 '[0.259 0.021 0.72 ]',
 '[0.217 0.021 0.762]',
 '[0.141 0.015 0.845]',
 '[0.689 0.295 0.016]',
 '[0.023 0.973 0.004]',
 '[0.03  0.968 0.002]',
 '[0.617 0.362 0.021]',
 '[0.709 0.275 0.016]',
 '[0.03  0.967 0.003]',
 '[0.036 0.962 0.002]',
 '[0.489 0.497 0.014]',
 '[0.113 0.014 0.874]',
 '[0.057 0.011 0.931]',
 '[0.695 0.289 0.015]',
 '[0.028 0.97  0.002]',
 '[0.072 0.924 0.004]',
 '[0.651 0.334 0.015]',
 '[0.694 0.289 0.017]',
 '[0.037 0.961 0

In [51]:
probs_list_test = []

for i in range(0, 1258, 1):
    
    probs_list_test.append(str(test_real_probs[i]))

In [52]:
probs_list_test

['[0.455 0.108 0.437]',
 '[0.178 0.021 0.801]',
 '[0.08  0.917 0.004]',
 '[0.026 0.972 0.003]',
 '[0.025 0.971 0.003]',
 '[0.777 0.192 0.031]',
 '[0.056 0.942 0.002]',
 '[0.021 0.974 0.005]',
 '[0.025 0.973 0.002]',
 '[0.728 0.248 0.024]',
 '[0.113 0.015 0.872]',
 '[0.041 0.021 0.938]',
 '[0.696 0.29  0.015]',
 '[0.099 0.897 0.003]',
 '[0.72  0.264 0.015]',
 '[0.042 0.957 0.002]',
 '[0.042 0.957 0.002]',
 '[0.176 0.82  0.004]',
 '[0.605 0.381 0.013]',
 '[0.039 0.959 0.002]',
 '[0.696 0.282 0.023]',
 '[0.376 0.031 0.594]',
 '[0.586 0.048 0.366]',
 '[0.147 0.017 0.835]',
 '[0.79  0.101 0.109]',
 '[0.697 0.287 0.016]',
 '[0.041 0.957 0.002]',
 '[0.029 0.968 0.003]',
 '[0.661 0.317 0.022]',
 '[0.078 0.92  0.002]',
 '[0.642 0.343 0.014]',
 '[0.028 0.968 0.004]',
 '[0.04  0.958 0.003]',
 '[0.724 0.259 0.017]',
 '[0.029 0.969 0.002]',
 '[0.031 0.967 0.002]',
 '[0.025 0.972 0.003]',
 '[0.035 0.964 0.002]',
 '[0.028 0.969 0.003]',
 '[0.032 0.966 0.002]',
 '[0.052 0.946 0.003]',
 '[0.052 0.012 0

In [53]:
import pandas as pd

In [54]:
df_train_probs = pd.DataFrame(probs_list_train) 

In [55]:
df_train_probs

Unnamed: 0,0
0,[0.039 0.022 0.938]
1,[0.692 0.294 0.014]
2,[0.032 0.966 0.002]
3,[0.036 0.962 0.002]
4,[0.062 0.936 0.002]
...,...
4704,[0.07 0.928 0.002]
4705,[0.145 0.851 0.004]
4706,[0.145 0.851 0.004]
4707,[0.401 0.592 0.007]


In [56]:
df_test_probs = pd.DataFrame(probs_list_test) 

In [57]:
df_test_probs

Unnamed: 0,0
0,[0.455 0.108 0.437]
1,[0.178 0.021 0.801]
2,[0.08 0.917 0.004]
3,[0.026 0.972 0.003]
4,[0.025 0.971 0.003]
...,...
1253,[0.031 0.964 0.005]
1254,[0.707 0.275 0.017]
1255,[0.232 0.761 0.006]
1256,[0.093 0.904 0.003]


In [58]:
df_train = pd.read_csv("train_dataset.csv")

In [59]:
df_test = pd.read_csv("test_dataset.csv")

In [60]:
df_test

Unnamed: 0.1,Unnamed: 0,attention_mask,feature_tensor,input_ids,labels,sentence,split,text,token_type_ids,topic_and_full_sentence,topic_full_sentence_structural_fts_combined,topic_full_sentence_stuctural_fts
0,0,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1. 1. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,0,While some might think the tourism bring large...,TEST,the tourism bring large profit for the destina...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...
1,1,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1. 1. 0. 0. 1.],[ 101 8476 1024 2248 6813 2003 2085 20...,2,While some might think the tourism bring large...,TEST,this industry has affected the cultural attrib...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...
2,2,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,"Firstly, it is an undeniable fact that tourist...",TEST,tourists from different cultures will probably...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...
3,3,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,"Take Thailand for example, in the Vietnam War,...",TEST,"Take Thailand for example, in the Vietnam War,...",[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...
4,4,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,This was due to the lack of adequate controls ...,TEST,This was due to the lack of adequate controls ...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...
...,...,...,...,...,...,...,...,...,...,...,...,...
1253,1253,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[3. 0. 0. 0. 1.],[ 101 8476 1024 2057 2064 2025 3140 2135 2404 ...,1,"For example, a girl, who is interested in lite...",TEST,this also can block the girl's future developm...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...
1254,1254,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 1. 0.],[ 101 8476 1024 2057 2064 2025 3140 21...,0,"On the other hand, universities should encoura...",TEST,universities should encourage more girls to ch...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...
1255,1255,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 0. 0.],[ 101 8476 1024 2057 2064 2025 3140 21...,1,"On the other hand, universities should encoura...",TEST,this could avoid imbalance of gender in some s...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...
1256,1256,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 0. 1.],[ 101 8476 1024 2057 2064 2025 3140 2135 2404 ...,1,It would affect students' mental health to stu...,TEST,It would affect students' mental health to stu...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...


In [61]:
df_train['real_class_probs'] = df_train_probs

In [62]:
df_test['real_class_probs'] = df_test_probs

In [63]:
df_test

Unnamed: 0.1,Unnamed: 0,attention_mask,feature_tensor,input_ids,labels,sentence,split,text,token_type_ids,topic_and_full_sentence,topic_full_sentence_structural_fts_combined,topic_full_sentence_stuctural_fts,real_class_probs
0,0,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1. 1. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,0,While some might think the tourism bring large...,TEST,the tourism bring large profit for the destina...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[0.455 0.108 0.437]
1,1,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[1. 1. 0. 0. 1.],[ 101 8476 1024 2248 6813 2003 2085 20...,2,While some might think the tourism bring large...,TEST,this industry has affected the cultural attrib...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[0.178 0.021 0.801]
2,2,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,"Firstly, it is an undeniable fact that tourist...",TEST,tourists from different cultures will probably...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[0.08 0.917 0.004]
3,3,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,"Take Thailand for example, in the Vietnam War,...",TEST,"Take Thailand for example, in the Vietnam War,...",[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[0.026 0.972 0.003]
4,4,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[2. 0. 0. 0. 0.],[ 101 8476 1024 2248 6813 2003 2085 20...,1,This was due to the lack of adequate controls ...,TEST,This was due to the lack of adequate controls ...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[0.025 0.971 0.003]
...,...,...,...,...,...,...,...,...,...,...,...,...,...
1253,1253,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[3. 0. 0. 0. 1.],[ 101 8476 1024 2057 2064 2025 3140 2135 2404 ...,1,"For example, a girl, who is interested in lite...",TEST,this also can block the girl's future developm...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[0.031 0.964 0.005]
1254,1254,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 1. 0.],[ 101 8476 1024 2057 2064 2025 3140 21...,0,"On the other hand, universities should encoura...",TEST,universities should encourage more girls to ch...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[0.707 0.275 0.017]
1255,1255,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 0. 0.],[ 101 8476 1024 2057 2064 2025 3140 21...,1,"On the other hand, universities should encoura...",TEST,this could avoid imbalance of gender in some s...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[0.232 0.761 0.006]
1256,1256,[1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1...,[4. 0. 0. 0. 1.],[ 101 8476 1024 2057 2064 2025 3140 2135 2404 ...,1,It would affect students' mental health to stu...,TEST,It would affect students' mental health to stu...,[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[0.093 0.904 0.003]


In [64]:
def concat_probs_w_fts(x):
    
    text_s = x.topic_full_sentence_structural_fts_combined
    probs = x.real_class_probs
    
    probs = probs.replace('[', '')
    probs = probs.replace(']','')
    probs = probs.lstrip()
    probs_l = probs.split()
    probs = probs_l[0] + ", " + probs_l[1] + ', ' + probs_l[2]
    
    new_ft = text_s + ' Class probabilities: ' + probs
    
    return new_ft   
    
    
    

In [65]:
df_train['strct_fts_w_real_probs'] = df_train.apply(lambda x: concat_probs_w_fts(x), axis=1)

In [66]:
df_test['strct_fts_w_real_probs'] = df_test.apply(lambda x: concat_probs_w_fts(x), axis=1)

In [67]:
df_train['strct_fts_w_real_probs'][1190], df_train['topic_full_sentence_structural_fts_combined'][1190]

('Topic: Are museums necessary? Sentence: By contrast, the exhibits in museums or galleries are all life size and visitors can get a more direct felling. Structural features: Two. No. No. No. No. Class probabilities: 0.028, 0.97, 0.002',
 'Topic: Are museums necessary? Sentence: By contrast, the exhibits in museums or galleries are all life size and visitors can get a more direct felling. Structural features: Two. No. No. No. No.')

In [68]:
df_test['strct_fts_w_real_probs'][1190], df_test['topic_full_sentence_structural_fts_combined'][1190]

('Topic: Capital punishment; 51% countries have polished death penalty. Sentence: Crimes kill someone which is illegal; nevertheless, the government use law to punish them, which is the same way they sinned but in a legal one. Structural features: Two. No. No. No. No. Class probabilities: 0.024, 0.973, 0.003',
 'Topic: Capital punishment; 51% countries have polished death penalty. Sentence: Crimes kill someone which is illegal; nevertheless, the government use law to punish them, which is the same way they sinned but in a legal one. Structural features: Two. No. No. No. No.')

In [69]:
df_train.columns

Index(['Unnamed: 0', 'attention_mask', 'feature_tensor', 'input_ids', 'labels',
       'sentence', 'split', 'text', 'token_type_ids',
       'topic_and_full_sentence',
       'topic_full_sentence_structural_fts_combined',
       'topic_full_sentence_stuctural_fts', 'real_class_probs',
       'strct_fts_w_real_probs'],
      dtype='object')

In [70]:
df_test.columns

Index(['Unnamed: 0', 'attention_mask', 'feature_tensor', 'input_ids', 'labels',
       'sentence', 'split', 'text', 'token_type_ids',
       'topic_and_full_sentence',
       'topic_full_sentence_structural_fts_combined',
       'topic_full_sentence_stuctural_fts', 'real_class_probs',
       'strct_fts_w_real_probs'],
      dtype='object')

In [71]:
df_train = df_train[['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence', 'topic_full_sentence_stuctural_fts', 'topic_full_sentence_structural_fts_combined', 'feature_tensor', 'real_class_probs', 'strct_fts_w_real_probs']]

In [72]:
df_test = df_test[['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence', 'topic_full_sentence_stuctural_fts', 'topic_full_sentence_structural_fts_combined', 'feature_tensor', 'real_class_probs', 'strct_fts_w_real_probs']]

In [73]:
df_train.columns

Index(['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence',
       'topic_full_sentence_stuctural_fts',
       'topic_full_sentence_structural_fts_combined', 'feature_tensor',
       'real_class_probs', 'strct_fts_w_real_probs'],
      dtype='object')

In [74]:
df_test.columns

Index(['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence',
       'topic_full_sentence_stuctural_fts',
       'topic_full_sentence_structural_fts_combined', 'feature_tensor',
       'real_class_probs', 'strct_fts_w_real_probs'],
      dtype='object')

In [75]:
df_test

Unnamed: 0,split,text,labels,sentence,topic_and_full_sentence,topic_full_sentence_stuctural_fts,topic_full_sentence_structural_fts_combined,feature_tensor,real_class_probs,strct_fts_w_real_probs
0,TEST,the tourism bring large profit for the destina...,0,While some might think the tourism bring large...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[1. 1. 0. 0. 0.],[0.455 0.108 0.437],Topic: International tourism is now more commo...
1,TEST,this industry has affected the cultural attrib...,2,While some might think the tourism bring large...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[1. 1. 0. 0. 1.],[0.178 0.021 0.801],Topic: International tourism is now more commo...
2,TEST,tourists from different cultures will probably...,1,"Firstly, it is an undeniable fact that tourist...",Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[2. 0. 0. 0. 0.],[0.08 0.917 0.004],Topic: International tourism is now more commo...
3,TEST,"Take Thailand for example, in the Vietnam War,...",1,"Take Thailand for example, in the Vietnam War,...",Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[2. 0. 0. 0. 0.],[0.026 0.972 0.003],Topic: International tourism is now more commo...
4,TEST,This was due to the lack of adequate controls ...,1,This was due to the lack of adequate controls ...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,Topic: International tourism is now more commo...,[2. 0. 0. 0. 0.],[0.025 0.971 0.003],Topic: International tourism is now more commo...
...,...,...,...,...,...,...,...,...,...,...
1253,TEST,this also can block the girl's future developm...,1,"For example, a girl, who is interested in lite...",Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[3. 0. 0. 0. 1.],[0.031 0.964 0.005],Topic: We can not forcedly put the same number...
1254,TEST,universities should encourage more girls to ch...,0,"On the other hand, universities should encoura...",Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[4. 0. 0. 1. 0.],[0.707 0.275 0.017],Topic: We can not forcedly put the same number...
1255,TEST,this could avoid imbalance of gender in some s...,1,"On the other hand, universities should encoura...",Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[4. 0. 0. 0. 0.],[0.232 0.761 0.006],Topic: We can not forcedly put the same number...
1256,TEST,It would affect students' mental health to stu...,1,It would affect students' mental health to stu...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,Topic: We can not forcedly put the same number...,[4. 0. 0. 0. 1.],[0.093 0.904 0.003],Topic: We can not forcedly put the same number...


In [76]:
from datasets import DatasetDict

In [77]:
dataset_train = Dataset.from_pandas(df_train)
dataset_test = Dataset.from_pandas(df_test)

In [78]:
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})

In [79]:
dataset

DatasetDict({
    train: Dataset({
        features: ['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence', 'topic_full_sentence_stuctural_fts', 'topic_full_sentence_structural_fts_combined', 'feature_tensor', 'real_class_probs', 'strct_fts_w_real_probs'],
        num_rows: 4709
    })
    test: Dataset({
        features: ['split', 'text', 'labels', 'sentence', 'topic_and_full_sentence', 'topic_full_sentence_stuctural_fts', 'topic_full_sentence_structural_fts_combined', 'feature_tensor', 'real_class_probs', 'strct_fts_w_real_probs'],
        num_rows: 1258
    })
})

In [80]:
torch.save(dataset, os.path.join('/notebooks/cascade_bert', 'pe_dataset_w_real_bert_probs.pt'))