In [1]:
import torch
import pandas as pd
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
# pose sequence as a NLI premise and label as a hypothesis
from transformers import AutoModelForSequenceClassification, AutoTokenizer, AdamW
from datasets import load_metric
import torch.nn.functional as F

metric = load_metric("accuracy")



HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1249.0, style=ProgressStyle(description…




In [2]:
entailment_data = pd.read_csv('./gigaword1000entailment.csv')


train_texts  = list(map(list, zip(entailment_data['document'], entailment_data['summary'])))
# map to mnli mapping, entailment = 2, contradiction = 0
print(train_texts[0])

train_labels = [2 if label == 0 or label == 2 else 0 for label in entailment_data['label']]
# train_labels = [2 if label == 0 else 0 for label in entailment_data['label']]

train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

["australia 's current account deficit shrunk by a record #.## billion dollars -lrb- #.## billion us -rrb- in the june quarter due to soaring commodity prices , figures released monday showed .", 'australian current account deficit narrows sharply']


In [3]:
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')

train_encodings = tokenizer(train_texts, return_tensors='pt', truncation='only_first', padding =True)
val_encodings = tokenizer(val_texts, return_tensors='pt', truncation='only_first', padding =True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=908.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898823.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355863.0, style=ProgressStyle(descript…




In [4]:
class EntailmentDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = EntailmentDataset(train_encodings, train_labels)
val_dataset = EntailmentDataset(val_encodings, val_labels)


In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli')#.to(device)
if torch.cuda.device_count() > 1:
  print("Let's use", torch.cuda.device_count(), "GPUs!")
  # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
  model = torch.nn.DataParallel(model)

model.to(device)

# for param in model.bart.bart.parameters():
#     param.requires_grad = False
# Change batch size to 32 when we have access to many gpu's
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)

optim = AdamW(model.parameters(), lr=1e-6)

model.eval()
print("validating")
with torch.no_grad():
    losses = []
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        losses.append(outputs[0].detach().item())
        preds = torch.argmax(F.softmax(outputs[1]), dim=1)
        # need to check this
        metric.add_batch(predictions=preds, references=labels)

    acc = metric.compute()
    print("acc:", acc)

val_perf = []
for epoch in range(10):
    print("epoch:", epoch)
    model.train()
    for batch in train_loader:
        optim.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.mean().backward()
        optim.step()

    model.eval()
    print("validating")
    with torch.no_grad():
        losses = []
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            losses.append(outputs[0].detach().item())
            preds = torch.argmax(F.softmax(outputs[1], dim=1), dim=1)
            # need to check this
            metric.add_batch(predictions=preds, references=labels)
    
        acc = metric.compute()
        print("acc:", acc)
        val_perf.append(((sum(losses)/len(losses)), acc))


        

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1629486723.0, style=ProgressStyle(descr…




Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Let's use 2 GPUs!
validating


  import sys


acc: {'accuracy': 0.785}
epoch: 0
validating
acc: {'accuracy': 0.87}
epoch: 1
validating
acc: {'accuracy': 0.87}
epoch: 2
validating
acc: {'accuracy': 0.87}
epoch: 3
validating
acc: {'accuracy': 0.87}
epoch: 4
validating
acc: {'accuracy': 0.87}


In [6]:
print(val_perf)

[(7.731543998718262, {'accuracy': 0.87}), (2.2283348083496093, {'accuracy': 0.87}), (1.9416706848144532, {'accuracy': 0.87}), (1.7624877548217774, {'accuracy': 0.87}), (3.4690134811401365, {'accuracy': 0.87})]


In [7]:
print(model)

DataParallel(
  (module): BartForSequenceClassification(
    (model): BartModel(
      (shared): Embedding(50265, 1024, padding_idx=1)
      (encoder): BartEncoder(
        (embed_tokens): Embedding(50265, 1024, padding_idx=1)
        (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)
        (layers): ModuleList(
          (0): EncoderLayer(
            (self_attn): Attention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (final_layer_n

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

nli_model = AutoModelForSequenceClassification.from_pretrained('facebook/bart-large-mnli').to(device)
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-mnli')
# label = "history"
# premise = "Napolean was exiled the Alba where he died of stomach cancer"
# hypothesis = f'This example is {label}.'

premise = "napolean and louis the 14th were two constrasting French figures whom once ruled the nation"
hypothesis = "Comparison of two French Rulers and they were great"



# run through model pre-trained on MNLI
x = tokenizer.encode(premise, hypothesis, return_tensors='pt', truncation='only_first', padding =True)
logits = nli_model(x.to(device))[0]

# we throw away "neutral" (dim 1) and take the probability of
# "entailment" (2) as the probability of the label being true 
entail_contradiction_logits = logits[:,[0,2]]
probs = entail_contradiction_logits.softmax(dim=1)
prob_label_is_true = probs[:,1]

print(probs)

Some weights of the model checkpoint at facebook/bart-large-mnli were not used when initializing BartForSequenceClassification: ['model.encoder.version', 'model.decoder.version']
- This IS expected if you are initializing BartForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BartForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor([[0.1554, 0.8446]], device='cuda:0', grad_fn=<SoftmaxBackward>)


In [9]:
print(logits.shape)

print(entail_contradiction_logits.shape)
print(entail_contradiction_logits)

print(probs.shape)
print(probs)


print(prob_label_is_true)

torch.Size([1, 3])
torch.Size([1, 2])
tensor([[-2.0410, -0.3481]], device='cuda:0', grad_fn=<IndexBackward>)
torch.Size([1, 2])
tensor([[0.1554, 0.8446]], device='cuda:0', grad_fn=<SoftmaxBackward>)
tensor([0.8446], device='cuda:0', grad_fn=<SelectBackward>)


In [10]:
print(x)

tensor([[    0,   282,  1115, 48547,     8, 26120,   354,     5,   501,   212,
            58,    80, 10759, 16136,   154,  1515,  2415,  2661,   683,  3447,
             5,  1226,     2,     2, 48080,  4060,     9,    80,  1515,   248,
           922,   268,     8,    51,    58,   372,     2]])


In [30]:
from datasets import load_dataset
import pandas as pd

giga_dataset = load_dataset("gigaword", split='train[:1000]')


Using custom data configuration default
Reusing dataset gigaword (/tmp/xdg-cache/huggingface/datasets/gigaword/default/1.2.0/c518c578e42a6afe842b09e979ee2907ea42a12b57ba992fae9e9d7347825245)


In [31]:
#torch.multiply(torch.rand((2,55,1)),torch.rand((2,55,1024)))

In [12]:
giga_filtered_keys = pd.DataFrame()

def mapper(examples):
    return tokenizer([(examples['document'], examples['summary'])], return_tensors='pt', truncation=True, padding='max_length')
print(giga_dataset)

gigaset = giga_dataset.map(mapper)

gigaset.set_format(type='torch', columns=['input_ids', 'attention_mask'])
print(gigaset[0])

Loading cached processed dataset at /tmp/xdg-cache/huggingface/datasets/gigaword/default/1.2.0/c518c578e42a6afe842b09e979ee2907ea42a12b57ba992fae9e9d7347825245/cache-f50088200ce2f575.arrow


Dataset({
    features: ['document', 'summary'],
    num_rows: 1000
})
{'attention_mask': [tensor([1, 1, 1,  ..., 0, 0, 0])], 'input_ids': [tensor([   0,  102, 4193,  ...,    1,    1,    1])]}


In [32]:
import pandas as pd

giga_filtered_keys = pd.DataFrame()

giga_loader = DataLoader(gigaset, batch_size=4)

# look into how it was bound together, something special is done
print(next(iter(giga_loader)))

with torch.no_grad():
    all_preds = []
    for batch in giga_loader:
        batch = {k: v[0].to(device) for k, v in batch.items()}
#         input_ids = (batch['input_ids'][0].to(device), batch['input_ids'][1].to(device))
#         attention_mask = (batch['attention_mask'][0].to(device), batch['attention_mask'][1].to(device))
#         outputs = model(input_ids, attention_mask=attention_mask)
        outputs = nli_model(**batch)
        preds = torch.argmax(F.softmax(outputs[0][:,[0,2]], dim=1), dim=1)
        all_preds.append(preds)
        
    giga_filtered_keys['index_keys'] = torch.cat(all_preds, dim=0).cpu()
    
print(giga_filtered_keys)

{'attention_mask': [tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])], 'input_ids': [tensor([[    0,   102,  4193,  ...,     1,     1,     1],
        [    0,   415,   513,  ...,     1,     1,     1],
        [    0,   102,  4193,  ...,     1,     1,     1],
        [    0, 25515,   449,  ...,     1,     1,     1]])]}
     index_keys
0             1
1             1
2             1
3             1
4             1
..          ...
995           1
996           1
997           1
998           1
999           0

[1000 rows x 1 columns]


In [33]:
for i,item in enumerate(giga_filtered_keys['index_keys']):
    if item == 0:
        print(i)

56
108
120
142
157
160
161
175
206
234
256
278
327
341
376
377
422
444
491
495
497
518
524
555
574
585
627
645
656
661
678
680
714
725
743
755
764
777
798
802
807
817
832
847
851
866
884
938
956
958
963
983
999


In [29]:
print(outputs[0].shape)

torch.Size([4, 3])
