# reddit_tifu/short

Dataset: tensorflow/datasets/reddit_tifu/short

In [None]:
class DC:
    dataset = 'reddit_tifu/short'
    split = ['train[:50%]', 'train[50%:70%]', 'train[70%:]']
    batch_size = 64

class MC:
    vocab_size = 10000
    embedding_features = 8
    sequence_length = 100
    
class Config:
    data = DC
    model = MC

## Setup

In [None]:
import io, os, re, shutil, string
from datetime import datetime
import numpy as np
import requests
import tensorflow as tf

In [None]:
def extract_sample_fn(s):
    return (s['documents'], s['tldr'])

def standardize_fn(x):
    x = tf.strings.lower(x)
    return tf.strings.regex_replace(x, '[%s]' % re.escape(string.punctuation), '')

def prepare(ds):
    return (ds.filter(lambda r: r['score'] != '')
              .batch(Config.data.batch_size)
              .map(extract_sample_fn, num_parallel_calls=tf.data.AUTOTUNE)
              .cache()
              .prefetch(tf.data.AUTOTUNE))

## Dataset

In [None]:
import tensorflow_datasets as tfds

class Data:
    (train, val, test), info = tfds.load(Config.data.dataset,
                                         split=Config.data.split,
                                         with_info=True,
                                         shuffle_files=True)
    
    (train, val, test) = map(prepare, (train, val, test))

In [None]:
Data.info

## Defining Model

In [None]:
from tensorflow.keras.layers.experimental.preprocessing import TextVectorization

vectorize_layer = TextVectorization(
    standardize=standardize_fn,
    max_tokens=Config.model.vocab_size,
    output_mode='int',
    output_sequence_length=Config.model.sequence_length)

text_ds = Data.train.map(lambda x, y: x)
vectorize_layer.adapt(text_ds)

In [None]:
from tensorflow.keras.layers import Embedding

em = Embedding(
    Config.model.vocab_size,
    Config.model.embedding_features,
    name='em')

score_model = Sequential([
    vectorize_layer,
    em,
    GlobalAveragePooling1D(name='avg_pool'),
    Dense(16, activation='relu', name='fc1'),
    Dense(1, name='predictions')],
    name='score_reg')