Thanks for the public notebooks below:  
https://www.kaggle.com/code/hoyso48/1st-place-solution-training  
https://www.kaggle.com/code/irohith/aslfr-ctc-based-on-prev-comp-1st-place  
https://www.kaggle.com/code/markwijkhuizen/aslfr-transformer-training-inference  
This is the 3rd place solution training code, you could refer the solution here:  
https://www.kaggle.com/competitions/asl-fingerspelling/discussion/434393  

# Install libs

In [None]:
try:
  from icecream import ic
  import pymp
except Exception:
  !pip install -q icecream --no-index --find-links=file:///kaggle/input/icecream
  !pip install -q pymp-pypi --no-index --find-links=file:///kaggle/input/pymp-pypi/pymp-pypi-0.4.5/dist

# Import libs

In [None]:
import sys, os
import numpy as np
import pandas as pd
import json
import re
import six
import glob
import traceback
import inspect
from typing import Union
from collections import Counter, OrderedDict, defaultdict
from collections.abc import Iterable
from multiprocessing import cpu_count
from tqdm.notebook import tqdm
from icecream import ic
import pymp
import transformers
import tensorflow as tf
import torch
from torch import nn, einsum
import torch.nn.functional as F
ic(tf.__version__, torch.__version__)

# Flags

In [None]:
class FLAGS(object):
  # for tfrecords args, you could ignore
  seed = 1024
  batch_parse = False
  sparse_to_dense = True
  eval_keys = []
  incl_keys = []
  excl_keys = []
  recount_tfrecords = False  
  batch_sizes = []
  buffer_size = 1024
  buckets = None
  drop_remainder = None
  shard_by_files = True
  shuffle_batch = None
  shuffle_files = None
  num_dataset_threads = 0
  num_prefetch_batches = 1024
  repeat_then_shuffle = False
  length_index = 1
  length_key = None
  dynamic_pad = True
  cache = False
  cache_after_map = False
  fixed_random = False
  parallel_read_files = True
  padding_idx = 0
  dataset_keys = []
  dataset_excl_keys = []
  exclude_varlen_keys = False
  prefetch = None
  dataset_ordered = False
    
  torch = True
  keras = False
    
  # online==False means using n-fold split and train on fold 1,2, folds-1 while valid on fold 0
  # online==True means using all train data but still will valid on fold 0
  online = False  
  folds = 4
  fold = 0
  fold_seed = 1229
  root = '../input/asl-fingerspelling'
  working = '/kaggle/working'
  use_z = True  # use x,y,z if True
  norm_frames = True # norm frames using x - mean / std
  concat_frames = True # concat original and normalized frames
  add_pos = True # add abs frame pos, like 1/1000., 2/1000.
  sup_weight = 0.1 # for supplement dataset assigin weight 0.1
  
  train_files = []
  valid_files = []
      
  mix_sup = True # train & sup dataset
  vie = 5 # valid interval epochs 
  lr = 2e-3
  epochs = 400 
  batch_size = 128
  eval_batch_size = 256
  awp = True
  adv_start_epoch = None
  adv_lr = 0.2
  adv_eps = 0
  fp16 = False # notice fp16 could not be set True if using awp here, otherwise nan
  optimizer = 'Adam'
  opt_eps = 1e-6 
  scheduler = 'linear'
  # for model related configs
  encoder_layers = 17
  encoder_units = 200 
  n_frames = 320  
  distributed = False
    
def load_json(filename):
  with open(filename) as fh:
    obj = json.load(fh)
  return obj

# Common configs

In [None]:
LPOSE = [13, 15, 17, 19, 21]
RPOSE = [14, 16, 18, 20, 22]
POSE = LPOSE + RPOSE

LIP = [
    61, 185, 40, 39, 37, 0, 267, 269, 270, 409,
    291, 146, 91, 181, 84, 17, 314, 405, 321, 375,
    78, 191, 80, 81, 82, 13, 312, 311, 310, 415,
    95, 88, 178, 87, 14, 317, 402, 318, 324, 308,
]
ic(len(LIP))
LLIP = [84,181,91,146,61,185,40,39,37,87,178,88,95,78,191,80,81,82]
RLIP = [314,405,321,375,291,409,270,269,267,317,402,318,324,308,415,310,311,312]
MID_LIP = [i for i in LIP if i not in LLIP + RLIP]
ic(len(LLIP), len(RLIP), len(MID_LIP))

