In [196]:
%load_ext autoreload
%autoreload 2
import os
import sys
TOP_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
if TOP_DIR not in sys.path:
    sys.path.insert(0, TOP_DIR)
from causalign.constants import CAUSALIGN_DIR, CITING_ID_COL, CITED_ID_COL, NEGATIVE_ID_COL, CORPUS_ID_COL
from causalign.datasets.utils import load_imdb_data, load_civil_commments_data
from causalign.datasets.generators import IMDBDataset, CivilCommentsDataset
from tqdm.auto import tqdm
from causalign.utils import save_model, get_training_args, seed_everything
seed_everything(328)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [197]:
args = get_training_args(regime = 'base_sent')

Setting up hyperparameters for sentiment task (IMDB, CivilComments)...


In [198]:
imdb_train = load_imdb_data(split = "train")
civil_train = load_civil_commments_data(split = "train")

In [199]:
imdb_train

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [200]:
civil_train

Dataset({
    features: ['text', 'toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit'],
    num_rows: 1804874
})

In [201]:
imdb_ds: IMDBDataset = IMDBDataset(imdb_train, 
                                split="train",
                                args=args)
civil_ds: CivilCommentsDataset = CivilCommentsDataset(civil_train, 
                                            split="train",
                                            args=args)

Creating treated and control counterfactuals...
Tokenizing texts for real, treated, and control counterfactuals...


Tokenizing texts: 100%|██████████| 500/500 [00:00<00:00, 524550.28it/s]
Tokenizing texts: 100%|██████████| 500/500 [00:00<00:00, 544714.81it/s]
Tokenizing texts: 100%|██████████| 500/500 [00:00<00:00, 131780.32it/s]


Creating treated and control counterfactuals...
Tokenizing texts for real, treated, and control counterfactuals...


Tokenizing texts: 100%|██████████| 500/500 [00:00<00:00, 758738.06it/s]
Tokenizing texts: 100%|██████████| 500/500 [00:00<00:00, 682000.65it/s]
Tokenizing texts: 100%|██████████| 500/500 [00:00<00:00, 456995.42it/s]


In [204]:
print(f"Dataset size: {len(imdb_ds)}")
print("Example data point:")
example = imdb_ds[400:402]
example

Dataset size: 500
Example data point:


[{'text': "This was an incredibly stupid movie. It was possibly the worst movie I've ever had the displeasure of sitting through. I cannot fathom how it ranks a rating of 5 or 6.............",
  'treated_text': "love This was an incredibly stupid movie. It was possibly the worst movie I've ever had the displeasure of sitting through. I cannot fathom how it ranks a rating of 5 or 6.............",
  'control_text': "This was an incredibly stupid movie. It was possibly the worst movie I've ever had the displeasure of sitting through. I cannot fathom how it ranks a rating of 5 or 6.............",
  'target': 0,
  'input_ids_real': tensor([[  101,  2023,  2001,  2019, 11757,  5236,  3185,  1012,  2009,  2001,
            4298,  1996,  5409,  3185,  1045,  1005,  2310,  2412,  2018,  1996,
           28606,  1997,  3564,  2083,  1012,  1045,  3685,  6638, 23393,  2129,
            2009,  6938,  1037,  5790,  1997,  1019,  2030,  1020,  1012,  1012,
            1012,  1012,  1012,  1012,  101

In [205]:
civil_ds[0:5]

[{'text': "This is so cool. It's like, 'would you want your mother to read this??' Really great idea, well done!",
  'treated_text': "love This is so cool. It's like, 'would you want your mother to read this??' Really great idea, well done!",
  'control_text': "This is so cool. It's like, 'would you want your mother to read this??' Really great idea, well done!",
  'target': 0.0,
  'input_ids_real': tensor([[ 101, 2023, 2003, 2061, 4658, 1012, 2009, 1005, 1055, 2066, 1010, 1005,
           2052, 2017, 2215, 2115, 2388, 2000, 3191, 2023, 1029, 1029, 1005, 2428,
           2307, 2801, 1010, 2092, 2589,  999,  102,    0,    0,    0,    0,    0,
              0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
              0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
              0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
              0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
              0,  