# Dataset and data loader

This notebook tests the dataset and data loader classes.

In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
import pickle

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from nlp_assemblee.datasets import AssembleeDataset

In [7]:
with open("../../data/processed/14th_records.pkl", "rb") as f:
    records = pickle.load(f)

In [8]:
with open("../../data/processed/14th_camembert_tokenizer.pkl", "rb") as f:
    camembert_tokenizer = pickle.load(f)

In [9]:
records[0]

{'nom': 'Pierre Lellouche',
 'groupe': 'UMP',
 'seance_id': 11,
 'date_seance': '2012-07-03',
 'titre': 'déclaration de politique générale du gouvernement débat et vote sur cette déclaration',
 'titre_complet': 'déclaration de politique générale du gouvernement débat et vote sur cette déclaration',
 'intervention': 'Alors, arrêtez de dépenser !',
 'nb_mots': 8,
 'intervention_count': 1562,
 'nb_mots_approx': 5,
 'date_naissance': '1951-05-03',
 'sexe': 'H',
 'profession': 'Avocat et universitaire',
 'nb_mandats': 2,
 'date': Timestamp('2012-07-03 00:00:00'),
 'year': 2012,
 'month': 7,
 'day': 3,
 'y_naissance': 1951,
 'n_y_naissance': 0.8658536585365854,
 'n_year': 0.0,
 'cos_month': -0.8660254037844388,
 'sin_month': -0.4999999999999997,
 'cos_day': 0.8207634412072763,
 'sin_day': 0.5712682150947923,
 'n_sexe': 0,
 'label': 2,
 'camembert_tokens': {'intervention': [5, 574, 7, 26748, 8, 11104, 83, 6],
  'titre_complet': [5,
   3035,
   8,
   462,
   1229,
   25,
   754,
   2159,
   14

In [11]:
camembert_dataset = AssembleeDataset(
    records=records,
    bert_type="camembert",
    text_vars=["intervention", "titre_complet"],
    features_vars=["n_y_naissance", "n_sexe"],
    label_var="label",
)

In [12]:
camembert_dataset[0]

({'intervention': [5, 574, 7, 26748, 8, 11104, 83, 6],
  'titre_complet': [5,
   3035,
   8,
   462,
   1229,
   25,
   754,
   2159,
   14,
   2422,
   32,
   78,
   3035,
   6],
  'features': array([0.86585366, 0.        ])},
 2)

In [14]:
camembert_tokenizer.decode(camembert_dataset[0][0]["titre_complet"])

'déclaration de politique générale du gouvernement débat et vote sur cette déclaration'

In [15]:
def collate_fn(data):
    """
    data: is a list of tuples with (example, label, length)
          where 'example' is a tensor of arbitrary shape
          and label/length are scalars
    """
    labels = torch.tensor([int(x[1]) for x in data])

    padded_inputs = {}

    keys = data[0][0].keys()

    for var in keys:
        if var == "features":
            padded_inputs["features"] = torch.tensor(np.array([x[0][var] for x in data]))
        else:
            padded_inputs[var] = pad_sequence(
                [torch.tensor(x[0][var]) for x in data], batch_first=True
            )

    return padded_inputs, labels.long()

In [16]:
camembert_dataset = AssembleeDataset(
    records=records,
    bert_type="bert",
    text_vars=["intervention"],
    features_vars=False,
    label_var="label",
)

In [17]:
camembert_dataloader = DataLoader(
    camembert_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    collate_fn=collate_fn,
    prefetch_factor=3,
)

In [19]:
x, y = next(iter(camembert_dataloader))

In [20]:
x

{'intervention': tensor([[   101,  44356,  10141,  ...,      0,      0,      0],
         [   101,  10281, 102609,  ...,      0,      0,      0],
         [   101,  13796,  24931,  ...,      0,      0,      0],
         ...,
         [   101,  17434,  20514,  ...,      0,      0,      0],
         [   101,  22135,  32769,  ...,  17083,    119,    102],
         [   101,  20491,  12970,  ...,      0,      0,      0]])}