In [3]:
import torch
import numpy as np
from torchtext import data

SEED = 1452 # for reproducibility
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

TEXT = data.Field(tokenize= 'spacy')
LABEL = data.LabelField(dtype=torch.float)


In [4]:
# Load dataset
with open('datasets/winemag-data-130k-v2.csv') as f:
    lines = f.readlines()

print(len(lines))

129976


In [69]:
# Split in train and test
# @Pierre: en fait on utilisera plutot le split de torchtext (voir plus bas)

TEST_SET_SIZE = .3
VALIDATION_SET_SIZE = .2

indices = list(range(1, len(lines)))
np.random.seed(SEED)
np.random.shuffle(indices)

first_split_index = int(TEST_SET_SIZE * len(lines))
second_split_index = int((TEST_SET_SIZE+VALIDATION_SET_SIZE) * len(lines))

print(first_split_index)
print(second_split_index)

test_indices = indices[:first_split_index]
validation_indices = indices[first_split_index:second_split_index]
train_indices = indices[second_split_index:]

train_set = [lines[k] for k in train_indices]
test_set = [lines[k] for k in test_indices]
validation_set = [lines[k] for k in validation_indices]

print(len(train_set))
print(len(test_set))
print(len(validation_set))
print(train_set[0:3])

38992
64988
64987
38992
25996
['18579,US,"Flavors of candied lemon, lime and pineapple are brightened by crisp acidity in this unoaked  Sauvignon Blanc. The wine reflects this cool-climate Monterey appellation with its clean, brisk character.",,85,16.0,California,Arroyo Seco,Central Coast,,,Mercy 2010 Sauvignon Blanc (Arroyo Seco),Sauvignon Blanc,Mercy\n', '89421,US,"There\'s a lot of oak on this Chardonnay, to judge by the buttered toast and butterscotch richness. Underneath all that is a wine ripe in tropical fruits and green apples, brightened by excellent, mouthwatering acidity. The oak stands out now, but give the wine until 2015 or 2016 in the cellar to let the parts integrate.",Sierra Mar Vineyard,90,40.0,California,Santa Lucia Highlands,Central Coast,,,Loring Wine Company 2012 Sierra Mar Vineyard Chardonnay (Santa Lucia Highlands),Chardonnay,Loring Wine Company\n', '56113,US,"Almost mauve in color, this widely distributed wine (named after the time the winemaking crew pops open

In [70]:
# Write split sets
with open('preprocessed_datasets/train.csv', 'w') as train_file:
    train_file.write(''.join(train_set))
with open('preprocessed_datasets/test.csv', 'w') as test_file:
    test_file.write(''.join(test_set))
with open('preprocessed_datasets/validation.csv', 'w') as validation_file:
    validation_file.write(''.join(validation_set))

In [5]:
# preprocess json
# (torchtext needs a file with a new json record per row and not a proper json)
import json
with open('datasets/winemag-data-130k-v2.json') as f:
    jsonfile = json.loads(f.read())
res = ''
# print(str(jsonfile[0]))
# print(json.dumps(jsonfile[0]))
for record in jsonfile:
    res += json.dumps(record) + '\n'

# print(res[:1000])


In [6]:
# Write preprocessed json to new file
with open('testfile', 'w') as out:
    out.write(res)

In [33]:
# Setup a dataset from the preprocessed json
full_dataset = data.TabularDataset(
    path='testfile',
    format='json',
    fields={'description': ('description', data.Field(sequential=True)),
            'points': ('points', data.Field(sequential=False))}
)

a,b = full_dataset.splits()

AttributeError: type object 'TabularDataset' has no attribute 'name'

In [29]:
import random
print(vars(full_dataset.examples[0]))
train_and_valid_data, test_data = full_dataset.split(random_state = random.seed(SEED))

train_data, valid_data = train_and_valid_data.split(random_state = random.seed(SEED))
print(len(train_data))
print(len(test_data))
print(len(valid_data))

{'points': '87', 'description': ['Aromas', 'include', 'tropical', 'fruit,', 'broom,', 'brimstone', 'and', 'dried', 'herb.', 'The', 'palate', "isn't", 'overly', 'expressive,', 'offering', 'unripened', 'apple,', 'citrus', 'and', 'dried', 'sage', 'alongside', 'brisk', 'acidity.']}
63686
38991
27294


In [68]:
print(len(TEXT.vocab))

2


In [72]:
# part 3 : build the dataset
tv_datafields = [("id", None),
                 ("country", LABEL),
                 ("description", TEXT),
                 ("designation", LABEL),
                 ("points", LABEL),
                 ("price", LABEL),
                 ("province", LABEL),
                 ("region_1", LABEL),
                 ("region_2", LABEL),
                 ("taster_name", LABEL),
                 ("taster_twitter_handle", LABEL),
                 ("title", LABEL),
                 ("variety", LABEL),
                 ("winery", LABEL)]

trn, vld, tst = data.TabularDataset.splits(path='preprocessed_datasets',
                                     format="csv",
                                     train= 'train.csv',
                                     validation='validation.csv',
                                     test='test.csv',
                                     fields=tv_datafields)

In [81]:
MAX_VOCAB_SIZE = 25000
TEXT.build_vocab(trn, max_size=MAX_VOCAB_SIZE)

print(len(TEXT.vocab))
# 25002 because of <pad> and <unk>

25002


In [80]:
print(TEXT.vocab.freqs.most_common(10))

[(',', 219889), ('.', 176716), ('and', 173574), ('of', 86268), ('the', 83922), ('a', 78846), ('with', 58105), ('is', 48536), ('wine', 40263), ('-', 37447), ('this', 36671), ('in', 30410), ('flavors', 30207), ('to', 27836), ('The', 26309), ("'s", 25712), ('fruit', 24835), ('It', 21974), ('on', 21537), ('it', 21082), ('This', 20352), ('that', 19700), ('palate', 19116), ('aromas', 17736), ('finish', 17409), ('acidity', 17309), ('tannins', 15241), ('from', 15112), ('but', 14779), ('cherry', 14363), ('black', 13894), ('are', 12808), ('ripe', 12598), ('has', 12303), ('A', 10872), ('red', 10620), ('by', 10392), ('for', 10382), ('Drink', 10212), ('%', 9635), ('spice', 9513), ('notes', 9270), ('oak', 8602), ('as', 8533), ('berry', 8463), ('nose', 8392), ('its', 8258), ('rich', 8143), ('an', 7980), ('fresh', 7963), ('dry', 7845), ('now', 7583), ('full', 7419), ('plum', 7346), ('fruits', 6724), ('apple', 6607), ('blend', 6563), ('well', 6561), ('sweet', 6509), ('white', 6268), ('offers', 6238), (

In [85]:
# Part 4 : building the iterator
from torchtext.data import Iterator, BucketIterator

train_iter, val_iter = BucketIterator.splits(
 (trn, vld),
 batch_sizes=(64, 64),
 sort_key=lambda x: len(x.description),
 sort_within_batch=False,
 repeat=False
)

test_iter = Iterator(tst, batch_size=64, sort=False, sort_within_batch=False, repeat=False)

In [98]:
# Part 5 : Wrapper (obscur)
class BatchWrapper:
      def __init__(self, dl, x_var, y_vars):
            self.dl, self.x_var, self.y_vars = dl, x_var, y_vars # we pass in the list of attributes for x &amp;amp;amp;amp;lt;g class="gr_ gr_3178 gr-alert gr_spell gr_inline_cards gr_disable_anim_appear ContextualSpelling ins-del" id="3178" data-gr-id="3178"&amp;amp;amp;amp;gt;and y&amp;amp;amp;amp;lt;/g&amp;amp;amp;amp;gt;
  
      def __iter__(self):
            for batch in self.dl:
                  x = getattr(batch, self.x_var) # we assume only one input in this wrapper
  
                  if self.y_vars is not None: # we will concatenate y into a single tensor
                        y = torch.cat([getattr(batch, feat).unsqueeze(1) for feat in self.y_vars], dim=1).float()
                  else:
                        y = torch.zeros((1))
                  yield (x, y)
  
      def __len__(self):
            return len(self.dl)
 
train_dl = BatchWrapper(train_iter, "description", [
    "country",
    "designation",
    "points",
    "price",
    "province",
    "region_1",
    "region_2",
    "taster_name",
    "taster_twitter_handle",
    "title",
    "variety",
    "winery"
])
valid_dl = BatchWrapper(val_iter, "description", [
    "country",
    "designation",
    "points",
    "price",
    "province",
    "region_1",
    "region_2",
    "taster_name",
    "taster_twitter_handle",
    "title",
    "variety",
    "winery"
])
test_dl = BatchWrapper(test_iter, "description", None)

In [99]:
next(train_dl.__iter__())

(tensor([[ 4576,  1271,  6818,  ...,    22,    22, 22698],
         [   13,    11,    25,  ...,   713,     9, 22196],
         [  186,     0,     5,  ...,    61,     7,     9],
         ...,
         [    1,     1,     1,  ...,     1,     1,     1],
         [    1,     1,     1,  ...,     1,     1,     1],
         [    1,     1,     1,  ...,     1,     1,     1]]),
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.