In [1]:
import pandas as pd
import ast

In [2]:
from pandas import RangeIndex

In [3]:
from datasets import Dataset

In [4]:
import argparse
import pytorch_lightning as pl
from torch.utils.data import DataLoader


from utils.dataset_utils import get_predictive_collate_fn
from utils.parsing.predictive import parse_predictive_config

config = './trained_models/predictive/tatoeba-de-en-predictive.yml'
parsed_config = parse_predictive_config(config, pretrained=False, develop=True)
nmt_model = parsed_config["nmt_model"]
pl_model = parsed_config["pl_model"]
tokenizer = parsed_config["tokenizer"]
train_dataset = parsed_config["train_dataset"]
val_dataset = parsed_config["validation_dataset"]

In [5]:
from custom_datasets import BayesRiskDataset

In [184]:
base_dir = './trained_models/NMT/tatoeba-de-en/data/'

dataset = BayesRiskDataset.load_dataset(base_dir, develop=True)


In [185]:
# import pyarrow as pa
# def preprocess_f(batch):
#     features = pl_model.preprocess_function(batch)
#     print(features["avg_pool_encoder_hidden_state"].dtype)
    
#     features = {key: pa.array(v) for v in features.items()}
#     return features

In [186]:
dataset = dataset.map(pl_model.preprocess_function, batched=True, batch_size=32)

  0%|          | 0/20 [00:00<?, ?ba/s]

In [None]:
dataset.data.table

In [169]:
# Save the model
import pyarrow.parquet as pq
    
pq.write_table(dataset.data.table, "test.parquet")

In [170]:
ref_table = pq.read_table("test.parquet")

In [171]:
# From this table we can take all the relevant features
ref_table.take([0,1,2,3,100])

pyarrow.Table
sources: string
hypothesis: string
utilities: string
count: int64
avg_pool_encoder_hidden_state: list<item: float>
  child 0, item: float
avg_pool_decoder_hidden_state: list<item: float>
  child 0, item: float
max_pool_encoder_hidden_state: list<item: float>
  child 0, item: float
max_pool_decoder_hidden_state: list<item: float>
  child 0, item: float
----
sources: [["Ich habe dich nicht hereinkommen gehört.","Ich habe dich nicht hereinkommen gehört.","Du wirst Tom küssen müssen.","Du wirst Tom küssen müssen.","Wir werden am nächsten Samstag eine Party machen."]]
hypothesis: [["I didn't hear you come in.","I didn't hear you coming in.","You're going to have to kiss Tom.","You'll have to kiss Tom.","We will have a party next Saturday."]]
utilities: [["{1.4914770126342773: 72, 0.7293738126754761: 11, 1.0108041763305664: 9, 0.7235156893730164: 4, 0.0017659920267760754: 1, 0.713856041431427: 1, 0.5867420434951782: 1, 0.21969538927078247: 1}","{1.1027610301971436: 72, 0.742346

In [172]:
# Make a reference dataset
no_features_dataset = dataset.remove_columns(pl_model.feature_names + ["sources", ]) # also drop hypothesis later (but we keep it now for testing)

In [173]:
df = no_features_dataset.to_pandas()

In [174]:
# Keep track of the index
df['ref_id'] = RangeIndex(start=0, stop=df.index.stop, step=1)


In [175]:
# Repeat the count
df = df.reindex(df.index.repeat(df["count"]))

In [176]:
df = df.reset_index(level=0, inplace=False, drop=True)

In [177]:
df.head()

Unnamed: 0,hypothesis,utilities,count,ref_id
0,I didn't hear you come in.,"{1.4914770126342773: 72, 0.7293738126754761: 1...",9,0
1,I didn't hear you come in.,"{1.4914770126342773: 72, 0.7293738126754761: 1...",9,0
2,I didn't hear you come in.,"{1.4914770126342773: 72, 0.7293738126754761: 1...",9,0
3,I didn't hear you come in.,"{1.4914770126342773: 72, 0.7293738126754761: 1...",9,0
4,I didn't hear you come in.,"{1.4914770126342773: 72, 0.7293738126754761: 1...",9,0


In [178]:
## Lastly we create the collate fn that looks up the right columns
from torch.utils.data import DataLoader
import torch


In [179]:
new_dataset = Dataset.from_pandas(df)

In [180]:
import numpy as np
def collate_fn(batch):
    
    
    
    # First get the features
    ids = [s["ref_id"] for s in batch]
    
    sources = [s["hypothesis"] for s in batch]
    
    info = ref_table.take(ids)
    
    
    features = {feature_name: torch.Tensor(np.stack(info[feature_name].to_numpy())) for feature_name in pl_model.feature_names}

    return info



In [181]:
 train_dataloader = DataLoader(new_dataset, collate_fn=collate_fn, batch_size=256, shuffle=True, )

In [183]:
# Takes approximate 2 minutes each epoch
from tqdm import tqdm
for x in tqdm(train_dataloader):
    pass
    

100%|████████████████████████████████████████████████████████████████████████████████| 811/811 [00:14<00:00, 56.85it/s]


In [None]:

dataset = BayesRiskDataset.load_dataset(base_dir, develop=False)

In [None]:
df = dataset.to_pandas()

df = df.reindex(df.index.repeat(df["count"]))

In [None]:
dataset = Dataset.from_pandas(df)

In [None]:
# Now without preprocessing:

def collate_fn(batch):
    
    b = {
        "sources": [s["sources"] for s in batch],
        "hypothesis": [s["hypothesis"] for s in batch]
    }
    features = pl_model.preprocess_function(b)

    return features

train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=32, shuffle=True, )

In [None]:
# 4 minutes per loop ! (2 times as slow)
from tqdm import tqdm
for x in tqdm(train_dataloader):
    pass

In [None]:
# Now without preprocessing:

def collate_fn(batch):
    
    

    return 0

train_dataloader = DataLoader(dataset, collate_fn=collate_fn, batch_size=256, shuffle=True, )

In [None]:
from tqdm import tqdm
for x in tqdm(train_dataloader):
    pass