In [1]:
from collections import defaultdict
import json
import logging
import math
import os
import random
from pathlib import Path

import datasets
import torch
from torch import nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from datasets import load_dataset, load_metric
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    PretrainedConfig,
    SchedulerType,
    default_data_collator,
    get_scheduler,
)

In [2]:
from typing import List, Optional, Tuple, Union
from datasets import load_dataset

raw_datasets = load_dataset('glue','sst2')

Downloading and preparing dataset glue/sst2 (download: 7.09 MiB, generated: 4.81 MiB, post-processed: Unknown size, total: 11.90 MiB) to /home/v-biyangguo/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Dataset glue downloaded and prepared to /home/v-biyangguo/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
def preprocess_function(examples):
    # Tokenize the texts
    texts = (examples['sentence'], None)
    result = tokenizer(*texts, padding="max_length", max_length=100, truncation=True)
    if "label" in examples:
        result["labels"] = examples["label"]
    return result


processed_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,  
    desc="Running tokenizer on dataset",)
processed_datasets

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

Running tokenizer on dataset:   0%|          | 0/68 [00:00<?, ?ba/s]

Running tokenizer on dataset:   0%|          | 0/1 [00:00<?, ?ba/s]

Running tokenizer on dataset:   0%|          | 0/2 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 872
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 1821
    })
})

In [4]:
data_collator = DataCollatorWithPadding(tokenizer)
dataloader = DataLoader(processed_datasets['validation'], shuffle=True, collate_fn=data_collator, batch_size=8)
c = 0
for batch in dataloader:
    pass
batch.keys(), batch['input_ids'].shape

(dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels']),
 torch.Size([8, 100]))

In [5]:
len(batch)

4

In [6]:
config = AutoConfig.from_pretrained('bert-base-cased', num_labels=2)
encoder = AutoModel.from_pretrained('bert-base-cased')
classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
dropout = nn.Dropout(classifier_dropout)
classifier_easy = nn.Linear(config.hidden_size, config.num_labels)
classifier_hard = nn.Linear(config.hidden_size, config.num_labels)
# 2 experts: easy or hard
hardness_gate = nn.Linear(config.hidden_size,2) 

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [8]:
batch_wo_labels = {k:v for k,v in batch.items() if k != 'labels'}
outputs = encoder(**batch_wo_labels)
pooled_output = outputs[1]
pooled_output.shape

torch.Size([8, 768])

In [17]:
# model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
logits = model(**batch).logits
logits

tensor([[ 0.1883, -0.2964],
        [ 0.1280, -0.3338],
        [ 0.1373, -0.3579],
        [ 0.2173, -0.2385],
        [ 0.0645, -0.3355],
        [ 0.1644, -0.2774],
        [ 0.1745, -0.3913],
        [ 0.2355, -0.2404]], grad_fn=<AddmmBackward>)

In [20]:
# from torch import nn
# kld_loss_fct = nn.KLDivLoss(reduction="batchmean")

kld_loss_fct(
            nn.functional.log_softmax(logits / 1, dim=-1),
            nn.functional.softmax(logits / 1, dim=-1),
        ) * (1) ** 2

tensor(7.4759e-09, grad_fn=<DivBackward0>)

In [30]:
T = 1
logits_gate = hardness_gate(pooled_output)
gate_weights = F.softmax(logits_gate/T, dim=1)
logits_gate.shape, gate_weights.shape, gate_weights

(torch.Size([8, 2]),
 torch.Size([8, 2]),
 tensor([[0.6097, 0.3903],
         [0.5698, 0.4302],
         [0.5829, 0.4171],
         [0.5860, 0.4140],
         [0.5920, 0.4080],
         [0.5839, 0.4161],
         [0.5857, 0.4143],
         [0.5875, 0.4125]], grad_fn=<SoftmaxBackward>))

## 只使用 `easy expert` 进行训练和预测
weight rank:
- ambi-easy
- easy
- ambi-hard
- hard

In [31]:
easy_probs = gate_weights[:,0]
easy_probs

tensor([0.6097, 0.5698, 0.5829, 0.5860, 0.5920, 0.5839, 0.5857, 0.5875],
       grad_fn=<SelectBackward>)

In [37]:
easy_weights = torch.where(easy_probs>0.5, 1-torch.abs(easy_probs-0.5), easy_probs)
# 归一化
batch_size = 8
easy_weights, easy_weights * batch_size / torch.sum(easy_weights)

