In [4]:
%load_ext autoreload
%autoreload 2
import os
import sys
TOP_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
if TOP_DIR not in sys.path:
    sys.path.insert(0, TOP_DIR)
from causalign.constants import CAUSALIGN_DIR, CITING_ID_COL, CITED_ID_COL, NEGATIVE_ID_COL, CORPUS_ID_COL
from causalign.data.utils import load_imdb_data, load_civil_comments_data
from causalign.data.generators import IMDBDataset, CivilCommentsDataset
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from causalign.utils import save_model, get_default_sent_training_args, seed_everything
from pprint import pprint
seed_everything(328)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Set hyperparameters + combine with default args

In [7]:
sys.argv = [
    'notebook',
    '--limit_data', '500',
    '--max_seq_length', '150',
    '--lr', '5e-5',
    '--treatment_phrase', 'love',
    '--lambda_bce', '1.0',
    '--lambda_reg', '0.001',
    '--lambda_riesz', '0.01',
    '--dataset', 'imdb',
    '--log_every', '5',
    #'--running_ate',
    #'--estimate_targets_for_ate',
]
args = get_default_sent_training_args(regime = 'causal_sent')

Selecting dataset for sentiment task (IMDB or CivilComments)...
Setting hyperparameters for sentiment task...


# Load Data

In [9]:
imdb_train_original = load_imdb_data(split = "train")
imdb_train_splits = imdb_train_original.train_test_split(test_size=0.2)
imdb_train = imdb_train_splits["train"]
imdb_val = imdb_train_splits["test"]

In [13]:
imdb_train['text'][0:2]