NOSE=[
    1,2,98,327
]
LNOSE = [98]
RNOSE = [327]
MID_NOSE = [i for i in NOSE if i not in LNOSE + RNOSE]

LEYE = [
    263, 249, 390, 373, 374, 380, 381, 382, 362,
    466, 388, 387, 386, 385, 384, 398,
]
REYE = [
    33, 7, 163, 144, 145, 153, 154, 155, 133,
    246, 161, 160, 159, 158, 157, 173,
]

N_HAND_POINTS = 21
N_POSE_POINTS = len(LPOSE)
N_LIP_POINTS = len(LLIP)
N_EYE_POINTS = len(LEYE)
N_NOSE_POINTS = len(LNOSE)
N_MID_POINTS = len(MID_LIP + MID_NOSE)

SEL_COLS = []
for i in range(N_HAND_POINTS):
  SEL_COLS.extend([f'x_left_hand_{i}', f'y_left_hand_{i}', f'z_left_hand_{i}'])
for i in range(N_HAND_POINTS):
  SEL_COLS.extend([f'x_right_hand_{i}', f'y_right_hand_{i}', f'z_right_hand_{i}'])
for i in LPOSE:
  SEL_COLS.extend([f'x_pose_{i}', f'y_pose_{i}', f'z_pose_{i}'])
for i in RPOSE:
  SEL_COLS.extend([f'x_pose_{i}', f'y_pose_{i}', f'z_pose_{i}'])
for i in LLIP:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])
for i in RLIP:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])

for i in LEYE:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])
for i in REYE:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])
  
for i in LNOSE:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])
for i in RNOSE:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])
  
for i in MID_LIP:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])
for i in MID_NOSE:
  SEL_COLS.extend([f'x_face_{i}', f'y_face_{i}', f'z_face_{i}'])
    
N_COLS = len(SEL_COLS)
ic(N_COLS)
    
CHAR2IDX = load_json(f'../input/asl-fingerspelling/character_to_prediction_index.json')
CHAR2IDX = {k: v + 1 for k, v in CHAR2IDX.items()}
N_CHARS = len(CHAR2IDX)
ic(N_CHARS)

PAD_IDX = 0
SOS_IDX = PAD_IDX # Start Of Sentence
EOS_IDX = N_CHARS + 1 # End Of Sentence
ic(PAD_IDX, SOS_IDX, EOS_IDX)

PAD_TOKEN = '<PAD>'
SOS_TOKEN = PAD_TOKEN
EOS_TOKEN = '<EOS>'

CHAR2IDX[PAD_TOKEN] = PAD_IDX
CHAR2IDX[EOS_TOKEN] = EOS_IDX 

ADDRESS_TOKEN = '<ADDRESS>'
URL_TOKEN = '<URL>'
PHONE_TOKEN = '<PHONE>'
SUP_TOKEN = '<SUP>'

VOCAB_SIZE = len(CHAR2IDX)
IDX2CHAR = {v: k for k, v in CHAR2IDX.items()}
ic(VOCAB_SIZE)
ic(len(IDX2CHAR))

STATS = {}
CLASSES = [
  'address', 
  'url', 
  'phone', 
  'sup',
  ]
PHRASE_TYPES = dict(zip(CLASSES, range(len(CLASSES))))
N_TYPES = len(CLASSES)
MAX_PHRASE_LEN = 32

def get_vocab_size():
  vocab_size = VOCAB_SIZE
  return vocab_size

def get_n_cols(no_motion=False, use_z=None):
  n_cols = N_COLS
  if use_z is None:
    use_z = FLAGS.use_z
  
  if FLAGS.concat_frames:
    assert FLAGS.norm_frames
    n_cols += N_COLS
  
  if not use_z:
    n_cols = n_cols // 3 * 2
    
  if FLAGS.add_pos:
    n_cols += 1
  
  return n_cols

# Tfrecord dataset

