In [None]:
### Deep Feature Reweighting ###

In [None]:
from google.colab import drive
drive.mount('/content/drive')
drive_PATH = '../content/drive/MyDrive/Colab Notebooks/dis.experiments.4'
import sys
sys.path.append(drive_PATH)
# drive_PATH = ''

In [None]:
import torch
import torch.nn as nn

import utils.NLIdataset as nli_ds
import utils.transforms as tr

import tqdm
import math
import numpy as np
import pandas as pd

In [None]:
# Device for GPU speedup
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DEVICE

In [None]:
### MNLI Dataset ###
!pip install jsonlines
import jsonlines # jsonl imports

train_PATH = drive_PATH + '/data/multinli_1.0/multinli_1.0_train.jsonl'
dev_matched_PATH = drive_PATH + '/data/multinli_1.0/multinli_1.0_dev_matched.jsonl'
dev_mismatched_PATH = drive_PATH + '/data/multinli_1.0/multinli_1.0_dev_mismatched.jsonl'
hans_PATH = drive_PATH + '/data/hans/heuristics_evaluation_set.jsonl'

# Train Data
train_DATA = []
train_s1 = []
train_s2 = []
train_text = []
train_label = []
# Mathced Dev Data
dev_matched_DATA = []
dev_matched_s1 = []
dev_matched_s2 = []
dev_matched_text = []
dev_matched_label = []
# Mismatched Dev Data
dev_mismatched_DATA = []
dev_mismatched_s1 = []
dev_mismatched_s2 = []
dev_mismatched_text = []
dev_mismatched_label = []
# Hans Data
hans_DATA = []
hans_s1 = []
hans_s2 = []
hans_text = []
hans_label = []

with jsonlines.open(train_PATH) as f:
    for line in f.iter():
        train_DATA.append(line)
        train_s1.append(line['sentence1'])
        train_s2.append(line['sentence2'])
        train_text.append( line['sentence1'] + ' ' + line['sentence2'] )
        train_label.append(line['gold_label'])
with jsonlines.open(dev_matched_PATH) as f:
    for line in f.iter():
        dev_matched_DATA.append(line)
        dev_matched_s1.append(line['sentence1'])
        dev_matched_s2.append(line['sentence2'])
        dev_matched_text.append( line['sentence1'] + ' ' + line['sentence2'] )
        dev_matched_label.append(line['gold_label'])
with jsonlines.open(dev_mismatched_PATH) as f:
    for line in f.iter():
        dev_mismatched_DATA.append(line)
        dev_mismatched_s1.append(line['sentence1'])
        dev_mismatched_s2.append(line['sentence2'])
        dev_mismatched_text.append( line['sentence1'] + ' ' + line['sentence2'] )
        dev_mismatched_label.append(line['gold_label'])
with jsonlines.open(hans_PATH) as f:
    for line in f.iter():
        hans_DATA.append(line)
        hans_s1.append(line['sentence1'])
        hans_s2.append(line['sentence2'])
        hans_text.append( line['sentence1'] + ' ' + line['sentence2'] )
        hans_label.append(line['gold_label'])

In [None]:
### Cleaning Datasets

# Train
train_label = np.array(train_label, dtype='<U14')
train_s1 = np.array(train_s1)
train_s2 = np.array(train_s2)
train_label[(train_label == 'neutral') | (train_label == 'contradiction')] = 'non-entailment'
train_label[train_label == ['entailment']] = 1
train_label[train_label == ['non-entailment']] = 0
train_label = np.array(train_label, dtype='int')

# Dev Matched
dev_matched_label = np.array(dev_matched_label, dtype='<U14')
dev_matched_filter = dev_matched_label != '-'
dev_matched_s1 = np.array(dev_matched_s1)[dev_matched_filter]
dev_matched_s2 = np.array(dev_matched_s2)[dev_matched_filter]
dev_matched_label = dev_matched_label[dev_matched_filter]
dev_matched_label[(dev_matched_label == 'neutral') | (dev_matched_label == 'contradiction')] = 'non-entailment'
dev_matched_label[dev_matched_label == ['entailment']] = 1
dev_matched_label[dev_matched_label == ['non-entailment']] = 0
dev_matched_label = np.array(dev_matched_label, dtype='int')

