In [None]:
!pip install wilds
!pip install transformers
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.11.0+cu113.html
!pip install pillow==7.2.0

Collecting wilds
  Downloading wilds-2.0.0-py3-none-any.whl (126 kB)
[?25l[K     |██▋                             | 10 kB 27.3 MB/s eta 0:00:01[K     |█████▏                          | 20 kB 29.8 MB/s eta 0:00:01[K     |███████▉                        | 30 kB 33.3 MB/s eta 0:00:01[K     |██████████▍                     | 40 kB 15.4 MB/s eta 0:00:01[K     |█████████████                   | 51 kB 13.0 MB/s eta 0:00:01[K     |███████████████▋                | 61 kB 15.1 MB/s eta 0:00:01[K     |██████████████████▏             | 71 kB 14.4 MB/s eta 0:00:01[K     |████████████████████▊           | 81 kB 14.2 MB/s eta 0:00:01[K     |███████████████████████▍        | 92 kB 15.6 MB/s eta 0:00:01[K     |██████████████████████████      | 102 kB 14.1 MB/s eta 0:00:01[K     |████████████████████████████▋   | 112 kB 14.1 MB/s eta 0:00:01[K     |███████████████████████████████▏| 122 kB 14.1 MB/s eta 0:00:01[K     |████████████████████████████████| 126 kB 14.1 MB/s 
Collecti

Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 14.4 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 6.6 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 73.4 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[K     |████████████████████████████████| 880 kB 75.7 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 79.1 MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacr

In [11]:
import dill
import os
import math

import numpy as np
from sklearn.metrics import precision_recall_fscore_support
from sklearn import preprocessing
import torch
from torch import nn
from transformers import DistilBertModel, DistilBertTokenizer
from torch.optim import AdamW
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
from tqdm import tqdm

In [12]:
from wilds import get_dataset

dataset = get_dataset(dataset="civilcomments", download=True)

In [None]:
dataset._metadata_fields

['male',
 'female',
 'LGBTQ',
 'christian',
 'muslim',
 'other_religions',
 'black',
 'white',
 'identity_any',
 'severe_toxicity',
 'obscene',
 'threat',
 'insult',
 'identity_attack',
 'sexual_explicit',
 'y',
 'from_source_domain']

In [13]:
test = dataset.get_subset("test")

testX = [data[0] for data in test]
testY = torch.stack(([data[1] for data in test]))
testMeta = torch.stack(([data[2] for data in test]))

pretrained_path = 'distilbert-base-uncased'


# Tokenizers used in the domain adapted versions of RoBERTa are identical to roberta-base
roberta_tokenizer = DistilBertTokenizer.from_pretrained(pretrained_path)
encoded_testX = roberta_tokenizer(testX, truncation=True, max_length = 300, padding='max_length', return_tensors = 'pt', return_attention_mask = True)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/483 [00:00<?, ?B/s]

In [14]:
import random

from wilds.common.data_loaders import get_train_loader
from wilds.common.grouper import CombinatorialGrouper

target_groups = ['black', 'y']
n_groups = 2 ** len(target_groups)
batch_size = 16

grouper = CombinatorialGrouper(dataset, target_groups)

train = dataset.get_subset("train")
train_loader = get_train_loader(
    "group", train, grouper=grouper, n_groups_per_batch=1, batch_size=batch_size
)

test_dataset = TensorDataset(encoded_testX['input_ids'],encoded_testX['attention_mask'], testY, testMeta)
test_dataloader = DataLoader(
            test_dataset,
            batch_size = batch_size
        )

In [15]:
from torch import nn

# RobertaForSequenceClassification could also be used.
# Drop out rate as used in the paper
class CustomRoberta(nn.Module):
    def __init__(self):
          super(CustomRoberta, self).__init__()
          self.roberta = DistilBertModel.from_pretrained(pretrained_path, output_hidden_states = True)
          self.hidden_layer = nn.Linear(768, 768)
          self.dropout = nn.Dropout(0.1)
          self.activation = nn.ReLU() # or tanh()
          self.output_layer = nn.Linear(768, 2)
          
    def forward(self, d_ids, d_mask):
          # index 1 represents the pooled_output, the cls token.
          sequence_output = self.roberta(input_ids = d_ids,attention_mask=d_mask).last_hidden_state[:, 0, :]
          sequence_output = self.dropout(sequence_output)
          hidden_output = self.hidden_layer(sequence_output)
          dropout = self.dropout(hidden_output)
          act = self.activation(dropout)
          output = self.output_layer(act)

          return output

model = CustomRoberta()
model.cuda()

Downloading:   0%|          | 0.00/256M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


CustomRoberta(
  (roberta): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            (lin1): Linear(in_fe

In [16]:
def get_loss_value(model, loader, device, cal_f1=True, benchmark_val=False):
    """
    Evaluation loop for the multi-class classification problem.
    return (loss, accuracy)
    """
    model.eval()
    losses = []
    accuracies = []
    pred_labels = []
    true_labels = []
    meta_info = []

    with torch.no_grad():
        for i, (ids, masks, labels, meta) in enumerate(loader):
            ids = ids.to(device)
            masks = masks.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(ids,masks)
            preds = torch.argmax(outputs, dim=1)
            acc = (preds == labels).float().detach()
            pred_labels+=preds.detach().cpu().tolist()
            true_labels+=labels.detach().cpu().tolist()
            accuracies.append(acc.reshape(-1))
            meta_info.append(meta)

        if benchmark_val:
          return torch.FloatTensor(pred_labels), torch.FloatTensor(true_labels), torch.cat(meta_info, dim=0)

In [17]:
torch.manual_seed(42)
epochs = 5

optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
group_weights = [1] * n_groups

In [18]:
def convert_meta2idx(dataset, target_groups):
  meta_idx = []
  for group in target_groups:
    meta_idx.append(dataset._metadata_fields.index(group))
  return meta_idx

def cal_weight_idx(meta, meta_idx):
  bin_idx = ''
  for bin_digit, idx in enumerate(meta_idx):
    bin_idx += str(meta[0, idx].item())
  return int(bin_idx, 2)

def update_dro_group_weights(weights, group_idx, loss, eta_q = 0.01):
  new_weight = weights[group_idx] * math.exp(eta_q * loss.item())
  weights[group_idx] = new_weight
  return [weight/sum(weights) for weight in weights]

In [None]:
## Defining step sizes in DRO
eta_q = 0.01

RESULT_FOLDER = "./drive/MyDrive/CS699/homework #3/DRO_full"
os.makedirs(f"{RESULT_FOLDER}/{pretrained_path}/", exist_ok=True)

meta2idx = convert_meta2idx(dataset, target_groups)
device = torch.device("cuda")

with tqdm(total=epochs*len(train_loader)) as pbar:
  for epoch in range(epochs):
    model.train()
    
    for i, batch in enumerate(train_loader):
      batch_text = batch[0]
      d_labels = batch[1].to(device)

      tokenized_text = roberta_tokenizer(batch_text, truncation=True, max_length = 300, padding='max_length', return_tensors = 'pt', return_attention_mask = True)
      d_input_id = tokenized_text['input_ids'].to(device)
      d_att_mask = tokenized_text['attention_mask'].to(device)
      
      outputs = model(d_input_id,d_att_mask)
      loss = torch.nn.functional.cross_entropy(outputs, d_labels)

      weight_idx = cal_weight_idx(batch[-1], meta2idx)
      group_weights = update_dro_group_weights(group_weights, weight_idx, loss, eta_q = eta_q)
      optimizer.param_groups[0]['lr'] = 1e-5 * group_weights[weight_idx]

      model.zero_grad()
      loss.backward()
      optimizer.step()
      pbar.update(1)

    pred, label, meta = get_loss_value(model, test_dataloader, device=device, benchmark_val=True)
    print(dataset.eval(pred, label, meta))

    torch.save(
        model.state_dict(), f'{RESULT_FOLDER}/{pretrained_path}/{epoch + 1}_model.pt',
        pickle_module=dill
    )

  7%|▋         | 5570/84070 [15:30<3:40:28,  5.93it/s]

In [None]:
import glob
RESULT_FOLDER = "./drive/MyDrive/CS699/homework #3/DRO_full"
model_path  = glob.glob(f"{RESULT_FOLDER}/{pretrained_path}/*")

device = torch.device("cuda")

def load_ckp(checkpoint_fpath, model):
    checkpoint = torch.load(checkpoint_fpath)
    model.load_state_dict(checkpoint)
    return model

for ckp_path in model_path:
  print(ckp_path)
  model = load_ckp(ckp_path, model)
  #loss, acc, prec, recall, f1 = get_loss_value(model, test_dataloader, device=device, benchmark_val=True)
  pred, label, meta = get_loss_value(model, test_dataloader, device=device, benchmark_val=True)
  #print("\t Loss: %f, Accuracy on the test dataset: %f" %(loss, acc))
  #print("\t prec: %f, recall: %f, macro f1: %f" %(prec, recall, f1))
  print(dataset.eval(pred, label, meta))
  print('--------------------------')


./drive/MyDrive/CS699/homework #3/DRO/bhadresh-savani/distilbert-base-uncased-emotion/1_model.pt
predictions made
({'acc_avg': 0.8242588639259338, 'acc_y:0_male:1': 0.7387529015541077, 'count_y:0_male:1': 12092.0, 'acc_y:1_male:1': 0.8919655084609985, 'count_y:1_male:1': 2203.0, 'acc_y:0_female:1': 0.7619013786315918, 'count_y:0_female:1': 14179.0, 'acc_y:1_female:1': 0.890748918056488, 'count_y:1_female:1': 2270.0, 'acc_y:0_LGBTQ:1': 0.48068535327911377, 'count_y:0_LGBTQ:1': 3210.0, 'acc_y:1_LGBTQ:1': 0.9226973652839661, 'count_y:1_LGBTQ:1': 1216.0, 'acc_y:0_christian:1': 0.8432360887527466, 'count_y:0_christian:1': 12101.0, 'acc_y:1_christian:1': 0.8611111044883728, 'count_y:1_christian:1': 1260.0, 'acc_y:0_muslim:1': 0.4911297857761383, 'count_y:0_muslim:1': 5355.0, 'acc_y:1_muslim:1': 0.931161642074585, 'count_y:1_muslim:1': 1627.0, 'acc_y:0_other_religions:1': 0.7258388996124268, 'count_y:0_other_religions:1': 2980.0, 'acc_y:1_other_religions:1': 0.8788461685180664, 'count_y:1_oth