In [None]:
def gen_inputs(files, 
           decode_fn, 
           batch_size=64,
           post_decode_fn=None,
           num_epochs = None, 
           num_threads=None, 
           buffer_size = 15000, #change from 1000 to 15000
           dynamic_pad=True,
           shuffle=True,
           shuffle_batch=None,
           shuffle_files=None,
           ordered=None,
           min_after_dequeue=None, #depreciated
           seed=None, 
           enqueue_many=False,  #depreciated
           fixed_random=False, 
           drop_remainder=False, 
           num_prefetch_batches=None, 
           bucket_boundaries=None,
           length_index=None,
           length_key=None,
           length_fn=None,
           bucket_batch_sizes=None,
           repeat=True,
           initializable=False,
           filter_fn=None,
           balance_pos_neg=False,
           pos_filter_fn=None,
           neg_filter_fn=None,
           count_fn=None,
           return_iterator=False,
           Dataset=None,
           batch_parse=False, #by default will be line parse
           hvd_shard=True,
           shard_by_files=False,
           training=False,
           simple_parse=False,
           repeat_then_shuffle=False,
           cache=False,
           cache_file='',
           cache_after_map=False,
           device=None,
           world_size=1,
           rank=0,
           parallel_read_files=False,
           use_feed_dict=False,
           feed_name=None,
           padding_values=None,
           distribute_strategy=None,
           torch=False,
           keras=False,
           subset=None,
           return_numpy=False,
           name='input'):
  Dataset = Dataset or tf.data.TFRecordDataset
  AUTO = tf.data.AUTOTUNE
  use_horovod = False

  def shard(d):
    return d.shard(hvd.size(), hvd.rank())

  # Choose to use cpu outside input function like in dataset.py
  #with tf.device('/cpu:0'):
  if isinstance(files, str):
    files = gezi.list_files(files)
  assert len(files) > 0

  if not num_threads:
    num_threads = 8

  if 'batch_size' in inspect.getfullargspec(decode_fn).args:
    decode_fn_ = decode_fn
    def decode_function(example):
      return decode_fn_(example, batch_size)
    decode_fn = decode_function
    
  if not num_epochs: 
    num_epochs = None

  if shuffle:
    if shuffle_files is None:
      shuffle_files = True
    if shuffle_batch is None:
      shuffle_batch = True
  else:
    if shuffle_files is None:
      shuffle_files = False
    if shuffle_batch is None:
      shuffle_batch = False
    # TDO 并行读取就会打乱顺序？
    if not shuffle_files:
      parallel_read_files = False

  if fixed_random:
    if seed is None:
      seed = 1024
  else:
    pass

  num_files = len(files)
  if use_feed_dict and feed_name:
    files = tf.compat.v1.placeholder(tf.string, [None], feed_name)
    gezi.set_global(feed_name, files)

  if not num_prefetch_batches:
    #num_prefetch_batches = num_threads + 3
    if buffer_size:
      num_prefetch_batches = int(buffer_size / batch_size)
    # else:
    #   num_prefetch_batches = 100
  
  if not buffer_size and num_prefetch_batches:
    buffer_size = num_prefetch_batches * batch_size
    
  options = tf.data.Options()
  try:
    options.threading.private_threadpool_size = num_threads
    options.threading.max_intra_op_parallelism = 1
  except Exception:
    options.experimental_threading.private_threadpool_size = num_threads
    options.experimental_threading.max_intra_op_parallelism = 1

  options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
  options.experimental_deterministic = True

  if shuffle and not fixed_random:
    options.experimental_deterministic = False

  if not ordered:
    options.experimental_deterministic = False

  if not parallel_read_files or num_files == 1:
    d = Dataset(files)
    d = d.with_options(options)
    if use_horovod and hvd_shard:
      d = shard(d)
    if not use_horovod and world_size > 1:
      d = d.shard(world_size, rank)
  else:
    try:
      if shffle_files and (use_horovod or world_size > 1):
        assert seed
      d = tf.data.Dataset.list_files(files, shuffle=shuffle_files, seed=seed)
      d = d.with_options(options)
    except Exception:
      d = tf.data.Dataset.from_tensor_slices(files)
      d = d.with_options(options)
    # here shard by files, not work good, especially for text line dataset with horovod
    if use_horovod and shard_by_files:
      d = shard(d)
    elif world_size > 1 and shard_by_files:
      d = d.shard(world_size, rank)

    d = d.interleave(Dataset,
                  #  cycle_length=min(len(files), 1000),  # in tf 1.14 must set and can not set as AUTOTUNE for tf 2.1 with default as AUTOTUNE
                  block_length=1,
                  num_parallel_calls=AUTO)

    if world_size > 1 and not shard_by_files:
      d = d.shard(world_size, rank)

  if repeat and repeat_then_shuffle:
    d = d.repeat(num_epochs)

  if cache and (not FLAGS.cache_after_map):
    d = d.cache(cache_file)
    
  # must batch then map if use pyfunc which you might use batch_parse, here batch_parse means batch parse otherwise slower but simple and powerfull...
  if not batch_parse:
    d = d.map(decode_fn, num_parallel_calls=AUTO)
    if cache and cache_after_map:
      d = d.cache(cache_file)
  
  if filter_fn is not None and not batch_parse:
    d = d.filter(filter_fn)

  if shuffle_batch:
    d = d.shuffle(buffer_size=buffer_size, seed=seed, reshuffle_each_iteration=True)

  # shuffle then repeat
  if repeat and not repeat_then_shuffle:
    d = d.repeat(num_epochs)
  
  if dynamic_pad:
    if not batch_parse: 
      d = d.padded_batch(batch_size, drop_remainder=drop_remainder)
    else:
      d = d.batch(batch_size, drop_remainder=drop_remainder)
  else:
    d = d.batch(batch_size, drop_remainder=drop_remainder)

  if batch_parse:
    d = d.map(decode_fn, num_parallel_calls=AUTO)
    if filter_fn is not None:
      try:
        d = d.unbatch()
        d = d.filter(filter_fn)
      except Exception:
        d = d.unbatch()

      d = d.batch(batch_size, drop_remainder=drop_remainder)

  if post_decode_fn is not None:
    d = d.map(post_decode_fn, num_parallel_calls=AUTO)

  if cache and FLAGS.cache_after_map:
    logging.debug('Cache datase after map')
    d = d.cache(cache_file)

  d = d.prefetch(FLAGS.prefetch or AUTO)

  if not return_numpy:    
    return d
  else:
    return d.as_numpy_iterator()