# Dev Mismatched
dev_mismatched_label = np.array(dev_mismatched_label, dtype='<U14')
dev_mismatched_filter = dev_mismatched_label != '-'
dev_mismatched_s1 = np.array(dev_mismatched_s1)[dev_mismatched_filter]
dev_mismatched_s2 = np.array(dev_mismatched_s2)[dev_mismatched_filter]
dev_mismatched_label = dev_mismatched_label[dev_mismatched_filter]
dev_mismatched_label[(dev_mismatched_label == 'neutral') | (dev_mismatched_label == 'contradiction')] = 'non-entailment'
dev_mismatched_label[dev_mismatched_label == ['entailment']] = 1
dev_mismatched_label[dev_mismatched_label == ['non-entailment']] = 0
dev_mismatched_label = np.array(dev_mismatched_label, dtype='int')

# HANS
hans_label = np.array(hans_label)
hans_s1 = np.array(hans_s1)
hans_s2 = np.array(hans_s2)
hans_label[hans_label == ['entailment']] = 1
hans_label[hans_label == ['non-entailment']] = 0
hans_label = np.array(hans_label, dtype='int')

train_labels = np.unique(train_label)
dev_matched_labels = np.unique(dev_matched_label)
dev_mismatched_labels = np.unique(dev_mismatched_label)
hans_labels = np.unique(np.array(hans_label))

value_counts = pd.concat({'train_label' : pd.DataFrame(train_label).value_counts(),
                        'dev_matched_label' : pd.DataFrame(dev_matched_label).value_counts(),
                        'dev_mismatched_label' : pd.DataFrame(dev_mismatched_label).value_counts(),
                        'hans_label' : pd.DataFrame(hans_label).value_counts()})

In [None]:
### Balancing Act
def balanced_idx(label_dataset):
    idx1 = np.array(range(len(label_dataset)))[label_dataset == 1]
    idx0 = np.array(range(len(label_dataset)))[label_dataset == 0]
    idx0_selected_i = np.random.choice(idx0.shape[0], len(idx1), replace=False)
    idx0_selected = idx0[idx0_selected_i]
    idx = np.concatenate((idx1, idx0_selected))
    np.random.shuffle(idx) # random shuffle
    return idx

# Balancing Train
train_balanced_idx = balanced_idx(train_label)
train_s1 = train_s1[train_balanced_idx]
train_s2 = train_s2[train_balanced_idx]
train_label = train_label[train_balanced_idx]

# Balancing Dev Matched
dev_matched_balanced_idx = balanced_idx(dev_matched_label)
dev_matched_s1 = dev_matched_s1[dev_matched_balanced_idx]
dev_matched_s2 = dev_matched_s2[dev_matched_balanced_idx]
dev_matched_label = dev_matched_label[dev_matched_balanced_idx]

# Balancing Dev Mismatched
dev_mismatched_balanced_idx = balanced_idx(dev_mismatched_label)
dev_mismatched_s1 = dev_mismatched_s1[dev_mismatched_balanced_idx]
dev_mismatched_s2 = dev_mismatched_s2[dev_mismatched_balanced_idx]
dev_mismatched_label = dev_mismatched_label[dev_mismatched_balanced_idx]

# Balancing HANS (already balanced)
hans_balanced_idx = balanced_idx(hans_label)
hans_s1 = hans_s1[hans_balanced_idx]
hans_s2 = hans_s2[hans_balanced_idx]
hans_label = hans_label[hans_balanced_idx]

In [None]:
### Preprocessing ###
vocab_train_iter = nli_ds.NLIdataset_merge(train_text , np.array(train_label, dtype='str'))
token_transform = tr.construct_token_transform()
vocab_transform = tr.construct_vocab_transform(vocab_train_iter)
tensor_transform = tr.construct_tensor_transform()
text_transform = tr.construct_text_transform(token_transform , vocab_transform, tensor_transform)
VOCAB_SIZE = len(vocab_transform)
VOCAB_SIZE

In [None]:
from model.embedding import TokenEmbedding, PositionalEncoding
from model.classifier import NonLinearClassifier
from model.encoder import Transformer_Encoder