(tensor([0.8903, 0.9302, 0.9171, 0.9140, 0.9080, 0.9161, 0.9143, 0.9125],
        grad_fn=<SWhereBackward>),
 tensor([0.9753, 1.0191, 1.0047, 1.0013, 0.9947, 1.0036, 1.0017, 0.9997],
        grad_fn=<DivBackward0>))

In [40]:
easy_weights = easy_weights * batch_size / torch.sum(easy_weights)
example_loss = torch.tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778])
easy_weights * example_loss, torch.mean(easy_weights * example_loss)

(tensor([0.4060, 0.3822, 0.9169, 0.4725, 1.1667, 0.2972, 0.4749, 0.9775],
        grad_fn=<MulBackward0>),
 tensor(0.6367, grad_fn=<MeanBackward0>))

end of 只使用 `easy expert` 进行训练和预测.
---

In [22]:
X = torch.tensor([[0.7,0.3],[0.4,0.6]])
X, torch.where(X>0.5, 1-torch.abs(X-0.5), X)

(tensor([[0.7000, 0.3000],
         [0.4000, 0.6000]]),
 tensor([[0.8000, 0.3000],
         [0.4000, 0.9000]]))

In [24]:
pooled_output = dropout(pooled_output)
logits_easy = classifier_easy(pooled_output)
logits_hard = classifier_easy(pooled_output)


In [28]:
confidences = torch.tensor([0.7]*8)
confidences
easy_probs = confidences.view(-1,1)
hard_probs = 1 - easy_probs
hardness_probs = torch.cat([easy_probs,hard_probs],dim=1)
confidences, easy_probs, hard_probs, hardness_probs

(tensor([0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000, 0.7000]),
 tensor([[0.7000],
         [0.7000],
         [0.7000],
         [0.7000],
         [0.7000],
         [0.7000],
         [0.7000],
         [0.7000]]),
 tensor([[0.3000],
         [0.3000],
         [0.3000],
         [0.3000],
         [0.3000],
         [0.3000],
         [0.3000],
         [0.3000]]),
 tensor([[0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000]]))

In [35]:
F.log_softmax(logits_gate, dim=-1),hardness_probs

(tensor([[-0.6693, -0.7176],
         [-0.6422, -0.7469],
         [-0.6316, -0.7587],
         [-0.6843, -0.7021],
         [-0.6417, -0.7474],
         [-0.6898, -0.6965],
         [-0.6593, -0.7282],
         [-0.6960, -0.6903]], grad_fn=<LogSoftmaxBackward>),
 tensor([[0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000],
         [0.7000, 0.3000]]))

In [42]:
loss_gate =  F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean')
loss_gate

tensor(0.0712, grad_fn=<DivBackward0>)

In [50]:
num_labels = config.num_labels
labels = batch['labels']
loss_fct = CrossEntropyLoss(reduction='none')
loss_easy = loss_fct(logits_easy.view(-1, num_labels), labels.view(-1))
loss_hard = loss_fct(logits_hard.view(-1, num_labels), labels.view(-1))
loss_easy, loss_hard

(tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778],
        grad_fn=<NllLossBackward>),
 tensor([0.4163, 0.3751, 0.9126, 0.4719, 1.1729, 0.2961, 0.4741, 0.9778],
        grad_fn=<NllLossBackward>))

In [54]:
easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1)
easy_hard_loss_cat

tensor([[0.4163, 0.4163],
        [0.3751, 0.3751],
        [0.9126, 0.9126],
        [0.4719, 0.4719],
        [1.1729, 1.1729],
        [0.2961, 0.2961],
        [0.4741, 0.4741],
        [0.9778, 0.9778]], grad_fn=<CatBackward>)

In [55]:
easy_hard_loss_cat, gate_weights, easy_hard_loss_cat * gate_weights

(tensor([[0.4163, 0.4163],
         [0.3751, 0.3751],
         [0.9126, 0.9126],
         [0.4719, 0.4719],
         [1.1729, 1.1729],
         [0.2961, 0.2961],
         [0.4741, 0.4741],
         [0.9778, 0.9778]], grad_fn=<CatBackward>),
 tensor([[0.5121, 0.4879],
         [0.5262, 0.4738],
         [0.5317, 0.4683],
         [0.5044, 0.4956],
         [0.5264, 0.4736],
         [0.5017, 0.4983],
         [0.5172, 0.4828],
         [0.4986, 0.5014]], grad_fn=<SoftmaxBackward>),
 tensor([[0.2132, 0.2031],
         [0.1974, 0.1777],
         [0.4853, 0.4274],
         [0.2380, 0.2338],
         [0.6174, 0.5555],
         [0.1486, 0.1476],
         [0.2452, 0.2289],
         [0.4875, 0.4903]], grad_fn=<MulBackward0>))

