In [None]:
import os
import glob
import numpy as np
import tensorflow as tf
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import scipy

# pipeline 1


In [None]:
tf.random.set_seed(42)
files = glob.glob(os.path.join('train/', '*/*'))
stars = []
for file in files:
    file_name = file.split('\\')[1]
    stars.append(file_name)
stars = np.unique(stars)

import random
random.seed(42)

def split_star_list(file_list, test_ratio=0.1):
    random.shuffle(file_list)
    split_index = int(len(file_list) * (1 - test_ratio))
    train_files = file_list[:split_index]
    test_files = file_list[split_index:]
    return train_files, test_files

train_stars, test_stars = split_star_list(stars)

labelDf = pd.read_csv("train_labels.csv")
labelDf = labelDf.set_index('planet_id')
meanLabels = np.mean(labelDf.mean())
stdLabels = np.std(labelDf.std())
maxLabels = np.max(labelDf.max())
minLabels = np.min(labelDf.min())

trainLabels = labelDf.loc[[int(star) for star in train_stars]]
meanTrainLabels = np.mean(trainLabels.mean())
stdTrainLabels = np.std(trainLabels.std())
maxTrainLabels = np.max(trainLabels.max())
minTrainLabels = np.min(trainLabels.min())

for col in labelDf.columns:
    labelDf.loc[:,col] = (labelDf[col]) / (maxTrainLabels)

# normalize over time and all samples, so we have a mean and a std dev per wavelength for all samples
def calcMeanAndStdOfTrain(train_stars):
    i = 0
    for star in train_stars:
        file_path = 'train/'+str(star)+'/combined.npz'
        with np.load(file_path) as data:
            x = data['a'][0,:,0:283,1]
            x = x / np.max(x,axis=0)
            if i ==0:
                mean = np.mean(x,axis=(0))
                sumS = np.sum(x**2,axis=0)
            else:
                mean = mean + np.mean(x, axis=(0))
                sumS += np.sum(x**2,axis=0)
            i=i+1
    meanTrain = mean / i
    stdTrain = np.sqrt(sumS / (i*x.shape[0]) - meanTrain**2)    
    return meanTrain, stdTrain
meanTrain, stdTrain = calcMeanAndStdOfTrain(train_stars)

def normalize_over_train(features, labels):
    #features = (features - meanTrain) / (stdTrain + 1e-6)
    features = (features) / (stdTrain + 1e-6)
    return features, labels

def calcRollingMeans(windows, features):
    arr=[features]
    for window in windows:
        padded_data = np.pad(features, ((window, 0),(0,0)), mode='edge')
        cumsum = np.cumsum(padded_data, axis=0)
        result = (cumsum[window:,:] - cumsum[:-window,:]) / window
        arr.append(result)

        # calc diff of rolling mean & compensate rolling mean with it
        diff = np.diff(result,axis=0)
        mean_lin_slope = np.mean(diff[window:,:],axis=0)
        accumulatedSlopes = np.cumsum(np.ones_like(result)*mean_lin_slope, axis=0)
        compensatedSignal = result - accumulatedSlopes
        arr.append(compensatedSignal)
    
        # add diff shape
        diffCorr = np.zeros_like(result)
        diffCorr[1:,:] = diff
        arr.append(diffCorr)
    return np.stack(arr,axis=-1)

# normalize over time per samples, so we have a mean and a std dev per wavelength for all samples
def calcMeanAndStdOfTrainPerStar(x):
    mean = np.mean(x,axis=(0))
    sumS = np.sum(x**2,axis=0)
    stdTrain = np.sqrt(sumS / (x.shape[0]) - mean**2)    
    return mean, stdTrain
def normalize_per_sample(features, labels):
    m,s = calcMeanAndStdOfTrainPerStar(features)
    features = (features) / (s + 1e-6)
    return features, labels




def load_npz(star):
    integer_value = tf.strings.to_number(star, out_type=tf.int64)
    python_int = integer_value.numpy()

    file_path = 'train/'+str(python_int)+'/combined.npz'
    try:
        with np.load(file_path) as data:
            features = data['a'][0,:,0:283,1]
            labels = labelDf.loc[python_int].to_numpy()
            meanL = np.mean(labels)
            stdL = np.std(labels)
            labels = (labels-meanL)/ stdL #*100
            features = np.reshape(features,(-1,25,283))
            features = np.mean(features,axis=1)
            maxValpWL = np.max(features,axis=0)
            features = features / maxValpWL -1
            #features, labels = normalize_per_sample(features,labels)
            features, labels = normalize_over_train(features,labels)
            features = calcRollingMeans([30,50,80,100],features)

            features = np.transpose(features, (1, 0, 2))
            return features, labels
    except Exception as e:
        print("Error loading file:", e, python_int)


def create_dataset(star_list, batch_size, shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices(star_list)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(star_list), reshuffle_each_iteration=True)
    def load_and_process(x):
        features, labels = tf.py_function(
            func=load_npz,
            inp=[x],
            Tout=[tf.float64, tf.float32]
        )
        return features, labels

    dataset = dataset.map(load_and_process, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(lambda x, y: (tf.ensure_shape(x,tf.TensorShape([283,225,13])), tf.ensure_shape(y, tf.TensorShape(283)))) #5625
    dataset = dataset.unbatch()
    dataset = dataset.map(lambda x, y: (x, y))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(star_list)*283, reshuffle_each_iteration=True)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset


# pipeline 2 - experiments
- calc diff to mean light curve
- decompose input by mean

In [None]:

tf.random.set_seed(42)
files = glob.glob(os.path.join('train/', '*/*'))
stars = []
for file in files:
    file_name = file.split('\\')[1]
    stars.append(file_name)
stars = np.unique(stars)

import random
random.seed(42)

def split_star_list(file_list, test_ratio=0.1):
    random.shuffle(file_list)
    split_index = int(len(file_list) * (1 - test_ratio))
    train_files = file_list[:split_index]
    test_files = file_list[split_index:]
    return train_files, test_files

train_stars, test_stars = split_star_list(stars)

labelDf = pd.read_csv("train_labels.csv")
labelDf = labelDf.set_index('planet_id')
meanLabels = np.mean(labelDf.mean())
stdLabels = np.std(labelDf.std())
maxLabels = np.max(labelDf.max())
minLabels = np.min(labelDf.min())

trainLabels = labelDf.loc[[int(star) for star in train_stars]]
meanTrainLabels = np.mean(trainLabels.mean())
stdTrainLabels = np.std(trainLabels.std())
maxTrainLabels = np.max(trainLabels.max())
minTrainLabels = np.min(trainLabels.min())

for col in labelDf.columns:
    labelDf.loc[:,col] = (labelDf[col]) / (maxTrainLabels)

# normalize over time and all samples, so we have a mean and a std dev per wavelength for all samples
def maxOfTrain(train_stars):
    maxTrain = 0
    for star in train_stars:
        file_path = 'train/'+str(star)+'/combined.npz'
        with np.load(file_path) as data:
            x = data['a'][0,:,0:283,1]
            maxTrain = max(maxTrain, np.max(x))  
    return maxTrain
maxTrain= maxOfTrain(train_stars)

def calcRollingMeans(windows, features):
    arr=[features]
    for window in windows:
        padded_data = np.pad(features, ((window, 0),(0,0)), mode='edge')
        cumsum = np.cumsum(padded_data, axis=0)
        result = (cumsum[window:,:] - cumsum[:-window,:]) / window
        arr.append(result)

        # calc diff of rolling mean & compensate rolling mean with it
        diff = np.diff(result,axis=0)
        mean_lin_slope = np.mean(diff[window:,:],axis=0)
        accumulatedSlopes = np.cumsum(np.ones_like(result)*mean_lin_slope, axis=0)
        compensatedSignal = result - accumulatedSlopes
        arr.append(compensatedSignal)
    
        # add diff shape
        diffCorr = np.zeros_like(result)
        diffCorr[1:,:] = diff
        arr.append(diffCorr)
    return np.stack(arr,axis=-1)


def load_npz(star):
    integer_value = tf.strings.to_number(star, out_type=tf.int64)
    python_int = integer_value.numpy()

    file_path = 'train/'+str(python_int)+'/combined.npz'
    try:
        with np.load(file_path) as data:
            f = data['a'][0,:,0:283,1]
            labels = labelDf.loc[python_int].to_numpy()
            meanL = np.mean(labels)
            stdL = np.std(labels)
            #labels = (labels-meanL)/ stdL #*100
            f = np.reshape(f,(-1,25,283))
            f = np.mean(f,axis=1)
            maxValpWL = np.mean(f,axis=0) #bring all values to a similar scale -> differences are now comparable
            f = 100*(f / maxValpWL) # map to [-100*diff,0]
            featureMax = np.reshape(np.ones_like(f) * 100/maxValpWL, (225,283,1))
            # my network can esitmate the -100*diff and has 100/maxVal
            # lambda=(x_top-x_bottom)/x_top = x_diff / x_top = (x_diff/maxVal) / (x_top/maxVal)
            #       = 1 - x_bottom/x_top  -> scalieren 

            # x_bottom/maxVal - 1, x_top/maxVal - 1, 

            f = calcRollingMeans([30,50,80,100],f)
            f = np.concatenate([f, featureMax], axis = 2) #give info about magnitude as well
            f = np.transpose(f, (1, 0, 2))
            return f, labels
    except Exception as e:
        print("Error loading file:", e, python_int)


