In [4]:
!pip install tensorflow wget



In [16]:
dataset = '/bohr/ai4spulseeis-lr97/v5'

import pickle
from os import listdir
from os.path import exists, join, splitext
import numpy as np
import tensorflow as tf
from wget import download
from tqdm import tqdm

if not exists('predictor.h5'): download('https://raw.githubusercontent.com/breadbread1984/EIS_prediction/main/predictor.h5')
if not exists('sos.npy'): download('https://raw.githubusercontent.com/breadbread1984/EIS_prediction/main/sos.npy')

In [10]:
def SelfAttention(hidden_dim = 256, num_heads = 8, use_bias = False, drop_rate = 0.1, is_causal = True):
  inputs = tf.keras.Input((None, hidden_dim)) # inputs.shape = (batch, seq, 256)
  results = tf.keras.layers.Dense(hidden_dim * 3, use_bias = use_bias)(inputs) # results.shape = (batch, seq, 3 * 256)
  results = tf.keras.layers.Reshape((-1, 3, num_heads, hidden_dim // num_heads))(results) # results.shape = (batch, seq, 3, 8, 32)
  results = tf.keras.layers.Lambda(lambda x: tf.transpose(x, (0,2,3,1,4)))(results) # results.shape = (batch, 3, 8, seq, 32)
  q, k, v = tf.keras.layers.Lambda(lambda x: (x[:,0,...], x[:,1,...], x[:,2,...]))(results) # shape = (batch, 8, seq, 32)
  qk = tf.keras.layers.Lambda(lambda x, s: tf.matmul(x[0], tf.transpose(x[1], (0,1,3,2))) * s, arguments = {'s': (hidden_dim // num_heads) ** -0.5})([q, k]) # qk.shape = (batch, 8, seq, seq)
  if is_causal:
    mask = tf.keras.layers.Lambda(lambda x: tf.expand_dims(
      tf.expand_dims(
        tf.where(
          tf.cast(tf.linalg.band_part(tf.ones((tf.shape(x)[2],tf.shape(x)[2])), -1, 0), dtype = tf.bool),
          tf.constant(0., dtype = tf.float32), tf.experimental.numpy.finfo(tf.float32).min),
        axis = 0),
      axis = 0))(k) # mask.shape = (1,1,seq,seq)
    qk = tf.keras.layers.Add()([qk, mask])
  attn = tf.keras.layers.Softmax(axis = -1)(qk)
  attn = tf.keras.layers.Dropout(rate = drop_rate)(attn)
  qkv = tf.keras.layers.Lambda(lambda x: tf.transpose(tf.matmul(x[0], x[1]), (0,2,1,3)))([attn, v]) # qkv.shape = (batch, seq, 8, 32)
  qkv = tf.keras.layers.Reshape((-1, hidden_dim))(qkv) # qkv.shape = (batch, seq, 256)
  results = tf.keras.layers.Dense(hidden_dim, use_bias = use_bias)(qkv)
  results = tf.keras.layers.Dropout(drop_rate)(results)
  return tf.keras.Model(inputs = inputs, outputs = results)

def TransformerEncoder(dict_size = 1024, hidden_dim = 256, num_heads = 8, use_bias = False, layers = 2, drop_rate = 0.1):
  inputs = tf.keras.Input((None, hidden_dim)) # inputs.shape = (batch, seq, hidden_dim)
  results = inputs
  for i in range(layers):
    skip = results
    results = tf.keras.layers.LayerNormalization()(results)
    results = SelfAttention(hidden_dim, num_heads, use_bias, drop_rate, is_causal = False)(results)
    results = tf.keras.layers.Add()([skip, results])
    skip = results
    results = tf.keras.layers.LayerNormalization()(results)
    results = tf.keras.layers.Dense(4 * hidden_dim, activation = tf.keras.activations.gelu)(results)
    results = tf.keras.layers.Dense(hidden_dim)(results)
    results = tf.keras.layers.Dropout(drop_rate)(results)
    results = tf.keras.layers.Add()([skip, results])
  return tf.keras.Model(inputs = inputs, outputs = results)

def CrossAttention(hidden_dim = 256, num_heads = 8, use_bias = False, drop_rate = 0.1):
  code = tf.keras.Input((None, hidden_dim)) # code.shape = (batch, seq, 256)
  inputs = tf.keras.Input((None, hidden_dim)) # inputs.shape = (batch, seq, 256)
  code_results = tf.keras.layers.Dense(hidden_dim * 2, use_bias = use_bias)(code) # code_results.shape = (batch, seq, 2 * 256)
  code_results = tf.keras.layers.Reshape((-1, 2, num_heads, hidden_dim // num_heads))(code_results) # code_results.shape = (batch, seq, 2, 8, 32)
  code_results = tf.keras.layers.Lambda(lambda x: tf.transpose(x, (0,2,3,1,4)))(code_results) # code_results.shape = (batch, 2, 8, seq, 32)
  k, v = tf.keras.layers.Lambda(lambda x: (x[:,0,...], x[:,1,...]))(code_results) # shape = (batch, 8, seq, 32)
  results = tf.keras.layers.Dense(hidden_dim, use_bias = use_bias)(inputs) # results.shape = (batch, seq, 256)
  results = tf.keras.layers.Reshape((-1, num_heads, hidden_dim // num_heads))(results) # results.shape = (batch, seq, 8, 32)
  q = tf.keras.layers.Lambda(lambda x: tf.transpose(x, (0,2,1,3)))(results) # results.shape = (batch, 8, seq, 32)
  qk = tf.keras.layers.Lambda(lambda x, s: tf.matmul(x[0], tf.transpose(x[1], (0,1,3,2))) * s, arguments = {'s': (hidden_dim // num_heads) ** -0.5})([q, k]) # qk.shape = (batch, 8, seq, seq)
  attn = tf.keras.layers.Softmax(axis = -1)(qk)
  attn = tf.keras.layers.Dropout(rate = drop_rate)(attn)
  qkv = tf.keras.layers.Lambda(lambda x: tf.transpose(tf.matmul(x[0], x[1]), (0,2,1,3)))([attn, v]) # qkv.shape = (batch, seq, 8, 32)
  qkv = tf.keras.layers.Reshape((-1, hidden_dim))(qkv) # qkv.shape = (batch, seq, 256)
  results = tf.keras.layers.Dense(hidden_dim, use_bias = use_bias)(qkv)
  results = tf.keras.layers.Dropout(drop_rate)(results)
  return tf.keras.Model(inputs = (code, inputs), outputs = results) 

def TransformerDecoder(dict_size = 1024, hidden_dim = 256, num_heads = 8, use_bias = False, layers = 2, drop_rate = 0.1):
  code = tf.keras.Input((None, hidden_dim)) # code.shape = (batch, seq, 256)
  inputs = tf.keras.Input((None, hidden_dim)) # inputs.shape = (batch, seq, hidden_dim)
  results = inputs
  for i in range(layers):
    skip = results
    results = tf.keras.layers.LayerNormalization()(results)
    results = SelfAttention(hidden_dim, num_heads, use_bias, drop_rate, is_causal = True)(results)
    results = tf.keras.layers.Add()([skip, results])
    skip = results
    results = tf.keras.layers.LayerNormalization()(results)
    results = CrossAttention(hidden_dim, num_heads, use_bias, drop_rate)([code, results])
    results = tf.keras.layers.Add()([skip, results])
    skip = results
    results = tf.keras.layers.LayerNormalization()(results)
    results = tf.keras.layers.Dense(4 * hidden_dim, activation = tf.keras.activations.gelu)(results)
    results = tf.keras.layers.Dense(hidden_dim)(results)
    results = tf.keras.layers.Dropout(drop_rate)(results)
    results = tf.keras.layers.Add()([skip, results])
  return tf.keras.Model(inputs = (code, inputs), outputs = results)

def Trainer(dict_size = 1024, hidden_dim = 256, num_heads = 8, use_bias = False, layers = 1, drop_rate = 0.1):
  pulse = tf.keras.Input((None,2))
  eis = tf.keras.Input((None,2))

  pulse_embed = tf.keras.layers.Dense(hidden_dim)(pulse)
  eis_embed = tf.keras.layers.Dense(hidden_dim)(eis)

  code = TransformerEncoder(dict_size, hidden_dim, num_heads, use_bias, layers, drop_rate)(pulse_embed) # code.shape = (batch, pulse_seq, 256)
  results = TransformerDecoder(dict_size, hidden_dim, num_heads, use_bias, layers, drop_rate)([code, eis_embed]) # results.shape = (batch, eis_seq, 256)
  eis_update = tf.keras.layers.Dense(2)(results) # eis_tokens.shape = (batch, eis_seq, 2)
  return tf.keras.Model(inputs = (pulse, eis), outputs = (eis_update))


In [None]:
trainer = Trainer()
trainer.load_weights('predictor.h5')
sos = tf.constant(np.load('sos.npy'))

output = open('submission.csv','w')
output.write('test_data_number,SOC(%),EIS_real,EIS_imaginary\n')
for f in listdir(join(dataset, 'test_datasets')):
  stem, ext = splitext(f)
  if ext != '.pkl': continue
  test_num = int(stem.replace('test_pulse_', ''))
  with open(join(dataset, 'test_datasets', f), 'rb') as f:
    data = pickle.load(f)
  for SOC, pulse_samples in tqdm(data.items()):
    soc = SOC.replace('%SOC','')
    pulse = tf.expand_dims(tf.stack([pulse_samples['Voltage'], pulse_samples['Current']], axis = -1), axis = 0) # pulse.shape = (1, seq, 2)
    eis = tf.tile(sos, (pulse.shape[0], 1, 1))
    for i in range(51):
      pred = trainer([pulse, eis])
      eis = tf.concat([eis, pred[:,-1:,:]], axis = -2)
    eis = eis[:,1:,:][0]
    for e in eis:
      output.write(','.join([str(test_num),soc,str(e[0].numpy().item()),str(e[1].numpy().item())]) + '\n')
output.close()


100%|████████████████████████████████████████████████████████████████████████████| 49/49 [02:09<00:00,  2.64s/it]
 76%|█████████████████████████████████████████████████████████▍                  | 37/49 [01:37<00:31,  2.59s/it]