In [None]:
def decode_example(x):
  if tf.executing_eagerly():
    x = x.numpy()
  x = tf.train.Example.FromString(x).features.feature
  features = {}
  for key in x.keys():
    typenames = ['bytes_list', 'float_list', 'int64_list']
    dtypes = [object, np.float32, np.int64]
    for typename, dtype in zip(typenames, dtypes):
      value = getattr(x[key], typename).value
      if value:
        features[key] = np.array(value, dtype=dtype)
  return features

def first_example(record_file):
  if isinstance(record_file, (list, tuple)):
    record_file = record_file[0]
  if tf.executing_eagerly():
    for item in tf.data.TFRecordDataset(record_file):
      x = decode_example(item)
      return x
  else:
    for item in tf.compat.v1.python_io.tf_record_iterator(record_file):
      x = decode_example(item)
      return x

def npdtype2tfdtype(dtype, large=False):
  if dtype == np.float32:
    return tf.float32
  if dtype == np.int32:
    if not large:
      return tf.int32
    else:
      return tf.int64
  if dtype == np.int64:
    return tf.int64
  if dtype == np.float64:
    return tf.float32
  return tf.string

def sparse_tensor_to_dense(input_tensor, default_value=0):  
  return tf.sparse.to_dense(input_tensor, default_value=default_value, validate_indices=False)

def sparse2dense(features, key=None, default_value=0):

  def sparse2dense_(features, key, default_value):
    val = features[key]
    if val.values.dtype == tf.string:
      default_value = None
    val = sparse_tensor_to_dense(val, default_value)
    features[key] = val

  modified = False
  if key:
    sparse2dense_(features, key)
    modified = True
  else:
    from tensorflow.python.framework.sparse_tensor import SparseTensor
    for key, val in features.items():
      if isinstance(val, SparseTensor):
        sparse2dense_(features, key, default_value)
        modified = True
  return modified

def get_num_records_single(tf_record_file, recount=False):
  if not recount:
    filename = os.path.basename(tf_record_file)
    filename = filename.replace('-', '.').replace('_', '.')
    l = filename.split('.')

    for item in reversed(l):
      if item.isdigit():
        return int(item)

  # try:
  return sum(
      1 for _ in tf.compat.v1.python_io.tf_record_iterator(tf_record_file))
  # except Exception:
  #   return 0


def get_num_records(files, recount=False):
  if isinstance(files, str):
    files = gezi.list_files(files)
  res = sum([
      get_num_records_single(file, recount=recount)
      for file in tqdm(files, ascii=False, desc='get_num_records', leave=False)
  ])
  return res