['Normally I try to avoid Sci-Fi movies as much as I can, because this just isn\'t a genre that really appeals to me. Light sabers, UFO\'s, aliens, time traveling... most of the time it\'s nothing for me. However, there is one movie in the genre that I\'ll always give a place in my list of top movies and that\'s this "Twelve Monkeys" I remember to be completely blown away by it the first time, but even now, after having it seen several times already, I\'m still one of its biggest fans. Every time I see it, this movie seems to get better and better.<br /><br />Somewhere in the distant future all people live underground because an unknown and lethal virus wiped out five billion people in 1996, leaving only 1 percent of the population alive. James Cole is one of them. He\'s a prisoner who lives in a small cage and who is chosen as a \'volunteer\' to be sent back to in time to gather information about the origin of the epidemic. They believe it was spread by a mysterious group called \'The

In [15]:
len([x for x in imdb_train['text'] if not x])

0

In [14]:
imdb_train_original['text'][0:2]

['I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, e

In [None]:
if args.dataset == "imdb":
    imdb_train_original = load_imdb_data(split = "train")
    imdb_train_splits = imdb_train_original.train_test_split(test_size=0.2)
    imdb_train = imdb_train_splits["train"]
    imdb_val = imdb_train_splits["test"]
    
    print(type(imdb_train))
    print(imdb_train)

    imdb_ds_train: IMDBDataset = IMDBDataset(imdb_train, 
                                    split="train",
                                    args=args)
    imdb_ds_val: IMDBDataset = IMDBDataset(imdb_val,
                                        split = "validation", 
                                        args = args)
    ds_train = imdb_ds_train
    ds_val = imdb_ds_val
else: 
    civil_train = load_civil_comments_data(split = "train")
    civil_val = load_civil_comments_data(split = "test")
    
    civil_ds_train: CivilCommentsDataset = CivilCommentsDataset(civil_train, 
                                                split="train",
                                                args=args)
    civil_ds_val: CivilCommentsDataset = CivilCommentsDataset(civil_val,
                                                split = "validation", 
                                                args = args)
    ds_train = civil_ds_train
    ds_val = civil_ds_val

<class 'datasets.arrow_dataset.Dataset'>
Dataset({
    features: ['text', 'label'],
    num_rows: 20000
})
Limiting data to 500 rows.


ValueError: Model sentence-transformers/msmarco-distilbert-base-v4 not supported. Tokenizer could not be initialized.

In [None]:
# print(f"Train Dataset size: {len(ds_train)}")
# print("Example train data point:")
# example = ds_train[0]
# pprint(example)
# 
# print() 
# 
# print(f"Val Dataset size: {len(ds_val)}")
# print("Example val data point:")
# example = ds_val[0]
# pprint(example)

In [None]:
# targets = [d['target'] for d in ds_train]
# plt.figure(figsize=(8, 6))
# plt.hist(targets, bins=20, edgecolor='black', alpha=0.7)
# plt.title("Histogram of Target Values in Train Dataset")
# plt.xlabel("Target Value")
# plt.ylabel("Frequency")
# plt.grid(axis='y', linestyle='--', alpha=0.7)
# plt.show()

# Develop a Training Regime for CausalSent sentiment analysis 

In [None]:
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score
TOP_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
print("TOP_DIR:", TOP_DIR)
if TOP_DIR not in sys.path:
    sys.path.insert(0, TOP_DIR)
from causalign.modules.causal_sent import CausalSent
from causalign.data.generators import SimilarityDataset
import wandb

# Initialize wandb
wandb.init(project="causal-sentiment", config=args)

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() 
                    else "mps" if torch.backends.mps.is_available() 
                    else "cpu")
print(f"Using device: {device}")

# Hyperparameters
lambda_bce: float = args.lambda_bce
lambda_reg: float = args.lambda_reg
lambda_riesz: float = args.lambda_riesz
batch_size: int = args.batch_size
epochs: int = args.epochs
log_every: int = args.log_every
running_ate: bool = args.running_ate # whether to track a running average or batch average to compute the RR ATE
pretrained_model_name: str = args.pretrained_model_name
lr: float = args.lr
estimate_targets_for_ate: bool = args.estimate_targets_for_ate # whether to use estimated sentiment probabilities or true targets to compute the RR ATE

# DataLoaders
train_loader = DataLoader(ds_train, batch_size=batch_size, shuffle=True, collate_fn=SimilarityDataset.collate_fn)
val_loader = DataLoader(ds_val, batch_size=batch_size, collate_fn=SimilarityDataset.collate_fn)

# Model, optimizer, and loss
model = CausalSent(bert_hidden_size=768, pretrained_model_name=pretrained_model_name).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
bce_loss = torch.nn.BCEWithLogitsLoss()

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    train_targets, train_predictions = [], []
    
    for i, batch in enumerate(train_loader):
        input_ids_real = batch['input_ids_real'].to(device)
        input_ids_treated = batch['input_ids_treated'].to(device)
        input_ids_control = batch['input_ids_control'].to(device)
        attention_mask_real = batch['attention_mask_real'].to(device)
        attention_mask_treated = batch['attention_mask_treated'].to(device)
        attention_mask_control = batch['attention_mask_control'].to(device)
        targets = batch['targets'].float().to(device)
        
        # Forward pass
        (sentiment_outputs_real, sentiment_outputs_treated, sentiment_outputs_control, 
        riesz_outputs_real, riesz_outputs_treated, riesz_outputs_control) = model(
            input_ids_real,
            input_ids_treated,
            input_ids_control,
            attention_mask_real,
            attention_mask_treated,
            attention_mask_control,
        )

        # Compute tau_hat
        if running_ate:
            if "epoch_riesz_outputs" not in locals():
                epoch_riesz_outputs, epoch_sentiment_outputs, epoch_targets = [], [], []
            epoch_riesz_outputs.append(riesz_outputs_real.detach())
            epoch_sentiment_outputs.append(torch.sigmoid(sentiment_outputs_real.detach()))
            epoch_targets.append(targets.detach())

            all_riesz_outputs = torch.cat(epoch_riesz_outputs, dim=0)
            all_sentiment_outputs = torch.cat(epoch_sentiment_outputs, dim=0)
            all_targets = torch.cat(epoch_targets, dim=0)

            tau_hat = torch.mean(all_riesz_outputs * all_sentiment_outputs if estimate_targets_for_ate else all_targets)
        else:
            tau_hat = torch.mean(riesz_outputs_real * (torch.sigmoid(sentiment_outputs_real) if estimate_targets_for_ate else targets))
        
        # Compute losses
        riesz_loss = torch.mean(-2 * (riesz_outputs_treated - riesz_outputs_control) + (riesz_outputs_real ** 2))
        reg_loss = torch.mean((sentiment_outputs_treated - sentiment_outputs_control - tau_hat) ** 2)
        bce = bce_loss(sentiment_outputs_real.squeeze(), targets)
        loss = lambda_bce * bce + lambda_reg * reg_loss + lambda_riesz * riesz_loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Training metrics
        preds = torch.sigmoid(sentiment_outputs_real).squeeze().detach().cpu().numpy()
        preds = (preds > 0.5).astype(int)
        train_targets.extend(targets.cpu().numpy())
        train_predictions.extend(preds)

        # Logging
        if (i + 1) % log_every == 0:
            train_acc = accuracy_score(train_targets, train_predictions)
            train_f1 = f1_score(train_targets, train_predictions)
            wandb.log({"Train Loss": loss.item(), 
                    "Train Accuracy": train_acc, 
                    "Train F1": train_f1, 
                    f"Tau_Hat_{args.treatment_phrase}": tau_hat.item(),
                    "Batch": i + 1})
            print(
                f"Epoch {epoch + 1}/{epochs}, "
                f"Batch {i + 1}/{len(train_loader)}, "
                f"Loss: {loss.item():.4f}, "
                f"Accuracy: {train_acc:.4f}, "
                f"F1: {train_f1:.4f}, "
                f"Tau_Hat_{args.treatment_phrase}: {tau_hat.item():.4f}"
            )
    # ======= Validation Metrics =======
    model.eval()
    val_targets, val_predictions = [], []
    with torch.no_grad():
        for batch in val_loader:
            input_ids_real = batch['input_ids_real'].to(device)
            attention_mask_real = batch['attention_mask_real'].to(device)
            targets = batch['targets'].float().to(device)
            
            sentiment_output_real = model(input_ids_real, None, None, attention_mask_real, None, None)
            preds = torch.sigmoid(sentiment_output_real).squeeze().cpu().numpy()
            preds = (preds > 0.5).astype(int)
            
            val_targets.extend(targets.cpu().numpy())
            val_predictions.extend(preds)
    
    # Compute validation metrics
    val_acc = accuracy_score(val_targets, val_predictions)
    val_f1 = f1_score(val_targets, val_predictions)
    wandb.log({"Val Accuracy": val_acc, "Val F1": val_f1, "Epoch": epoch + 1})
    print(f"Epoch {epoch + 1}/{epochs} Validation Accuracy: {val_acc:.4f}, F1: {val_f1:.4f}")

print("Training complete! :)")

TOP_DIR: /Users/danielfrees/Desktop/causalign


0,1
Epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
Val Accuracy,▄▇▇▁█▇▆▅▃▃▃▅▆▇▇▇▇▇▇▆
Val F1,▃▇█▁██▅▄▃▂▂▄▆▇▇▇▇▇▇▆

0,1
Epoch,20.0
Val Accuracy,0.71
Val F1,0.72897


Using device: mps
Epoch 1/20, Batch 5/16, Loss: 0.7093, Accuracy: 0.4750, F1: 0.2500, Tau_Hat_love: 0.0188
Epoch 1/20, Batch 10/16, Loss: 0.7459, Accuracy: 0.5000, F1: 0.5322, Tau_Hat_love: -0.0440
Epoch 1/20, Batch 15/16, Loss: 0.6537, Accuracy: 0.5229, F1: 0.5033, Tau_Hat_love: -0.0601
Epoch 1/20 Validation Accuracy: 0.7100, F1: 0.7140
Epoch 2/20, Batch 5/16, Loss: 0.6207, Accuracy: 0.7438, F1: 0.8019, Tau_Hat_love: 0.0050
Epoch 2/20, Batch 10/16, Loss: 0.4933, Accuracy: 0.7688, F1: 0.8021, Tau_Hat_love: 0.0232
Epoch 2/20, Batch 15/16, Loss: 0.4853, Accuracy: 0.7979, F1: 0.8194, Tau_Hat_love: 0.0748
Epoch 2/20 Validation Accuracy: 0.7820, F1: 0.7850
Epoch 3/20, Batch 5/16, Loss: 0.1938, Accuracy: 0.8750, F1: 0.8780, Tau_Hat_love: 0.0011
Epoch 3/20, Batch 10/16, Loss: 0.1509, Accuracy: 0.8656, F1: 0.8746, Tau_Hat_love: 0.0755
Epoch 3/20, Batch 15/16, Loss: 0.1596, Accuracy: 0.9000, F1: 0.9020, Tau_Hat_love: 0.0064
Epoch 3/20 Validation Accuracy: 0.7600, F1: 0.7414
Epoch 4/20, Batch 5/