### Natural Language Inference Model
class NLInference(nn.Module):
    def __init__(self):
        super(NLInference, self).__init__()
        # Configuration and Initialization
        self.dmodel = 256                       # All
        self.num_enc_layers = 2                 # Encoder
        self.nhead = 4                          # Encoder: For Transformer
        self.dclassifier = 2*self.dmodel        # Classifier: Calculate the input dimension for the classifier
        self.fc_dim = 512                       # Classifier: Dimension of the fully connected layers
        self.n_classes = 2                      # Classifier: Number of classes for classification

        # Encoders
        self.encoder = Transformer_Encoder( self.dmodel , self.nhead, self.num_enc_layers, VOCAB_SIZE )
        # Classifiers
        self.classifier = NonLinearClassifier(self.dclassifier, self.fc_dim, self.n_classes)

    def forward(self, s1, s2):
        # padding masks
        # s1_padding_mask = (s1 == tr.PAD_IDX).transpose(0, 1)
        # s2_padding_mask = (s2 == tr.PAD_IDX).transpose(0, 1)
        # add masks s1_padding_mask, s2_padding_mask
        # s1_emb = self.positional_encoding(self.tok_emb(s1))
        # s2_emb = self.positional_encoding(self.tok_emb(s2))
        # pass embeddings through encoder
        s1_encoded = self.encoder(s1)
        s2_encoded = self.encoder(s2)
        # take the average to calculate sentence representation
        # s1_encoded = torch.sum(s1_encoded,0) / s1_encoded.size()[0]
        # s2_encoded = torch.sum(s2_encoded,0) / s2_encoded.size()[0]
        # combine the two sentences by concatenating
        combined_context = torch.cat((s1_encoded, s2_encoded), 1)
        # Pass the combined features through the classifier to get the output
        output = self.classifier(combined_context)
        return output

In [None]:
model = NLInference()

In [None]:
### TRAINING LOOP
import time
def train(dataloader):
    # print('HERE')
    model.cuda()
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 100
    start_time = time.time()

    for idx, (s1, s2, label) in enumerate(dataloader):
        s1 = s1.to(DEVICE)
        s2 = s2.to(DEVICE)
        label = label.to(DEVICE)

        optimizer.zero_grad()
        predicted_label = model(s1, s2)

        loss = criterion(predicted_label, label)
        loss.backward()

        # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # investigate

        optimizer.step()

        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)

        train_acc = total_acc / total_count

        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print(
                "| epoch {: d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(epoch, idx, len(dataloader), train_acc)
            )
            total_acc, total_count = 0, 0
            start_time = time.time()

        train_losses.append(train_acc)

In [None]:
### EVALUATION LOOP
def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (s1, s2, label) in enumerate(dataloader):
            s1 = s1.to(DEVICE)
            s2 = s2.to(DEVICE)
            label = label.to(DEVICE)
            predicted_label = model(s1, s2)
            loss = criterion(predicted_label, label)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [None]:
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader

# Hyperparameters
EPOCHS = 4  # epoch
LEARNING_RATE = 0.0001  # learning rate
BATCH_SIZE = 16  # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)
# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

train_iter = nli_ds.NLIdataset(train_s1 , train_s2, train_label)
dev_matched_iter = nli_ds.NLIdataset(dev_matched_s1, dev_matched_s2 , dev_matched_label)
dev_mismatched_iter = nli_ds.NLIdataset(dev_mismatched_s1, dev_mismatched_s2 , dev_mismatched_label)
hans_iter = nli_ds.NLIdataset(hans_s1, hans_s2 , hans_label)

train_dataset = to_map_style_dataset(train_iter)
dev_matched_dataset = to_map_style_dataset(dev_matched_iter)
dev_mismatched_dataset = to_map_style_dataset(dev_mismatched_iter)
hans_dataset = to_map_style_dataset(hans_iter)

num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = random_split( train_dataset, [num_train, len(train_dataset) - num_train] )
num_train_hans = int(len(hans_dataset) * 0.75)
split_train_hans_ , split_test_hans_ = random_split( hans_dataset, [num_train_hans, len(hans_dataset) - num_train_hans] )

