In [None]:
!pip install setuptools==66
!pip install matplotlib_inline
!pip install d2l==1.0.0b

In [None]:
import os
import re
import torch
from torch import nn
from d2l import torch as d2l

In [None]:
d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')
data_dir = d2l.download_extract('SNLI')

Downloading ../data/snli_1.0.zip from https://nlp.stanford.edu/projects/snli/snli_1.0.zip...


In [None]:
#Read the dataset
# NOte that in order to come up with this cleaning
# the author has to do EDA first see what the dataset looks like
# it's okay to copy and paste here without understanding 
# but when you work in real dataset, you have do EDA yourself
def read_snli(data_dir, is_train):

  def extract_text(s):
    s = re.sub('\\(', '', s)
    s = re.sub('\\)', '', s)
    s = re.sub('\\s{2,}', ' ', s)

    return s.strip()

  label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
  file_name = os.path.join(data_dir, 'snli_1.0_train.txt' if is_train else 'snli_1.0_test.txt')

  with open(file_name, 'r') as f:
      lines = f.readlines()
      rows = [row.split('\t') for row in lines[1:]]
    
  premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
  hypotheses = [extract_text(row[2]) for row in rows if row[0] in label_set]
  labels = [label_set[row[0]] for row in rows if row[0] in label_set]

  return premises, hypotheses, labels

train_data = read_snli(data_dir, True)

for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
  print('premise:', x0)
  print('hypothesis:', x1)
  print('label:', y)

premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is training his horse for a competition .
label: 2
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is at a diner , ordering an omelette .
label: 1
premise: A person on a horse jumps over a broken down airplane .
hypothesis: A person is outdoors , on a horse .
label: 0


In [None]:
test_data = read_snli(data_dir, is_train = False)
for data in [train_data, test_data]:
  print([[row for row in data[2]].count(i) for i in range(3)])

sample raw line =  neutral	( ( This ( church choir ) ) ( ( ( sings ( to ( the masses ) ) ) ( as ( they ( ( sing ( joyous songs ) ) ( from ( ( the book ) ( at ( a church ) ) ) ) ) ) ) ) . ) )	( ( The church ) ( ( has ( cracks ( in ( the ceiling ) ) ) ) . ) )	(ROOT (S (NP (DT This) (NN church) (NN choir)) (VP (VBZ sings) (PP (TO to) (NP (DT the) (NNS masses))) (SBAR (IN as) (S (NP (PRP they)) (VP (VBP sing) (NP (JJ joyous) (NNS songs)) (PP (IN from) (NP (NP (DT the) (NN book)) (PP (IN at) (NP (DT a) (NN church))))))))) (. .)))	(ROOT (S (NP (DT The) (NN church)) (VP (VBZ has) (NP (NP (NNS cracks)) (PP (IN in) (NP (DT the) (NN ceiling))))) (. .)))	This church choir sings to the masses as they sing joyous songs from the book at a church.	The church has cracks in the ceiling.	2677109430.jpg#1	2677109430.jpg#1r1n	neutral	contradiction	contradiction	neutral	neutral

[183416, 183187, 182764]
[3368, 3237, 3219]


In [None]:
class SNLIDataset(torch.utils.data.Dataset):

  def __init__(self, dataset, num_steps, vocab = None):
    self.num_steps = num_steps
    all_premise_tokens = d2l.tokenize(dataset[0])
    all_hypothesis_tokens = d2l.tokenize(dataset[1])
    if vocab is None:
      #if vocab is None, build one with all tokens from premise and hypothesis
      self.vocab = d2l.Vocab(all_premise_tokens + all_hypothesis_tokens, min_freq = 5, reserved_tokens=['<pad>'])
    else:
      self.vocab = vocab
    
    self.premises = self._pad(all_premise_tokens)
    self.hypotheses = self._pad(all_hypothesis_tokens)
    self.labels = torch.tensor(dataset[2])

    print('read ' + str(len(self.premises)) + ' examples')
  
  def _pad(self, lines):
    return torch.tensor([d2l.truncate_pad(self.vocab[line], self.num_steps, self.vocab['<pad>']) for line in lines])
  
  def __getitem__(self, idx):
    return (self.premises[idx], self.hypotheses[idx], self.labels[idx])
  
  def __len__(self):
    return len(self.premises)


In [None]:
def load_data_snli(batch_size, num_steps = 50):
  num_workers = d2l.get_dataloader_workers()
  data_dir = d2l.download_extract('SNLI')
  train_data = read_snli(data_dir, True)
  test_data = read_snli(data_dir, False)

  train_set = SNLIDataset(train_data, num_steps)
  test_set = SNLIDataset(test_data, num_steps)

  train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle = True, num_workers = num_workers)
  test_iter = torch.utils.data.DataLoader(test_set, batch_size, shuffle = True, num_workers = num_workers)

  #NOTE in the book: 
  # any new token from the testing set will be unknown to the model trained on the training set.
  return train_iter, test_iter, train_set.vocab

train_iter, tets_iter, vocab = load_data_snli(128, 50)

print('len vocab = ', len(vocab))

read 549367 examples
read 9824 examples
len vocab =  18678