In [None]:
# A wrapper base class for tfrecords related dataset 
class TfrecordsDataset(object):
  def __init__(self, 
               subset='valid',
               batch_size=None,
               Type=None, 
               files=None,
               num_instances=None,
               batch_parse=None,
               sparse_to_dense=None,
               hvd_shard=True,
               use_int32=True,
               is_info=False,
               eval_keys=[],
               incl_keys=[],
               excl_keys=[],
               str_keys=[],
               varlen_keys=[],
               use_tpu=False,
               recount=None):
    self.subset = subset
    self.filter_fn = None
    self.pos_filter_fn = None
    self.neg_filter_fn = None 
    self.count_fn = None
    self.Type = Type
    self.batch_parse = batch_parse if batch_parse is not None else FLAGS.batch_parse
    self.sparse_to_dense = sparse_to_dense if sparse_to_dense is not None else FLAGS.sparse_to_dense
    self.use_post_decode = None
    # if self.batch_parse:
    #   self.sparse_to_dense = False
    self.batch_size = batch_size or FLAGS.batch_size
    self.hvd_shard = hvd_shard
    self.indexes = {'train': -1, 'valid': -1, 'test': -1}
    self.is_info = is_info
    self.eval_keys = eval_keys or FLAGS.eval_keys
    if subset == 'test':
      self.eval_keys = gezi.get('test_keys') or self.eval_keys
    self.show_keys = set()  # 如果用户不指定eval_keys 可以用self.show_keys所有非变成以及长度为0,1的key 前提需要使用.adds不能自己外部定义
    self.excl_keys = excl_keys or FLAGS.excl_keys
    self.incl_keys = incl_keys or FLAGS.incl_keys
    self.str_keys = str_keys
    self.varlen_keys = varlen_keys

    self.parse_fn = tf.io.parse_single_example if not self.batch_parse else tf.io.parse_example

    self.features_dict = {}
    self.has_varlen_feats = False
    self.use_tpu = use_tpu
    try:
      # TPU detection. No parameters necessary if TPU_NAME environment variable is
      # set: this is always the case on Kaggle.
      tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
      # print('Running on TPU ', tpu.master())
    except ValueError:
      tpu = None
    if tpu is not None:
      self.use_tpu = True
    self.use_int32 = use_int32
    if self.use_tpu:
      self.use_int32 = True

    self.num_instances_ = num_instances
    self.files_ = files

    # self.use_post_decode = use_post_decode
    self.recount = recount or FLAGS.recount_tfrecords

    assert self.subset in ['train', 'valid', 'test'], \
          'subset is {} but should in [train, valid, test]'.format(self.subset)

  @staticmethod
  def get_filenames_(subset=None, shuffle=False):
    try:
      if subset in ['train', 'valid', 'test']:
        if subset == 'train':
          return FLAGS.train_files
        elif subset == 'valid':
          return FLAGS.valid_files
      else:
        raise ValueError('Invalid data subset "%s"' % subset)
    except Exception:
      return None

  def get_filenames(self, subset=None, shuffle=False):
    subset = subset or self.subset
    return TfrecordsDataset.get_filenames_(subset, shuffle=False)

  def basic_parse(self, example):
    self.auto_parse(keys=self.incl_keys, exclude_keys=self.excl_keys)
    if self.varlen_keys:
      self.adds_varlens(self.varlen_keys)
    fe = self.parse_(serialized=example)
    return fe
  
  # override this
  def parse(self, example):
    return self.basic_parse(example)

  def decode(self, example):
    l = self.parse(example)
    
    if isinstance(l, (list, tuple)):
      features = l[0]
    else:
      features = l
    # self.use_tpu = True
    if isinstance(features, dict):
      if self.use_tpu:
        def decode_label(label):
          label = tf.io.decode_raw(label, tf.uint8)  # tf.string -> [tf.uint8]
          label = tf.reshape(label, [])  # label is a scalar
          return tf.cast(label, tf.int32) 
        for key in features.keys():
          if features[key].dtype in [tf.int64, tf.uint8, tf.uint16, tf.uint32]:
            features[key] = tf.cast(features[key], tf.int32)

        if not self.is_info:
          keys = list(features.keys())
          for key in keys:
            if features[key].dtype ==tf.string:
              del features[key]

              if key in self.eval_keys:
                FLAGS.use_info_dataset = True  # 因为训练model的dataset不再包含eval的某个信息 需要依赖再遍历一遍info_dataset
              # features[key] = tf.ones_like(features[key], tf.int32)
              # features[key] = decode_label(features[key]) ## not work TODO

      else:
        def _cast_dict(features):
          for key in features:
            if isinstance(features[key], dict):
              _cast_dict(features[key])
            else:
              # tf.print(key, features[key])
              if features[key].dtype == tf.int64 and self.use_int32:
                features[key] = tf.cast(features[key], tf.int32)
        _cast_dict(features)
 
      # is_info 只在tf2 keras模式下生效, 都创建 但是有可能不用 只有 FLAGS.use_info_dataset = True 才使用
      if self.is_info:
        keys = list(features.keys())
        if not FLAGS.predict_on_batch:
          if not self.eval_keys:
            for key in keys:
              dim = 1 if self.batch_parse else 0
              if not (len(features[key].shape) == dim or features[key].shape[dim] == 1):
                del features[key]
              else:
                self.show_keys.add(key)
          else:
            for key in keys:
              if key not in self.eval_keys:
                del features[key]
      else:
        keys = list(features.keys())
        for key in keys:
          if key in self.excl_keys:
            del features[key]

    return l

  def adjust(self, result):
    return result

  def parse_(self, serialized, features=None):
    features = features or self.features_dict
    # ic(features)
    features = self.parse_fn(serialized=serialized, features=features)
    if FLAGS.exclude_varlen_keys:
      from tensorflow.python.framework.sparse_tensor import SparseTensor
      sparse_keys = [key for key in features if isinstance(key, SparseTensor)]
      for key in sparse_keys:
        del features[key]
    else:
      if self.sparse_to_dense:
        modified = sparse2dense(features, default_value=FLAGS.padding_idx)
        self.has_varlen_feats = modified
    self.features = features
    return features
  
  def gen_example(self, files=None):
    if not files:
      files = self.get_filenames()
    if not isinstance(files, (list, tuple)):
      files = [files]
    example = {}
    if files:
      for file in files:
        try:
          example = first_example(file)
        except Exception:
          ic(traceback.format_exc())
          ic('bad tfrecord:', file)
        if example:
          self.example = example
          break
    self.example = example
    return example

  def gen_input(self, files=None):
    example = self.gen_example().copy()
    for key in example:
      example[key] = np.asarray([example[key]])
    return example

  def first_input(self, files=None):
    return self.gen_input(files)

  def add(self, key, dtype=None, length=None, features_dict=None):
    features_dict = features_dict or self.features_dict
    dtype_ = dtype
    if key in self.example:
      dtype = dtype_ or self.example[key].dtype 
      if length is None:
        features_dict[key] = tf.io.VarLenFeature(dtype)
      elif length > 0:
        features_dict[key] = tf.io.FixedLenFeature([length], dtype)
      else:
        features_dict[key] = tf.io.FixedLenFeature([], dtype)
    
  def adds(self, keys, dtype=None, length=None, features_dict=None):
    features_dict = features_dict or self.features_dict
    dtype_ = dtype
    for key in keys:
      if key in self.example:
        dtype = dtype_ or self.example[key].dtype 
        if length is None:
          features_dict[key] = tf.io.VarLenFeature(dtype)
        elif length > 0:
          features_dict[key] = tf.io.FixedLenFeature([length], dtype)
        else:
          features_dict[key] = tf.io.FixedLenFeature([], dtype)

  def auto_parse(self, keys=[], exclude_keys=[], features_dict=None):
    keys = keys or FLAGS.dataset_keys or self.example.keys()
    exclude_keys = exclude_keys or FLAGS.dataset_excl_keys
    keys = [key for key in keys if key not in exclude_keys]

    for key in keys:
      if key not in self.example:
        continue
      length = self.example[key].shape[0]
      
      if length == 1:
        # just to (bs,), tf keras will auto change to (bs,1), also for string 0 is ok
        length = 0 

      dtype = npdtype2tfdtype(self.example[key].dtype)
      # print(key, dtype, length, self.example[key])
      self.adds([key], dtype, length, features_dict)

  def adds_varlens(self, keys=[], exclude_keys=[], features_dict=None):
    keys = keys or self.example.keys()
    keys = [key for key in keys if key not in exclude_keys]

    for key in keys:
      if not key in self.example:
        continue
      length = self.example[key].shape[0]
      dtype = npdtype2tfdtype(self.example[key].dtype)
      length = None
      if dtype == tf.string:
        length = 1
      self.adds([key], dtype, length, features_dict)  
  
  def make_batch(self, 
                 batch_size=None, 
                 filenames=None,
                 subset=None,
                 initializable=False,
                 repeat=None,
                 shuffle=None,
                 return_iterator=True,
                 hvd_shard=None,
                 simple_parse=False,
                 num_epochs=None,
                 cache=False,
                 cache_file='',
                 buffer_size=None,
                 batch_sizes=None,
                 buckets=None,
                 drop_remainder=None,
                 world_size=1,
                 rank=0,
                 shard_by_files=None,
                 distribute_strategy=None,
                 return_numpy=False):
    # with tf.device('/cpu:0'):
    subset = subset or self.subset
    hvd_shard = hvd_shard if hvd_shard is not None else self.hvd_shard
    if batch_size is None:
      is_test = True
    else:
      is_test = False
    batch_size = batch_size or self.batch_size
    self.batch_size = batch_size
    batch_sizes = batch_sizes if batch_sizes is not None else FLAGS.batch_sizes
    buffer_size = buffer_size if buffer_size is not None else FLAGS.buffer_size
    buckets = buckets if buckets is not None else FLAGS.buckets
    drop_remainder = drop_remainder if drop_remainder is not None else FLAGS.drop_remainder
    shard_by_files = shard_by_files if shard_by_files is not None else FLAGS.shard_by_files

    self.return_numpy = return_numpy

    filenames = filenames or self.files_ or self.get_filenames(subset)
    
    self.gen_example(filenames)

    is_eager = tf.executing_eagerly()

    self.files_ = filenames

    self.indexes[self.subset] += 1
    
    if repeat is None:
      num_gpus = 1
      # if subset == 'train' or num_gpus > 1:
      if subset == 'train':
        repeat = True
      else:
        repeat = False
      if is_eager and num_gpus == 1 and tf.__version__ < '2':
        # let tf eager similary to pytorch
        repeat = False

    if shuffle is None:
      if subset == 'train':
        shuffle = FLAGS.shuffle 
      else:
        shuffle = FLAGS.shuffle_valid 

    if drop_remainder is None:
      if subset == 'train':
        drop_remainder = True
      else:
        drop_remainder = False

    balance_pos_neg=False
    ic(self.subset, repeat, drop_remainder)

    seed = FLAGS.seed 
    if seed is not None:
      FLAGS.seed += 1

    ## put on cpu or dummy
    with tf.device('/cpu'):
      result = gen_inputs(
        filenames, 
        decode_fn=self.decode,
        batch_size=batch_size,
        post_decode_fn=self.post_decode if hasattr(self, 'post_decode') and self.use_post_decode != False else None,
        shuffle=shuffle,
        shuffle_batch=FLAGS.shuffle_batch,
        shuffle_files=FLAGS.shuffle_files,
        ordered=FLAGS.dataset_ordered if subset == 'train' else True,
        num_threads=FLAGS.num_dataset_threads,
        buffer_size=buffer_size,
        num_prefetch_batches=FLAGS.num_prefetch_batches,
        initializable=initializable,
        repeat=repeat,
        repeat_then_shuffle=FLAGS.repeat_then_shuffle,
        drop_remainder=drop_remainder,
        bucket_boundaries=buckets,
        bucket_batch_sizes=batch_sizes,
        length_index=FLAGS.length_index,
        length_key=FLAGS.length_key,
        seed=seed,
        return_iterator=return_iterator,
        filter_fn=self.filter_fn,  # inside filter_fn judge subset train or valid or test
        balance_pos_neg=balance_pos_neg,
        pos_filter_fn=self.pos_filter_fn if subset == 'train' else None,
        neg_filter_fn=self.neg_filter_fn if subset == 'train' else None,
        count_fn=self.count_fn if subset == 'train' else None,
        name=subset,
        Dataset=self.Type,
        batch_parse=self.batch_parse,
        hvd_shard=hvd_shard,
        shard_by_files=shard_by_files,
        training=subset == 'train',
        simple_parse=simple_parse,
        num_epochs=num_epochs,
        dynamic_pad=FLAGS.dynamic_pad, #如果有varlen feats才需要 padded_batch 同时batch_parse模式其实也不需要因为sparse2dense就可以自动padd
        cache=cache,
        cache_file=cache_file,
        cache_after_map=FLAGS.cache_after_map,
        device='/gpu:0',
        world_size=world_size,
        rank=rank,
        fixed_random=FLAGS.fixed_random,
        parallel_read_files=FLAGS.parallel_read_files,
        #use_feed_dict=FLAGS.train_loop and FLAGS.rounds > 1 and not is_eager and FLAGS.feed_dataset and tf.__version__ < '2',
        feed_name=f'{self.subset}_{self.indexes[self.subset]}' if not is_test else None,
        padding_values=FLAGS.padding_idx, 
        distribute_strategy=distribute_strategy,
        torch=FLAGS.torch,
        keras=FLAGS.keras,
        subset=self.subset,
        return_numpy=return_numpy,
        ) 
      
    result = self.adjust(result)
    return result
    
  @staticmethod
  def num_examples_per_epoch(subset, dir=None):
    if subset == 'train':
      num_examples = get_num_records(FLAGS.train_files)
    elif subset == 'valid':
      num_examples = get_num_records(FLAGS.valid_files)
    else:
      raise ValueError('Invalid data subset "%s"' % subset)
    
    assert num_examples
    return num_examples

  @staticmethod
  def num_examples(subset, dir=None):
    return Dataset.num_examples_per_epoch(subset, dir)

  @property
  def num_instances(self):
    if self.num_instances_:
      return self.num_instances_
    assert self.files_
    self.num_instances_ = get_num_records(self.files_, recount=self.recount)
    return self.num_instances_

  @property
  def files(self):
    return self.files_

  @property
  def records(self):
    return self.files_

  def __len__(self):
    return self.num_instances or Dataset.num_examples_per_epoch(self.subset)

  @property
  def num_steps(self):
    return -(-len(self) // self.batch_size)


# Generate means.npy and stds.npy

In [None]:
records_pattern = f'../input/3rd-place-step1-gen-tfrecords-for-train/tfrecords/train/*.tfrec'
record_files = glob.glob(records_pattern)
ic(record_files[:2])
dataset = TfrecordsDataset('valid', 
                      files=record_files, 
                      incl_keys=['frames', 'n_frames'],
                      varlen_keys=['frames'],
                          )
datas = dataset.make_batch(1024, 
                           shuffle=False, 
                           drop_remainder=False, 
                           return_numpy=True)
ic(dataset.features_dict)
ic(dataset.num_instances)
num_steps = dataset.num_steps
ic(num_steps)

In [None]:
# https://stackoverflow.com/questions/5543651/computing-standard-deviation-in-a-stream
class OnlineVariance(object):
    """
    Welford's algorithm computes the sample variance incrementally.
    """

    def __init__(self, iterable=None, ddof=1):
        self.ddof, self.n, self.mean_, self.M2 = ddof, 0, 0.0, 0.0
        if iterable is not None:
            for datum in iterable:
                self.include(datum)

    def add(self, datum):
        self.n += 1
        self.delta = datum - self.mean
        self.mean_ += self.delta / self.n
        self.M2 += self.delta * (datum - self.mean_)

    @property
    def variance(self):
        return self.M2 / (self.n - self.ddof)

    @property
    def std(self):
        return np.sqrt(self.variance)
    
    @property
    def mean(self):
      return self.mean_

In [None]:
ic(N_COLS)
means = np.zeros([N_COLS], dtype=np.float32)
stds = np.zeros([N_COLS], dtype=np.float32)
ovs = [OnlineVariance() for _ in range(N_COLS)]
# using streaming mean due to memory limit, notice since the last batch size might not be 1024, so maybe a bit different from all in cache mean results, however should not diff much
for x in tqdm(datas, total=num_steps, desc='Loop-dataset'):
  batch_frames = x['frames']
  batch_n_frames = x['n_frames']
  batch_frames = batch_frames.reshape(batch_frames.shape[0], -1, N_COLS)
  l = []
  for frames, n_frames in zip(batch_frames, batch_n_frames):
    frames = frames[:n_frames]
    l.append(frames)
  frames = np.concatenate(l)
  for col, v in enumerate(frames.reshape([-1, N_COLS]).T):
    v = v[~np.isnan(v)]
    ovs[col].add(v.astype(np.float32).mean())
        
for i, ov in enumerate(ovs):
  means[i] = ov.mean
  # very important, other wise keras and tflite results diff...
  if ov.std >= 1e-6:
    stds[i] = ov.std
  else:
    stds[i] = 1.

In [None]:
np.save(f'{FLAGS.working}/means.npy', means)
np.save(f'{FLAGS.working}/stds.npy', stds)