def create_dataset(star_list, batch_size, shuffle=True):
    dataset = tf.data.Dataset.from_tensor_slices(star_list)
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(star_list), reshuffle_each_iteration=True)
    def load_and_process(x):
        features, labels = tf.py_function(
            func=load_npz,
            inp=[x],
            Tout=[tf.float64, tf.float32]
        )
        return features, labels

    dataset = dataset.map(load_and_process, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.map(lambda x, y: (tf.ensure_shape(x,tf.TensorShape([283,225,14])), tf.ensure_shape(y, tf.TensorShape(283)))) #5625
    dataset = dataset.unbatch()
    dataset = dataset.map(lambda x, y: (x, y))
    if shuffle:
        dataset = dataset.shuffle(buffer_size=len(star_list)*283, reshuffle_each_iteration=True)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset


In [None]:
tf.random.set_seed(42)
batch_size = 283*4#*4#*10

train_dataset = create_dataset(train_stars, batch_size, shuffle=True)
test_dataset = create_dataset(test_stars, batch_size, shuffle=False)

In [None]:
small_train = create_dataset(train_stars[0:4], batch_size, shuffle=True)

# model

In [None]:
class Reshape1(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x):
        x = tf.transpose(x, perm=[0,2,1,3])
        #x = tf.reshape(x, [-1, self.timepoints, tf.cast(self.wavelengths * self.representations, tf.int32)])
        return x
    
class Reshape11(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x):
        x = tf.transpose(x, perm=[0,2,1])
        #x = tf.reshape(x, [-1, self.timepoints, tf.cast(self.wavelengths * self.representations, tf.int32)])
        return x

class Reshape2(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x_pred, x_confidence):
        x = tf.concat([x_pred, x_confidence], axis = -1)
        
        return x
    
class Reshape22(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x_pred, x_confidence):
        x_pred = tf.expand_dims(x_pred, axis=-1)
        x_confidence = tf.expand_dims(x_confidence, axis=-1)
        x = tf.concat([x_pred, x_confidence], axis = -1)
        return x
    
class Reshape3(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x):
        x = tf.reshape(x, (None,-1,x.shape[2]))
        return x
    
class reduce(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x):
        mean = tf.reduce_sum(x,axis=-1)
        mean = tf.expand_dims(mean, axis=-1)
        return mean
class reduce1(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x):
        mean = tf.reduce_sum(x,axis=-1)
        return mean
    
class tile(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x,mean):
        x = tf.concat([x,mean],axis=-1)
        return x
    
class tile2(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x,mean):
        x = tf.concat([x,tf.expand_dims(mean,axis=-1)],axis=-2)
        return x
    
class meanOfWavelengths(tf.keras.layers.Layer):
    def __init__(self, concat=True,**kwargs):
        self.concat=concat
        super().__init__(**kwargs)
    def call(self, x):
        m = tf.expand_dims(tf.reduce_mean(x,axis=-1),axis=-1)
        x = tf.concat([x,m],axis=-1)
        return x if self.concat else m

# gated linear unit, splits input in 2 batches, second batch is activation
class GLU(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def call(self, x, mask=None):
        x,gate = tf.split(x, 2, axis = -1)
        # swish = gate * sigmoid(gate) (sigmoid = between 0..1)
        x = x*tf.keras.activations.swish(gate) # use one input as a gate such that the network is able to focus on information
        return x

class GLUMlp(tf.keras.layers.Layer):
    def __init__(self, dim_expand, dim, **kwargs):
        super().__init__(**kwargs)
        self.dim_expand = dim_expand
        self.dim = dim
        # same operation as dense layer
        self.dense_1 = tf.keras.layers.EinsumDense("abc,cd->abd",output_shape=(None, self.dim_expand), activation = 'linear', bias_axes = 'd')
        self.glu_1 = GLU()
        self.dense_2 = tf.keras.layers.EinsumDense("abc,cd->abd",output_shape=(None, self.dim), activation = 'linear', bias_axes = 'd')
    def call(self, x, training = False):
        #print('glu_input',x.shape)
        x = self.dense_1(x)
        #print(x.shape)
        x = self.glu_1(x)
        #print(x.shape)
        x = self.dense_2(x)
        #print(x.shape)
        return x

class ScaleBias(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    def build(self, input_shape):
        self.scale_bias = tf.keras.layers.EinsumDense("abc,c->abc",output_shape=(None, input_shape[-1]),activation = 'linear', bias_axes = 'c')
    def call(self, x, mask=None):
        return self.scale_bias(x)

#attention gets calculated along first dimension!
class TransformerEncoder(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, feed_forward_dim):
        super().__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim//num_heads)
        self.ffn = GLUMlp(feed_forward_dim, embed_dim)
        #self.ffn = tf.keras.layers.Dense(feed_forward_dim)
        self.layer_norm_1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) # normalization by a * (input-mean) /sqrt(var + eps) +b    where a and b are learned, eps is to avoid div/0
        self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        #self.scale_bias_1 = ScaleBias()
        #self.scale_bias_2 = ScaleBias()
    def call(self, x, training = None):
        residual = x
        #print('before att')
        x = self.att(x, x)
        #x = self.scale_bias_1(x)
        x = self.layer_norm_1(x + residual)
        #x = x+residual
        residual = x
        #print('after att')
        x = self.ffn(x, training = training)
        #print('after glu')
        #x = self.scale_bias_2(x)
        x = self.layer_norm_2(x + residual)
        return x
    
# is effectively an attention mechanism to allow some columns to be used / turned off
# effective channel attention!
class ECA(tf.keras.layers.Layer):
    # TF implementation from https://www.kaggle.com/code/hoyso48/1st-place-solution-training
    def __init__(self, kernel_size=5, **kwargs):
        super().__init__(**kwargs)
        self.kernel_size = kernel_size
        self.conv = tf.keras.layers.Conv1D(1, kernel_size=kernel_size, strides=1, padding="same", use_bias=False) # only one1D convolution with kernel size
    def call(self, inputs):
        nn = tf.keras.layers.GlobalAveragePooling1D()(inputs) # works not on batch size, but on next dimension, e.g. batch_size, 60x9 -> works on 60, so output is batch_size x 9
        nn = tf.expand_dims(nn, -1) # a,c -> a,c,1
        nn = self.conv(nn) # a,c,1 -> a,c,1 (1, because conv is only having 1 filter)
        nn = tf.squeeze(nn, -1) # a,c,1 -> a,c
        nn = tf.nn.sigmoid(nn) # a,c -> a,c
        nn = nn[:,None,:] # a,1,c -> e.g. batch_size,1,9
        return inputs * nn # a,1,c * a,b,c applies broadcasting / elementwise multiplication -> turns input on or off column wise


class HeadDense(tf.keras.layers.Layer):
    def __init__(self, head_dim, **kwargs):
        super().__init__(**kwargs)
        self.head_dim = head_dim
    def build(self, input_shape):
        self.length = input_shape[1]
        self.dim = input_shape[2]
        self.dense = tf.keras.layers.EinsumDense("abc,cd->abd",output_shape=(self.length, self.head_dim), activation = 'swish', bias_axes = 'd') #siwsh is causing a self gating
    def call(self, x):
        x = self.dense(x)
        return x
    
class Conv1DBlockSqueezeformer(tf.keras.layers.Layer):
    def __init__(self, channel_size, kernel_size, dilation_rate=1,
                 expand_ratio=2, se_ratio=0.25, activation='swish', name=None, **kwargs):
        super().__init__()
        self.channel_size = channel_size
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.expand_ratio = expand_ratio
        self.se_ratio = se_ratio
        self.activation = activation
        self.scale_bias = ScaleBias()
        self.glu_layer = GLU()
        self.ffn = GLUMlp(channel_size*4, channel_size)
        self.layer_norm_2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.scale_bias_1 = ScaleBias()
        self.scale_bias_2 = ScaleBias()
    def build(self, input_shape):
        self.length = input_shape[1]
        self.channels_in = input_shape[2]
        self.channels_expand = self.channels_in * self.expand_ratio
        self.dwconv = tf.keras.layers.DepthwiseConv1D(self.kernel_size,dilation_rate=self.dilation_rate,padding='same',use_bias=False)
        self.BatchNormalization_layer = tf.keras.layers.BatchNormalization(momentum=0.95)
        self.conv_activation = tf.keras.layers.Activation(self.activation)
        self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.ECA_layer = ECA() #convolutional attention
        self.expand = tf.keras.layers.EinsumDense("abc,cd->abd",output_shape=(self.length, self.channels_expand), activation = 'linear', bias_axes = 'd')
        self.project =tf.keras.layers.EinsumDense("abc,cd->abd",output_shape=(self.length, self.channel_size), activation = 'linear', bias_axes = 'd')
    def call(self, x, training = None):
        skip = x
        #print(x.shape)
        x = self.expand(x) #dense layer expands time dimension
        #print(x.shape)
        x = self.glu_layer(x) # gating of input through linear gating unit, 2 halfs, second half = activation of first(=input)
        #print('glu',x.shape)
        x = self.dwconv(x)
        #print('conv filter',x.shape)
        x = self.BatchNormalization_layer(x)
        #print('batchnorm',x.shape)
        x = self.conv_activation(x)
        #print('activation f',x.shape)
        x = self.ECA_layer(x) #conv attention
        #print('eca',x.shape)
        x = self.project(x)
        #print(x.shape)
        x = self.scale_bias_1(x)
        #print(x.shape)

        x = x+skip

        residual = x
        x = self.ffn(x) # ff + gate
        x = self.scale_bias_2(x)
        x = self.layer_norm_2(x + residual)
        return x


In [None]:

timepoints = 225
representations = 4
wavelengths = 283
targetWavelengths = 283

class transf1d(tf.keras.layers.Layer):
    def __init__(self, inputDim, num_heads, feed_forward_dim, reshape=True):
        super().__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=inputDim//num_heads)
        #self.ffn = GLUMlp(feed_forward_dim, embed_dim)
        self.ffn2 = tf.keras.layers.Dense(feed_forward_dim)
        self.reshape1 = Reshape11()
        self.reshape2 = Reshape11()
        self.reshape = reshape
    def call(self, x, training = None):
        residual = x
        x = self.att(x,x)
        x = x + residual
        if self.reshape:
            x = self.reshape1(x)
        #x = self.ffn(x)
        x = self.ffn2(x)
        #x = self.reshape2()(x)
        return x
class att1d(tf.keras.layers.Layer):
    def __init__(self, embed_dim, num_heads, feed_forward_dim):
        super().__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim//num_heads)
    def call(self, x, training = None):
        residual = x
        x = self.att(x,x)
        x = x + residual
        return x
    
class custConv1dDepthwise(tf.keras.layers.Layer):
    def __init__(self, arr, meanInit=True):
        super().__init__()
        self.arr = arr
        self.maxSlidingWindow = max(self.arr)
        if meanInit:
            self.convArr = [tf.keras.layers.DepthwiseConv1D(kernel_size=kernelS,strides=1,padding='valid', depth_multiplier=1,activation='linear', depthwise_initializer=tf.keras.initializers.Constant(1.0 / kernelS)) for kernelS in self.arr]
        else:
            self.convArr = [tf.keras.layers.DepthwiseConv1D(kernel_size=kernelS,strides=1,padding='valid', depth_multiplier=1) for kernelS in self.arr]
    def call(self, x, training = None):
        #print(x.shape)
        outDim = x.shape[1] - self.maxSlidingWindow +1
        out=[]
        for i in range(len(self.arr)):  
            #print(x.shape)
            x0=self.convArr[i](x)
            thisOutDim = x.shape[1] - self.arr[i] +1 
            startIdx = int((thisOutDim - outDim)/2)
            x0 = x0[:,startIdx:startIdx+outDim,:]
            out.append(x0)
        x = tf.keras.layers.Concatenate(axis=-1)(out)
        #print('out',x.shape)
        return x
    
class custConv1d(tf.keras.layers.Layer):
    def __init__(self, arr, meanInit=True, filters=8):
        super().__init__()
        self.arr = arr
        self.maxSlidingWindow = max(self.arr)
        if meanInit:
            self.convArr = [tf.keras.layers.Conv1D(filters=filters, kernel_size=(kernelS), padding='valid', kernel_initializer=tf.keras.initializers.Constant(1.0 / kernelS)) for kernelS in self.arr]
        else:
            self.convArr = [tf.keras.layers.Conv1D(filters=filters, kernel_size=(kernelS), padding='valid') for kernelS in self.arr]
    def call(self, x, training = None):
        outDim = x.shape[-2] - self.maxSlidingWindow +1
        out=[]
        for i in range(len(self.arr)):  
            x0=self.convArr[i](x)
            thisOutDim = x.shape[-2] - self.arr[i] +1 
            startIdx = int((thisOutDim - outDim)/2)
            x0 = x0[:,startIdx:startIdx+outDim,:]
            out.append(x0)
        x = tf.keras.layers.Concatenate(axis=-1)(out)
        return x

 
# Batch norm causes massive difference between train and inference time!!!!
# during training mean & std dev of batch are used to normalize
# during inference a rolling mean&stdDev is used for normalization
def cnn():
    inp = tf.keras.Input(shape=(timepoints, wavelengths))
    x = inp[:,:,0:1]
    timeP = timepoints
    #x = custConv1d([10,30,50])(x)
    x = custConv1d([50], meanInit=False, filters=32)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=64)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=128)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=256)(x)
    x = tf.keras.layers.BatchNormalization()(x)

    x = Reshape11()(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.Dense(1000, activation='relu')(x)
    x = tf.keras.layers.Dense(500, activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(100, activation='relu')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

# with layer normalization we don't have so much overfitting, however we can only fit the mean
def cnn2():
    inp = tf.keras.Input(shape=(timepoints, wavelengths))
    x = inp[:,:,0:1]
    timeP = timepoints
    #x = custConv1d([10,30,50])(x)
    x = custConv1d([50], meanInit=False, filters=32)(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=64)(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=128)(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=256)(x)
    x = tf.keras.layers.LayerNormalization()(x)

    x = Reshape11()(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.Dense(1000, activation='relu')(x)
    x = tf.keras.layers.Dense(500, activation='relu')(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Dense(100, activation='relu')(x)
    x = tf.keras.layers.LayerNormalization()(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

# can't fit due to vanishing gradients
def cnn3():
    inp = tf.keras.Input(shape=(timepoints, wavelengths))
    x = inp[:,:,0:1]
    timeP = timepoints
    #x = custConv1d([10,30,50])(x)
    x = custConv1d([50], meanInit=False, filters=32)(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=32)(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=32)(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x = custConv1d([50], meanInit=False, filters=32)(x)
    #x = tf.keras.layers.LayerNormalization()(x)

    x = Reshape11()(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.Dense(1000, activation='relu')(x)
    x = tf.keras.layers.Dense(500, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Dense(100, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

# we can fit the small dataset, but overfit quite heavily
def cnn4():
    inp = tf.keras.Input(shape=(timepoints, wavelengths))
    x = inp[:,:,0:1]
    timeP = timepoints
    x = custConv1d([50], meanInit=False, filters=8)(x)
    for i in range(2):
        x = custConv1d([50], meanInit=False, filters=32* 2**(i))(x)
        #x = tf.keras.layers.MaxPool1D()(x)

    x = tf.keras.layers.Dense(1)(x) # collapse filters to 1 signal -> helps
    x = Reshape11()(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.Dense(1000, activation='relu')(x)
    #x = tf.keras.layers.Dense(500, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Dense(20, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

def cnn5():
    inp = tf.keras.Input(shape=(timepoints, wavelengths))
    x = inp[:,:,0:1]
    timeP = timepoints
    x = custConv1d([50], meanInit=True, filters=1)(x)
    for i in range(2):
        x = custConv1d([50], meanInit=False, filters=8* 2**(i))(x)
        #x = tf.keras.layers.MaxPool1D()(x)

    x = tf.keras.layers.Dense(1)(x) # collapse filters to 1 signal -> helps
    #x = Reshape11()(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.Dense(1000, activation='relu')(x)
    #x = tf.keras.layers.Dense(500, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Dense(20, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

# network is capable, but doesn't generalize well -> add more features to it
def cnn6():
    inp = tf.keras.Input(shape=(timepoints, wavelengths))
    x = inp[:,:,0:1]
    timeP = timepoints
    x = custConv1d([50,40,30,20,10], meanInit=True, filters=1)(x)
    for i in range(2):
        x = custConv1d([5,10,3], meanInit=False, filters=8* 2**(i))(x)
        x = tf.keras.layers.MaxPool1D()(x)

    x = tf.keras.layers.Dense(1)(x) # collapse filters to 1 signal -> helps
    #x = Reshape11()(x)
    #x = GLUMlp(dim_expand=x.shape[2]*2,dim=x.shape[2])(x) # select timepoints
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.Dense(1000, activation='relu')(x)
    #x = tf.keras.layers.Dense(500, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    for _ in range(3):
        x = tf.keras.layers.Dense(37, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

embed = 13
def cnn7():
    inp = tf.keras.Input(shape=(timepoints, embed))
    x = inp[:,:,1:13]
    x = tf.keras.layers.LayerNormalization()(x)
    for i in range(5):
        x = custConv1d([40], meanInit=False, filters=64)(x)
        #x = tf.keras.layers.LayerNormalization()(x)
        #x = tf.keras.layers.Dense(50)(x)
        #reduce time dimension
        #x = Reshape11()(x)
        #x = tf.keras.layers.Dense(x.shape[2]-20)(x)
        #x = Reshape11()(x)
        #x = tf.keras.layers.MaxPool1D()(x)

    x = tf.keras.layers.Dense(64)(x) # collapse filters to 1 signal -> helps
    #x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Dense(30)(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    #x = Reshape11()(x)
    #x = GLUMlp(dim_expand=x.shape[2]*2,dim=x.shape[2])(x) # select timepoints
    x = tf.keras.layers.Flatten()(x)
    for i in range(3):
        x = tf.keras.layers.Dense(1000, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.Dense(300, activation='relu')(x)
    x = tf.keras.layers.Dense(30, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    #for _ in range(3):
    #    x = tf.keras.layers.Dense(37, activation='relu')(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

embed = 13
def cnn8():
    inp = tf.keras.Input(shape=(timepoints, embed))
    x = inp[:,:,1:13]
    #x = tf.keras.layers.LayerNormalization()(x)
   # for i in range(5):
   #     x = custConv1d([5], meanInit=False, filters=10)(x)
   # x = custConv1d([5], meanInit=False, filters=5)(x)
    x = tf.keras.layers.Flatten()(x)
    for i in range(1):
        x = tf.keras.layers.Dense(700, activation='relu')(x)
    x = tf.keras.layers.Dense(100)(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

embed = 13
def cnn9():
    inp = tf.keras.Input(shape=(timepoints, embed))
    x = inp#[:,:,1:13]
    #x = tf.keras.layers.LayerNormalization()(x)
    for i in range(3):
        x = transf1d(inputDim=225,embed_dim=13, num_heads=4, feed_forward_dim=2*13)(x)
    #x = custConv1d([50], meanInit=False, filters=5)(x)
    x = tf.keras.layers.Dense(1)(x)
    x = tf.keras.layers.Flatten()(x)
    #for i in range(1):
    #    x = tf.keras.layers.Dense(700, activation='relu')(x)
    x = tf.keras.layers.Dense(100)(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

embed = 13
def cnn10():
    inp = tf.keras.Input(shape=(timepoints, embed))
    x = inp#[:,:,1:13]
    #x = tf.keras.layers.LayerNormalization()(x)
    for i in range(3):
        x = transf1d(inputDim=225, num_heads=4, feed_forward_dim=2*13, reshape=False)(x)
    #x = custConv1d([50], meanInit=False, filters=5)(x)
    x = tf.keras.layers.Dense(1)(x)
    x = tf.keras.layers.Flatten()(x)
    #for i in range(1):
    #    x = tf.keras.layers.Dense(700, activation='relu')(x)
    #x = tf.keras.layers.Dense(100)(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model

def cnn11():
    inp = tf.keras.Input(shape=(timepoints, 14))
    x = inp[:,:,1:14]
    #x = tf.keras.layers.LayerNormalization()(x)
    out=[]
    #for i in range(x.shape[2]):
    for i in range(4):
    #    x = transf1d(inputDim=225, num_heads=4, feed_forward_dim=2*13, reshape=False)(x)
        x = custConv1d([50], meanInit=False, filters=4)(x)
        #y = custConv1dDepthwise([50], meanInit=False)(x[:,:,i:i+1])
        #y = custConv1dDepthwise([50], meanInit=False)(y)
        #y = custConv1dDepthwise([50], meanInit=False)(y)
        #out.append(y)
    #x = tf.keras.layers.Concatenate(axis=-1)(out)


    #x = custConv1d([50], meanInit=False, filters=1)(x)
    x = tf.keras.layers.Dense(2)(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    #for i in range(1):
    #x = tf.keras.layers.Dense(700, activation='relu')(x)
    x = tf.keras.layers.Dense(50)(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

def cnn12():
    inp = tf.keras.Input(shape=(timepoints, 14))
    x = inp[:,:,1:14]
    #x = tf.keras.layers.LayerNormalization()(x)
    out=[]
    #for i in range(x.shape[2]):
    for i in range(4):
    #    x = transf1d(inputDim=225, num_heads=4, feed_forward_dim=2*13, reshape=False)(x)
        x = custConv1d([50], meanInit=False, filters=4)(x)
        #y = custConv1dDepthwise([50], meanInit=False)(x[:,:,i:i+1])
        #y = custConv1dDepthwise([50], meanInit=False)(y)
        #y = custConv1dDepthwise([50], meanInit=False)(y)
        #out.append(y)
    #x = tf.keras.layers.Concatenate(axis=-1)(out)

    x = custConv1d([20], meanInit=False, filters=4)(x)
    x = custConv1d([5], meanInit=False, filters=4)(x)
    #x = custConv1d([50], meanInit=False, filters=1)(x)
    x = tf.keras.layers.Dense(2)(x)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    #for i in range(1):
    #x = tf.keras.layers.Dense(700, activation='relu')(x)
    #x = tf.keras.layers.Dense(50)(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

def cnn13():
    inp = tf.keras.Input(shape=(timepoints, 14))
    x = inp[:,:,1:14]
    #x = tf.keras.layers.LayerNormalization()(x)
    out=[]
    #for i in range(x.shape[2]):
    for i in range(1):
    #    x = transf1d(inputDim=225, num_heads=4, feed_forward_dim=2*13, reshape=False)(x)
        x = custConv1d([50], meanInit=False, filters=1)(x)
        #y = custConv1dDepthwise([50], meanInit=False)(x[:,:,i:i+1])
        #y = custConv1dDepthwise([50], meanInit=False)(y)
        #y = custConv1dDepthwise([50], meanInit=False)(y)
        #out.append(y)
    x = tf.keras.layers.Flatten()(x)
    #x = tf.keras.layers.LayerNormalization()(x)
    #for i in range(1):
    #x = tf.keras.layers.Dense(700, activation='relu')(x)
    x = tf.keras.layers.Dense(50)(x)
    x_pred = tf.keras.layers.Dense(1, activation='linear')(x)

    model = tf.keras.Model(inp, x_pred)
    return model  

#model = cnn2D() 
#model= squeezeformer()
#model = fcn() 
#model = cnn1() 
#model = transformer()
#model = cnnAttentin()
#model = singleWL()
model=cnn13()
#model=cnn5converges()
model.summary()

In [None]:
model.layers[4].get_weights()

In [None]:
model.save_weights('fullPred_singleWL_cnn12_010_021_newDataPipeline.weights.h5')

In [None]:
batch_iter = iter(train_dataset)
batch=next(batch_iter)
#batch1=next(batch_iter)
#batch2=next(batch_iter)
out = model(batch[0])
dataset_iterator = iter(test_dataset)
test_batch1 = next(dataset_iterator)
#test_batch2 = next(dataset_iterator)
batch[0].dtype ,batch[1].dtype, out.dtype,batch[0].shape ,batch[1].shape, out.shape

In [None]:
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath="C:/Users/uic33116/Documents/documents/ariel-data-challenge-2024/training_full_model/single_wl_deviations_model-{epoch:02d}.weights.h5",
    save_weights_only=True,  # Set to False if you want to save the entire model
    save_freq=300 * 4,
    verbose=1
)

In [None]:
tf.random.set_seed(42)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
model.compile(loss='mae'
              #loss='mse'#maeSingleWL#'mae'            
              #,metrics=[log_likelihood_maxScaling]
              ,metrics=['mse']
              , optimizer=optimizer)

#history = model.fit(#train_dataset, 
#                    #batch[0],batch[1], #verbose=2,
#                    small_train,
#                    #validation_data=test_dataset,
#                    validation_data=test_batch1,
#                    epochs=100, batch_size=batch_size,
#                    #callbacks=[lr_callback]
#                    )


# batch normalization essential for gradient to travel downstream!
# with batch of 12 we converge well to mae ~3.3, mse 57 after ~3000 epochs

In [None]:
history = model.fit(#train_dataset, 
                    #batch[0],batch[1], #verbose=2,
                    small_train,
                    #validation_data=test_dataset,
                    validation_data=test_batch1,
                    epochs=1000, batch_size=batch_size,
                    #callbacks=[lr_callback]
                    )

In [None]:
b = next(iter(small_train))
outs = model.predict(b[0])
np.mean(np.abs(outs - b[1][:,0:1].numpy())), outs[0:20],b[1][0:20,0:1]

In [None]:
outs[:,0] ,b[1][:,0].numpy()

In [None]:
outs = model.predict(test_batch1[0])
np.mean(np.abs(outs[:,0] - test_batch1[1][:,0].numpy())),outs[0:20],test_batch1[1][0:20,0:1]

In [None]:
history = model.fit(#train_dataset, 
                    #batch[0],batch[1], #verbose=2,
                    small_train,
                    #validation_data=test_dataset,
                    validation_data=test_batch1,
                    epochs=100, batch_size=batch_size,
                    #callbacks=[lr_callback]
                    )

In [None]:
model.save_weights('deviationModelCnn56_5_75.weights.h5')

# investigate

In [None]:
def calcStats(b, plot=True, display=False):
    outputs = model.predict(b[0])
    print(outputs.shape)
    pred = outputs[:,0]

    if display:
        print(pred[0:10:,0:2,0], b[1][0:10:,0:2])
        print(pred[0:10:,0:2,0]*maxLabels, b[1][0:10:,0:2]*maxLabels)

    mae = np.sum(np.abs(pred-b[1])) / pred.shape[0] #/ pred.shape[1]
    mse = np.sum(np.abs(pred-b[1])**2) / pred.shape[0]# / pred.shape[1]
    print('mae',mae,'mse', mse)

    
    fig = go.Figure()
    #m = min(100, batch[0].shape[0])
    #for i in range(m): #range(12):# 
    print(b[1].shape)
    fig.add_trace(go.Scatter(y=b[1],mode='markers',name=f'gt',marker=dict(size=3)))
    fig.add_trace(go.Scatter(y=pred,mode='markers',name=f'pred',marker=dict(size=3)))
    fig.show()

    plt.hist(batch[1][:], bins=30, edgecolor='blue',alpha=0.7)
    plt.hist(out[:], bins=30, edgecolor='red',alpha=0.7)
    plt.title('Histogram of Data')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.show()

In [None]:
calcStats(next(iter(small_train)))

In [None]:
calcStats(batch)

In [None]:
calcStats(test_batch1)

In [None]:
outputs = model.predict(batch[0])

mae = np.sum(np.abs(outputs*100*maxLabels-batch[1][:,0:1]*100*maxLabels)) / outputs.shape[0]
mse = np.sum(np.abs(outputs*100*maxLabels-batch[1][:,0:1]*100*maxLabels)**2) / outputs.shape[0]
mae1 = np.sum(np.abs(outputs*maxLabels-batch[1][:,0:1]*maxLabels)) / outputs.shape[0]
mse1 = np.sum(np.abs(outputs*maxLabels-batch[1][:,0:1]*maxLabels)**2) / outputs.shape[0]
mae,mae1, outputs, batch[1][:,0:1]

In [None]:
outputst = model.predict(test_batch[0])

mae = np.sum(np.abs(outputst*100*maxLabels-test_batch[1][:,0:1]*100*maxLabels)) / outputs.shape[0]
mse = np.sum(np.abs(outputst*100*maxLabels-test_batch[1][:,0:1]*100*maxLabels)**2) / outputs.shape[0]
mae1 = np.sum(np.abs(outputst*maxLabels-test_batch[1][:,0:1]*maxLabels)) / outputs.shape[0]
mse1 = np.sum(np.abs(outputs*maxLabels-test_batch[1][:,0:1]*maxLabels)**2) / outputs.shape[0]
mae,mae1, outputst, test_batch[1][:,0:1]

In [None]:
batch[0].shape

In [None]:
# visualize input data
fig = go.Figure()
for i in range(100):
    fig.add_trace(go.Scatter(y=batch[0][i,:,0],mode='markers',name=f'{i}',marker=dict(size=3)))
fig.show()


In [None]:
out = model.predict(batch[0])
fig = go.Figure()
fig.add_trace(go.Scatter(y=batch[1][:],mode='markers',name=f'gt',marker=dict(size=3)))
fig.add_trace(go.Scatter(y=out[:,0],mode='markers',name=f'pred',marker=dict(size=3)))
fig.show()

plt.hist(batch[1][:], bins=30, edgecolor='blue',alpha=0.7)
plt.hist(out[:], bins=30, edgecolor='red',alpha=0.7)
plt.title('Histogram of Data')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()

In [None]:
out = model.predict(test_batch1[0])
fig = go.Figure()
fig.add_trace(go.Scatter(y=test_batch1[1][:],mode='markers',name=f'gt',marker=dict(size=3)))
fig.add_trace(go.Scatter(y=out[:,0],mode='markers',name=f'pred',marker=dict(size=3)))
fig.show()

plt.hist(test_batch1[1][:], bins=30, edgecolor='blue',alpha=0.7)
plt.hist(out[:], bins=30, edgecolor='red',alpha=0.7)
plt.title('Histogram of Data')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.show()

In [None]:
test_batch1[1].shape

# visualize layers

In [None]:
batchid = 2
wl = 0
x = batch[0][batchid:batchid+1,:,:,1]  
print(x.shape)  
fig = plt.figure()
plt.plot(x[0,:,wl])
plt.title(f'input')
plt.show()

for layers in range(len(model.layers)-1):
    x = model.layers[layers+1](x)
    print(x.shape)
    
    if len(x.shape)>=3:
        if len(x.shape) == 4:
            for i in range(x.shape[-1]):
                fig = plt.figure()
                plt.plot(x[0,:,wl,i])
                plt.title(f'layer {layers+1}')
                plt.show()
        else:
            fig = plt.figure()
            if x.shape[2] == 283:
                plt.plot(x[0,:,wl])
            else:
                if x.shape[2] == 1:
                    print(x[:,wl,:], batch[1][batchid:batchid+1,0:1])
                else:
                    plt.plot(x[0,wl,:])
            plt.title(f'layer {layers+1}')
            plt.show()
    else:
        if x.shape[1] >1:
            fig=plt.figure()
            plt.plot(x[0])
            plt.title(f'layer{layers+1}')
            plt.show()
        else:
            print(x, batch[1][batchid:batchid+1,0:1])
#print(model.layers[2].get_weights())

In [None]:
def print1Val(batchRowId, wl, batch, only1PerGraph=True):
    x = batch[0][batchRowId:batchRowId+1,:,:]  
    print(x.shape)  
    fig = plt.figure()
    for i in range(x.shape[2]):
        plt.plot(x[0,:,i])
    plt.title(f'input')
    plt.show()

    x = x[:,:,1:14]
    for layers in range(len(model.layers)-1):
        i = layers + 1
        print(x.shape)
        x = model.layers[i](x)
        print(x.shape)
        if x.shape[0] ==1 and x.shape[1] == 1 and len(x.shape)==2:
            print(x, batch[1][batchid:batchid+1,wl])
        else:
        
            fig = plt.figure()
            if len(x.shape) == 3:
                for i in range(x.shape[2]):
                    #fig = plt.figure()
                    plt.plot(np.reshape(x[0,:,i],(-1)))
                    if only1PerGraph:
                        break
                plt.title(f'layer {i}')
                plt.show()
            else:
                plt.plot(x[0,:])
                plt.title(f'layer {i}')
                plt.show()
    #print(model.layers[2].get_weights())

In [None]:
print1Val(0,0,test_batch1, only1PerGraph=True)

In [None]:
print1Val(0,0,batch)

In [None]:
fig = plt.figure()
plt.plot(batch[0][batchid,:,wl,1]  )
plt.title(f'input')
plt.show()

In [None]:
batch[0][batchid:batchid+1,:,wl,1]

In [None]:
print(batch[1])
for i in range(4):
    fig=plt.figure()
    plt.plot(x[i,:,0])
    fig.show()