In [80]:
weighted_loss = easy_hard_loss_cat * gate_weights
torch.mean(weighted_loss)

tensor(0.3185, grad_fn=<MeanBackward0>)

In [111]:
from transformers.modeling_outputs import SequenceClassifierOutput

class HCTForSequenceClassification(nn.Module):
    def __init__(self, model_name_or_path, config):
        super(HCTForSequenceClassification, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name_or_path)
        self.config = config
        self.num_labels = config.num_labels
        self.classifier_dropout = (
        self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob
    )
        self.dropout = nn.Dropout(self.classifier_dropout)
        self.classifier_easy = nn.Linear(self.config.hidden_size, config.num_labels)
        self.classifier_hard = nn.Linear(self.config.hidden_size, config.num_labels)
        # 2 experts: easy or hard
        # gate output: 0 for easy, 1 for hard
        self.hardness_gate = nn.Linear(self.config.hidden_size,2) 


    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        confidences: Optional[torch.Tensor] = None,  # the confidence value of a sample
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.encoder(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # the CLS vector
        pooled_output = outputs[1]  
        # gating:
        logits_gate = self.hardness_gate(pooled_output)
        gate_weights = F.softmax(logits_gate)
        # easy/hard experts:
        pooled_output = self.dropout(pooled_output)
        logits_easy = self.classifier_easy(pooled_output)
        logits_hard = self.classifier_hard(pooled_output)

        loss = None
        loss_easy, loss_hard, gate_loss = None, None, None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            # gating loss
            # confidences 其实相当于 easy 分支的概率，所以你还需要自己构造一个 hard prob
            easy_probs = confidences.view(-1,1)
            hard_probs = 1 - easy_probs
            hardness_probs = torch.cat([easy_probs,hard_probs],dim=1)
            gate_loss = F.kl_div(F.log_softmax(logits_gate, dim=-1), hardness_probs, reduction='batchmean')
            if self.config.problem_type == "regression":
                loss_fct = MSELoss(reduction='none')
                if self.num_labels == 1:
                    loss_easy = loss_fct(logits_easy.squeeze(), labels.squeeze())
                    loss_hard = loss_fct(logits_hard.squeeze(), labels.squeeze())
                else:
                    loss_easy = loss_fct(logits_easy, labels)
                    loss_hard = loss_fct(logits_hard, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss(reduction='none') # reduction='none', 来得到每个sample的loss
                loss_easy = loss_fct(logits_easy.view(-1, self.num_labels), labels.view(-1))
                loss_hard = loss_fct(logits_hard.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss(reduction='none')
                loss_easy = loss_fct(logits_easy, labels)
                loss_hard = loss_fct(logits_hard, labels)
            
            easy_hard_loss_cat = torch.cat([loss_easy.view(-1,1), loss_hard.view(-1,1)],dim=1)
            weighted_loss = easy_hard_loss_cat * gate_weights
            clf_loss = torch.mean(weighted_loss)
            loss = gate_loss + clf_loss
        
        if not return_dict:
            output = (logits_gate,logits_easy, logits_hard,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits={"gate":logits_gate, "easy":logits_easy, "hard":logits_hard},
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

my_model = HCTForSequenceClassification('bert-base-cased', config)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [112]:
my_outputs = my_model(**batch, confidences=confidences)
my_outputs.loss, my_outputs.logits



(tensor(0.5328, grad_fn=<AddBackward0>),
 {'gate': tensor([[0.4412, 0.2590],
          [0.4171, 0.3171],
          [0.4294, 0.3268],
          [0.4348, 0.3207],
          [0.3658, 0.3255],
          [0.4153, 0.2928],
          [0.3959, 0.2656],
          [0.4340, 0.2752]], grad_fn=<AddmmBackward>),
  'easy': tensor([[ 0.0633,  0.9459],
          [-0.1190,  0.8201],
          [ 0.1919,  0.9389],
          [-0.2952,  0.6612],
          [ 0.0297,  0.8347],
          [-0.2183,  1.2751],
          [-0.1609,  1.0679],
          [-0.1628,  0.9590]], grad_fn=<AddmmBackward>),
  'hard': tensor([[-0.2337,  0.2637],
          [-0.3422,  0.3396],
          [-0.2908,  0.3610],
          [-0.2687,  0.1167],
          [-0.1619, -0.1485],
          [-0.7175,  0.3527],
          [-0.4886,  0.2593],
          [-0.1387,  0.3856]], grad_fn=<AddmmBackward>)})

In [107]:
# indices = torch.argmax(my_outputs.logits['gate'],dim=1)
indices = torch.tensor([1,1,1,1,0,0,0,0])
indices, 1-indices

(tensor([1, 1, 1, 1, 0, 0, 0, 0]), tensor([0, 0, 0, 0, 1, 1, 1, 1]))

In [113]:
my_outputs.logits['easy'] * (1-indices).view(-1,1), my_outputs.logits['hard'] * indices.view(-1,1)

my_outputs.logits['easy'] * (1-indices).view(-1,1) + my_outputs.logits['hard'] * indices.view(-1,1)

tensor([[-0.2337,  0.2637],
        [-0.3422,  0.3396],
        [-0.2908,  0.3610],
        [-0.2687,  0.1167],
        [ 0.0297,  0.8347],
        [-0.2183,  1.2751],
        [-0.1609,  1.0679],
        [-0.1628,  0.9590]], grad_fn=<AddBackward0>)

In [114]:
my_outputs.logits['easy'], my_outputs.logits['hard']

(tensor([[ 0.0633,  0.9459],
         [-0.1190,  0.8201],
         [ 0.1919,  0.9389],
         [-0.2952,  0.6612],
         [ 0.0297,  0.8347],
         [-0.2183,  1.2751],
         [-0.1609,  1.0679],
         [-0.1628,  0.9590]], grad_fn=<AddmmBackward>),
 tensor([[-0.2337,  0.2637],
         [-0.3422,  0.3396],
         [-0.2908,  0.3610],
         [-0.2687,  0.1167],
         [-0.1619, -0.1485],
         [-0.7175,  0.3527],
         [-0.4886,  0.2593],
         [-0.1387,  0.3856]], grad_fn=<AddmmBackward>))

In [117]:
weights = F.softmax(my_outputs.logits['gate'], dim=1)
weights, weights[:,0], weights[:,1]  # 第0列是easy的权重，第1列是hard的权重

(tensor([[0.5454, 0.4546],
         [0.5250, 0.4750],
         [0.5256, 0.4744],
         [0.5285, 0.4715],
         [0.5101, 0.4899],
         [0.5306, 0.4694],
         [0.5325, 0.4675],
         [0.5396, 0.4604]], grad_fn=<SoftmaxBackward>),
 tensor([0.5454, 0.5250, 0.5256, 0.5285, 0.5101, 0.5306, 0.5325, 0.5396],
        grad_fn=<SelectBackward>),
 tensor([0.4546, 0.4750, 0.4744, 0.4715, 0.4899, 0.4694, 0.4675, 0.4604],
        grad_fn=<SelectBackward>))

In [120]:
my_outputs.logits['easy'] * weights[:,0].view(-1,1), my_outputs.logits['hard'] * weights[:,1].view(-1,1)

(tensor([[ 0.0345,  0.5159],
         [-0.0625,  0.4305],
         [ 0.1009,  0.4935],
         [-0.1560,  0.3494],
         [ 0.0152,  0.4258],
         [-0.1158,  0.6765],
         [-0.0857,  0.5687],
         [-0.0879,  0.5175]], grad_fn=<MulBackward0>),
 tensor([[-0.1062,  0.1199],
         [-0.1626,  0.1613],
         [-0.1380,  0.1712],
         [-0.1267,  0.0550],
         [-0.0793, -0.0728],
         [-0.3368,  0.1656],
         [-0.2284,  0.1212],
         [-0.0639,  0.1775]], grad_fn=<MulBackward0>))

In [1]:
from datasets import load_dataset
snli_data = load_dataset('snli')
snli_data

Reusing dataset snli (/home/v-biyangguo/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 550152
    })
    validation: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 10000
    })
})

In [8]:
from collections import Counter
c = Counter(snli_data['test']['label'])
c

Counter({1: 3219, 0: 3368, 2: 3237, -1: 176})

In [6]:
snli_data['train'] = snli_data['train'].filter(lambda x:x['label']!=-1)
snli_data['train']

Loading cached processed dataset at /home/v-biyangguo/.cache/huggingface/datasets/snli/plain_text/1.0.0/1f60b67533b65ae0275561ff7828aad5ee4282d0e6f844fd148d05d3c6ea251b/cache-bd229a5e884be60a.arrow


Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 549367
})

In [5]:
550152-785

549367