num_train_hans_train_ = int(len(split_train_hans_) * 0.90)
num_train_hans_train_ , num_train_hans_valid_ = random_split( split_train_hans_, [num_train_hans_train_, len(split_train_hans_) - num_train_hans_train_] )

def collate_fn( batch):
    label_pipeline = lambda x: int(x) #{'contradiction': 0, 'entailment': 1, 'neutral': 2, '-': -1}[x]
    # lists to hold processed source and target
    s1_batch, s2_batch, tgt_batch, padding_offsets = [], [], [],  []
    for s1_sample, s2_sample, tgt_sample in batch:
        # convert to tensor
        s1_sample = text_transform(s1_sample)
        s2_sample = text_transform(s2_sample)
        s1_batch.append(s1_sample)
        s2_batch.append(s2_sample)
        tgt_batch.append(label_pipeline(tgt_sample))
        padding_offsets.append(len(s1_sample))
        padding_offsets.append(len(s2_sample))
    # Convert the label_list to a tensor with integer type.
    tgt_batch = torch.tensor(tgt_batch, dtype=torch.int64)
    # to make the padded sequences for s1 and s2 equal length
    padding_offset = max(padding_offsets)
    s1_batch[0] = nn.ConstantPad1d((0,padding_offset - len(s1_batch[0]) ), tr.PAD_IDX)(s1_batch[0])
    s2_batch[0] = nn.ConstantPad1d((0,padding_offset - len(s2_batch[0]) ), tr.PAD_IDX)(s2_batch[0])
    # pad the sequences to ensure they have the same length
    s1_batch = torch.nn.utils.rnn.pad_sequence(s1_batch, padding_value=tr.PAD_IDX)
    s2_batch = torch.nn.utils.rnn.pad_sequence(s2_batch, padding_value=tr.PAD_IDX)
    return s1_batch.to(DEVICE), s2_batch.to(DEVICE), tgt_batch.to(DEVICE)

train_dataloader = DataLoader( split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn )
valid_dataloader = DataLoader( split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn )
dev_matched_dataloader = DataLoader( dev_matched_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn )
dev_mismatched_dataloader = DataLoader( dev_mismatched_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn )

test_hans_dataloader = DataLoader( split_test_hans_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn )
train_hans_dataloader = DataLoader( num_train_hans_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn )
valid_hans_dataloader = DataLoader( num_train_hans_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn )


In [None]:
### INITIALIZATION
for p in model.parameters():
    if p.dim() > 1:
        torch.nn.init.xavier_uniform_(p)

### ACCOUNTING
model_paths = []
val_losses = []
train_losses = []
import time
model_id = '-'.join(time.ctime(time.time()).replace(':', ' ').split(' ')[2:5])

### TRAINING
for epoch in range(1, EPOCHS + 1):

    epoch_start_time = time.time()
    train(train_dataloader)
    epoch_end_time = time.time()
    elapsed_time = epoch_end_time - epoch_start_time

    accu_train = train_losses[-1]
    accu_val = evaluate(valid_dataloader)
    val_losses.append(accu_val)

    accu_dev_matched = evaluate(dev_matched_dataloader)
    accu_dev_mismatched = evaluate(dev_mismatched_dataloader)
    accu_hans = evaluate(test_hans_dataloader)

    # register model path
    model_paths.append(f'id-{model_id}-epoch-{epoch}-accu_train-{accu_train:.3f}-accu_val-{accu_val:.3f}-accu_dev_matched-{accu_dev_matched:.3f}-accu_dev_mismatched-{accu_dev_mismatched:.3f}-accu_hans-{accu_hans:.3f}.pt')
    # save model to path
    torch.save(model.state_dict(), drive_PATH+'/model_states/'+model_paths[-1])


    print("-" * 59)
    print("| end of epoch {:3d} | time: {:5.2f}s | valid accuracy {:8.3f} |".format( epoch, elapsed_time, accu_val))
    print("| dev matched accuracy {:8.3f} | dev mismatched accuracy {:8.3f} | hans accuracy {:8.3f} |".format( accu_dev_matched, accu_dev_mismatched, accu_hans))
    print("-" * 59)

In [None]:
# Load the model with best validation accuracy
best_model_index = np.argmin(val_losses)
model_state = model_paths[best_model_index]
model.load_state_dict(torch.load(drive_PATH+'/model_states/'+model_state))
val_losses

In [None]:
### FREEZING ALL LAYERS EXCEPT THE LAST
n_classes = 2
for p in model.parameters():
    if p.size()[0] == n_classes:
        print(p.size())
        print(p)
        pass
    else:
        p.requires_grad = False

In [None]:
### ACCOUNTING
# reinitialize losses and paths
model_paths = []
val_losses = []
train_losses = []

EPOCHS = 4

### TRAINING
for epoch in range(1, EPOCHS + 1):

    epoch_start_time = time.time()
    train(train_hans_dataloader)
    epoch_end_time = time.time()
    elapsed_time = epoch_end_time - epoch_start_time

    accu_train = train_losses[-1]
    accu_val = evaluate(valid_hans_dataloader)
    val_losses.append(accu_val)

    accu_dev_matched = evaluate(dev_matched_dataloader)
    accu_dev_mismatched = evaluate(dev_mismatched_dataloader)
    accu_hans = evaluate(test_hans_dataloader)

    # register model path
    model_paths.append(f'id-{model_id}-DFRepoch-{epoch}-accu_train-{accu_train:.3f}-accu_val-{accu_val:.3f}-accu_dev_matched-{accu_dev_matched:.3f}-accu_dev_mismatched-{accu_dev_mismatched:.3f}-accu_hans-{accu_hans:.3f}.pt')
    # save model to path
    torch.save(model.state_dict(), drive_PATH+'/model_states/'+model_paths[-1])

    print("-" * 59)
    print("| end of epoch {:3d} | time: {:5.2f}s | valid accuracy {:8.3f} |".format( epoch, elapsed_time, accu_val))
    print("| dev matched accuracy {:8.3f} | dev mismatched accuracy {:8.3f} | hans accuracy {:8.3f} |".format( accu_dev_matched, accu_dev_mismatched, accu_hans))
    print("-" * 59)

In [None]:
import collections
from functools import partial

NUM_BATCHES = 100
s1_dataset, s2_dataset, labels_dataset = [] , [], []
for idx, (s1, s2, label) in enumerate(valid_dataloader):
	if idx == NUM_BATCHES:
		break
	s1_dataset.append(s1)
	s2_dataset.append(s2)
	labels_dataset.append(label)

# a dictionary that keeps saving the activations as they come
activations = collections.defaultdict(list)
def save_activation(name, mod, inp, out):
	activations[name].append(out.cpu())

# Registering hooks for all the TransformerEncoder layers
# Note: Hooks are called EVERY TIME the module performs a forward pass. For modules that are
# called repeatedly at different stages of the forward pass (like TransformerEncoder in NLI called for s1 and s2 separately)
# this will save different activations.
# Editing the forward pass code to save activations is the way to go for these cases.
# Or we can filter out the odd and even indices from the activations to get the ones for s1 and s2
for name, m in model.named_modules():
	if name == 'encoder':
		# partial to assign the layer name to each hook
		m.register_forward_hook(partial(save_activation, name))
	if name == 'classifier.classifier.1':
		m.register_forward_hook(partial(save_activation, name))
	if name == 'classifier.classifier.4':
		m.register_forward_hook(partial(save_activation, name))
	if name == 'classifier.classifier.7':
		m.register_forward_hook(partial(save_activation, name))

# forward pass through the full dataset
for batch_i in range(NUM_BATCHES):
	out = model(s1_dataset[batch_i], s2_dataset[batch_i])

In [None]:
keys = list(activations.keys())
activations_dict = {}
activations_df_dict = {}
for key in keys:
    activations_dict[key] = np.array(torch.cat([a.detach() for a in activations[key]])).transpose(1,0)
    activations_df_dict[key] = pd.DataFrame(activations_dict[key])
    print(activations_dict[key].shape)

In [None]:
activations_df = pd.concat(activations_df_dict, axis=0)
out_csv_PATH = drive_PATH + "/res/activations/DFR/test.csv"
activations_df.to_csv(out_csv_PATH)