In [None]:
threads = 16
import os
os.environ['OMP_NUM_THREADS']=str(threads)
import tensorflow as tf

# Tensorflow needs explicit cofig calls
tf.config.threading.set_inter_op_parallelism_threads(threads)
tf.config.threading.set_intra_op_parallelism_threads(threads)

In [None]:
# input conformation
#conf = "alaninedipeptide_H.pdb"
conf = "trpcage_correct.pdb"

# input trajectory
# atom numbering must be consistent with {conf}

#traj = "alaninedipeptide_reduced.xtc"
traj = "trpcage_red.xtc"

# input topology
# expected to be produced with 
#    gmx pdb2gmx -f {conf} -p {topol} -n {index} -o {gro}

# Gromacs changes atom numbering, the index file must be generated and used as well
# gro file is used to generate inverse indexing for plumed.dat

#topol = "topol.top"
topol = "topol_correct.top"
index = 'index_correct.ndx'
gro = 'trpcage_correct.gro'

In [None]:
import mdtraj as md
import numpy as np
import asmsa
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt

In [None]:
tr = md.load(traj,top=conf)
idx=tr[0].top.select("name CA")
#idx=tr[0].top.select("element != H")
tr.superpose(tr[0],atom_indices=idx)
geom = np.moveaxis(tr.xyz ,0,-1)
geom = np.moveaxis(tr.xyz ,0,-1)
geom.shape

In [None]:
density = 2 # integer in [1, n_atoms-1]

sparse_dists = asmsa.NBDistancesSparse(geom.shape[0], density=density)
mol = asmsa.Molecule(pdb=conf,top=topol,ndx=index,fms=[sparse_dists])

In [None]:
X_train = mol.intcoord(geom).T
X_train.shape

In [None]:
from tensorflow import keras

class AAEMultiModelNaive(keras.models.Model):
    
    def __init__(self,models):
        super().__init__()
        self.models = models
        
    def compile(self,*args,**kwargs):
        super().compile()
        for m in self.models:
            m.compile(args,kwargs)
    
    @tf.function
    def train_step(self,batch):
        ael = [0]
        dl = [0]
        for m in self.models:
            res = m.train_step(batch)
#            ael.append(res['ae_loss'].numpy())
#            dl.append(res['d_loss'].numpy())
        
        return { 'ae_loss_min': min(ael), 'ae_loss_max': max(ael), 'd_loss_min': min(dl), 'd_loss_max': max(dl) }

In [None]:
_default_hp = {
        'batch_size' : 64,
        'activation' : 'relu', 
        'ae_number_of_layers': 2,
        'disc_number_of_layers': 2,
        'ae_neuron_number_seed' : 32,
        'disc_neuron_number_seed' : 32,
        'ae_loss_fn': 'MeanSquaredError',
        'disc_loss_fn': 'BinaryCrossentropy',
        'optimizer': 'Adam',
}

In [None]:
mods = []
for ael in [2,3]:
    for an in range(32,129,8):
        hp = _default_hp.copy()
        hp['ae_number_of_layers'] = ael
        hp['ae_neuron_number_seed'] = an
        mods.append(asmsa.AAEModel((X_train.shape[1],),hp=hp))
        
mmod = AAEMultiModelNaive(mods)
mmod.compile()

        

In [None]:
ds = tf.data.Dataset.from_tensor_slices(X_train).shuffle(2048).batch(_default_hp['batch_size'],drop_remainder=True)

In [None]:
mmod.fit(ds,epochs=50)

In [None]:
len(mods)

In [None]:
mods[0].fit(ds,epochs=50)

In [None]:
def _compute_number_of_neurons(params,ae):
        tmp = params['ae_neuron_number_seed' if ae else 'disc_neuron_number_seed']
        neurons = [ tmp ]
        
        for _ in range(params['ae_number_of_layers'] if ae else params['disc_number_of_layers']):
                tmp = int(tmp / 2)
                neurons.append(tmp)
        return neurons


In [None]:
class AAEMultiModel(keras.models.Model):
    
    def __init__(self,molecule_shape,latent_dim=2,prior='normal',hp=_default_hp):
        super().__init__()
        self.inp = keras.Input(shape=molecule_shape,name='common.input')
        latent = []
        out = []
        
        # XXX
        for an in range(32,129,8):
            for r in range(4):
                m = f"seed{r}_{an}"
                neurons = _compute_number_of_neurons(hp,ae=True)
                l = self.inp
                i = 1
                for n in neurons:
                    l = keras.layers.Dense(n, activation=hp['activation'], name=f'enc.{m}.{i}')(l)
                    l = keras.layers.BatchNormalization(momentum=0.8)(l)
                    i += 1
                l = keras.layers.Dense(latent_dim,activation='linear',name=f'enc.{m}.output')(l)
                latent.append(l)

                i = 1
                for n in reversed(neurons):
                    l = keras.layers.Dense(n, activation=hp['activation'], name=f'dec.{m}.{i}')(l)
                    l = keras.layers.BatchNormalization(momentum=0.8)(l)
                    i += 1
                l = keras.layers.Dense(np.prod(molecule_shape), activation=hp['activation'],name=f'dec.{m}.output')(l)
                l = keras.layers.Reshape(molecule_shape,name=f'dec.{m}.reshape')(l)
                out.append(l)
        
        self.latent = tf.stack(latent,axis=1,name='all.latent')
        self.out = tf.stack(out,axis=1,name='all.output')
        self.n_models = len(out)
        
        self.encs = keras.Model(inputs=self.inp,outputs=self.latent)
        self.aes = keras.Model(inputs=self.inp,outputs=self.out)
            
    def compile(self,*args,**kwargs):
        super().compile(*args,**kwargs)
        self.encs.compile(*args,**kwargs)
        self.aes.compile(*args,**kwargs)
        
        # XXX
        self.ae_loss = keras.losses.MeanSquaredError()
        self.optimizer = keras.optimizers.legacy.Adagrad(learning_rate=0.0002)
    
    @tf.function
    def train_step(self,batch):
        if isinstance(batch,tuple):
            batch = batch[0]
            
        multibatch = tf.stack([batch]*self.n_models,axis=1)
#        print(multibatch)
        with tf.GradientTape() as tape:
            reconstruct = self.aes(batch)
            ae_loss = tf.reduce_sum(keras.metrics.mean_squared_error(multibatch,reconstruct),axis=0)

        ae_grad = tape.gradient(ae_loss,self.aes.trainable_weights)
        self.optimizer.apply_gradients(zip(ae_grad,self.aes.trainable_weights))
        
        return { 'ae min': tf.reduce_min(ae_loss), 'ae max' : tf.reduce_max(ae_loss) }


In [None]:
_compute_number_of_neurons(_default_hp,True)

In [None]:
m = AAEMultiModel((X_train.shape[1],))

In [None]:
m.compile()

In [None]:
m.aes.summary()

In [None]:
m.fit(ds,epochs=1)

In [None]:
len(list(range(32,129,8)))

In [None]:
m.aes.trainable_weights[26]

In [None]:
X_train.shape

In [None]:
364/13

In [None]:
class AAEMultiModel2(keras.models.Model):
    
    def __init__(self,molecule_shape,latent_dim=2,prior='normal',hp=_default_hp):
        super().__init__()
        self.inp = keras.Input(shape=molecule_shape,name='common.input')
        latent = []
        out = []
        self.aes = []
        
        # XXX
        for an in range(32,129,8):
            for r in range(4):
                m = f"seed{r}_{an}"
                neurons = _compute_number_of_neurons(hp,ae=True)
                l = self.inp
                i = 1
                for n in neurons:
                    l = keras.layers.Dense(n, activation=hp['activation'], name=f'enc.{m}.{i}')(l)
                    l = keras.layers.BatchNormalization(momentum=0.8)(l)
                    i += 1
                l = keras.layers.Dense(latent_dim,activation='linear',name=f'enc.{m}.output')(l)
                latent.append(l)

                i = 1
                for n in reversed(neurons):
                    l = keras.layers.Dense(n, activation=hp['activation'], name=f'dec.{m}.{i}')(l)
                    l = keras.layers.BatchNormalization(momentum=0.8)(l)
                    i += 1
                l = keras.layers.Dense(np.prod(molecule_shape), activation=hp['activation'],name=f'dec.{m}.output')(l)
                l = keras.layers.Reshape(molecule_shape,name=f'dec.{m}.reshape')(l)
                out.append(l)
                self.aes.append(keras.Model(inputs=self.inp,outputs=l))
        
        self.latent = tf.stack(latent,axis=1,name='all.latent')
        self.out = tf.stack(out,axis=1,name='all.output')        
        self.all_aes = keras.Model(inputs=self.inp,outputs=self.out)
            
    def compile(self,*args,**kwargs):
        super().compile(*args,**kwargs)

        for m in self.aes:
            m.compile(*args,**kwargs)
        
        # XXX
        self.ae_loss = keras.losses.MeanSquaredError()
        self.optimizer = keras.optimizers.legacy.Adagrad(learning_rate=0.0002)
    
    @tf.function
    def train_step(self,batch):
        if isinstance(batch,tuple):
            batch = batch[0]
            
        multibatch = tf.stack([batch]*len(self.aes),axis=1)
#        print(multibatch)
        with tf.GradientTape(persistent=True) as tape:
            reconstruct = self.all_aes(batch)
            ae_loss = tf.reduce_sum(keras.metrics.mean_squared_error(multibatch,reconstruct),axis=0)

        for m in self.aes:
            ae_grad = tape.gradient(ae_loss,m.trainable_weights)
            self.optimizer.apply_gradients(zip(ae_grad,m.trainable_weights))
        
        return { 'ae min': tf.reduce_min(ae_loss), 'ae max' : tf.reduce_max(ae_loss) }


In [None]:
m = AAEMultiModel2((X_train.shape[1],)))

In [None]:
m.compile()

In [None]:
m.fit(ds,epochs=1)

In [None]:
inp = keras.Input((X_train.shape[1],))

In [None]:
out = inp
for _ in range(4):
    out = keras.layers.BatchNormalization(momentum=0.8)(
        keras.layers.Dense(850,activation='relu')(out)
    )
out = keras.layers.Dense(X_train.shape[1],)(out)

class TestModel(keras.models.Model):
    @tf.function
    def train_step(self,batch):
        if isinstance(batch,tuple):
            batch = batch[0]

        with tf.GradientTape(persistent=True) as tape:
            reconstruct = self.all_aes(batch)
            ae_loss = keras.metrics.mean_squared_error(batch,reconstruct)

        ae_grad = tape.gradient(ae_loss,self.trainable_weights)
        self.optimizer.apply_gradients(zip(ae_grad,self.trainable_weights))
        return ae_loss


    
silim = keras.Model(inputs=inp,outputs=out)
silim.compile(loss=keras.losses.MeanSquaredError(),optimizer = keras.optimizers.legacy.Adagrad(learning_rate=0.0002))


In [None]:
silim.summary()

In [None]:
silim.fit(X_train,X_train,epochs=1)

In [None]:
inp=keras.Input((5,))

In [None]:
l1 = keras.layers.Dense(10,activation='relu')(inp)

In [None]:
l2 = keras.layers.Dense(10,activation='relu')(l1)

In [None]:
l2.trainable

In [None]:
l2.trainable_weights

In [None]:
tm = keras.Model(inputs=inp,outputs=l2)

In [None]:
l2

In [None]:
tm.layers[1].trainable_weights

In [None]:
l2

In [None]:
tm.layers[2].name

In [None]:
class AAEMultiModel3(keras.models.Model):
    # enc_neurons: (models,layers)
    def __init__(self,inp_shape,enc_neurons,disc_neurons,latent_dim=2):
        super().__init__()
        
        self.n_models = enc_neurons.shape[0]
        self.latent_dim = latent_dim
        
        assert disc_neurons.shape[0] == self.n_models
        
        inp = keras.Input(shape=inp_shape)
        out = inp
        # TODO: emulate "empty" layers
        for num in range(enc_neurons.shape[1]):
            print(f'enc_{num} {enc_neurons[:,num]}')
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',name=f'enc_{num}')(out)
            out = keras.layers.BatchNormalization(momentum=0.8,name=f'enc_bn_{num}')(out)
        out = keras.layers.Dense(enc_neurons.shape[0]*latent_dim,name='enc_out')(out) # 
        latent = out
        
        # decoder layers are numbered in reverse so that neuron numbers match with encoder
        for num  in reversed(range(enc_neurons.shape[1])):
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',name=f'dec_{num}')(out)
            out = keras.layers.BatchNormalization(momentum=0.8,name=f'dec_bn_{num}')(out)
            
        out = keras.layers.Dense(enc_neurons.shape[0]*inp_shape[0],name='dec_out')(out)
        out = keras.layers.Reshape((enc_neurons.shape[0],inp_shape[0]))(out)
        
        self.aes = keras.Model(inputs=inp,outputs=[out,latent])
        self.enc = keras.Model(inputs=inp,outputs=latent)
        self.dec = keras.Model(inputs=latent,outputs=out)
        
        inp = keras.Input(shape=(latent_dim * self.n_models,))
        disc = inp
        for num in range(disc_neurons.shape[1]):
            disc = keras.layers.Dense(np.sum(disc_neurons[:,num]),name=f'disc_{num}')(disc)
            disc = keras.layers.LeakyReLU(alpha=0.2,name=f'disc_relu_{num}')(disc)
            
        disc = keras.layers.Dense(disc_neurons.shape[0],name='disc_out')(disc)
        
        self.disc = keras.Model(inputs=inp,outputs=disc)
        
        self.masks = {}
        
        # masks for encoder layers 1:, decoder layers 0: are just .T's
        idx = np.concatenate((np.zeros((1,enc_neurons.shape[1]),dtype=np.int32),np.cumsum(enc_neurons,axis=0)))
        for layer in range(1, enc_neurons.shape[1]):
            mask = np.zeros((np.sum(enc_neurons[:,layer-1]),np.sum(enc_neurons[:,layer])),dtype=np.float32)
            for mod in range(enc_neurons.shape[0]):
                mask[idx[mod,layer-1]:idx[mod+1,layer-1],idx[mod,layer]:idx[mod+1,layer]] = 1.
            self.masks[f'enc_{layer}'] = tf.convert_to_tensor(mask)
            self.masks[f'dec_{layer-1}'] = tf.convert_to_tensor(mask.T)
        
        # mask to and from latent layer
        mask = np.zeros((np.sum(enc_neurons[:,-1]),latent_dim * enc_neurons.shape[0]),dtype=np.float32)
        for mod in range(enc_neurons.shape[0]):
            mask[idx[mod,-1]:idx[mod+1,-1],mod*latent_dim:(mod+1)*latent_dim] = 1.
        
        self.masks['enc_out'] = tf.convert_to_tensor(mask)
        self.masks[f'dec_{enc_neurons.shape[1]-1}'] = tf.convert_to_tensor(mask.T)
        
        # mask for decoder output
        mask = np.zeros((np.sum(enc_neurons[:,0]),inp_shape[0]*enc_neurons.shape[0]),dtype=np.float32)
        for mod in range(enc_neurons.shape[0]):
            mask[idx[mod,0]:idx[mod+1,0], mod*inp_shape[0]:(mod+1)*inp_shape[0]] = 1.
        self.masks['dec_out'] = tf.convert_to_tensor(mask)
             
        idx = np.concatenate((np.zeros((1,disc_neurons.shape[1]),dtype=np.int32),np.cumsum(disc_neurons,axis=0)))
        # mask for discriminator layer 0:
        mask = np.zeros((latent_dim * disc_neurons.shape[0],np.sum(disc_neurons[:,0])),dtype=np.float32)
        for mod in range(disc_neurons.shape[0]):
            mask[mod*latent_dim:(mod+1)*latent_dim, idx[mod,0]:idx[mod+1,0]] = 1.
        self.masks['disc_0'] = tf.convert_to_tensor(mask)
            
        # mask for discriminator layers 1:
        for layer in range(1,disc_neurons.shape[1]):
            mask = np.zeros((np.sum(disc_neurons[:,layer-1]),np.sum(disc_neurons[:,layer])),dtype=np.float32)
            for mod in range(disc_neurons.shape[0]):
                mask[idx[mod,layer-1]:idx[mod+1,layer-1],idx[mod,layer]:idx[mod+1,layer]] = 1.
            self.masks[f'disc_{layer}'] = tf.convert_to_tensor(mask)
            
        # mask for discriminator output:
        mask = np.zeros((np.sum(disc_neurons[:,-1]),disc_neurons.shape[0]),dtype=np.float32)
        for mod in range(disc_neurons.shape[0]):
            mask[idx[mod,disc_neurons.shape[1]-1]:idx[mod+1,disc_neurons.shape[1]-1],mod] = 1.
        self.masks['disc_out'] = tf.convert_to_tensor(mask)

    def compile(self):
        # XXX
        super().compile(optimizer = keras.optimizers.legacy.Adagrad(learning_rate=0.0002))
        self.ae_weights = self.enc.trainable_weights + self.dec.trainable_weights
         
    @tf.function
    def train_step(self,batch):
        if isinstance(batch,tuple):
            batch = batch[0]
        
        # multiple models need replicated batch to compute loss simultaneously
        multibatch = tf.stack([batch]*self.n_models,axis=1)
 
        with tf.GradientTape() as tape:
            reconstruct = self.aes(batch)
            ae_multiloss = tf.reduce_mean(keras.metrics.mean_squared_error(multibatch,reconstruct[0]),axis=0)
            ae_loss = tf.reduce_sum(ae_multiloss)
               
        ae_grad = tape.gradient(ae_loss,self.ae_weights)
        self.optimizer.apply_gradients(zip(ae_grad,self.ae_weights))
        
        for l in self.aes.layers:
            if l.name in self.masks:
#                print(l.name, self.masks[l.name])
                l.weights[0].assign(l.weights[0] * self.masks[l.name])
       
#        rand_low = tf.random.normal((batch.shape[0],self.latent_dim * self.n_models))
        rand_low = tf.random.normal((batch.shape[0],self.latent_dim))
#        rand_low = tf.random.uniform((batch.shape[0],self.latent_dim))
        rand_low = tf.repeat(rand_low,self.n_models,axis=1)
        lows = tf.concat([rand_low, reconstruct[1]],axis=0)
        labels = tf.concat([tf.ones((batch.shape[0],self.n_models)), tf.zeros((batch.shape[0],self.n_models))], axis=0)
        labels += 0.05 * tf.random.uniform(labels.shape)
                        
        with tf.GradientTape() as tape:
            pred = self.disc(lows)
            disc_losses = tf.keras.metrics.binary_crossentropy(labels,pred,from_logits=True,axis=0)
            disc_loss = tf.reduce_sum(disc_losses)
            
        disc_grads = tape.gradient(disc_loss,self.disc.trainable_weights)
        self.optimizer.apply_gradients(zip(disc_grads,self.disc.trainable_weights))
        
        for l in self.disc.layers:
            if l.name in self.masks:
#                print(l.name, self.masks[l.name])
                l.weights[0].assign(l.weights[0] * self.masks[l.name])
  
        all_true = tf.ones((batch.shape[0],self.n_models))
    
        with tf.GradientTape() as tape:
            cheat = self.disc(self.enc(batch))
            cheat_losses = tf.keras.metrics.binary_crossentropy(all_true,cheat,from_logits=True,axis=0)
            cheat_loss = tf.reduce_sum(cheat_losses)
            
        cheat_grads = tape.gradient(cheat_loss,self.enc.trainable_weights)
        self.optimizer.apply_gradients(zip(cheat_grads,self.enc.trainable_weights))
        
        for l in self.enc.layers:
            if l.name in self.masks:
#                print(l.name, self.masks[l.name])
                l.weights[0].assign(l.weights[0] * self.masks[l.name])
       

        return {
            'AE loss min' : tf.reduce_min(ae_multiloss),
            'AE loss max' : tf.reduce_max(ae_multiloss),
            'disc loss min' : tf.reduce_min(disc_losses),
            'disc loss max' : tf.reduce_max(disc_losses),
            'cheat loss min' : tf.reduce_min(cheat_losses),
            'cheat loss max' : tf.reduce_max(cheat_losses)
        }
    
    @tf.function
    def call(self,inp):
        return self.aes(inp)
                            
            

In [None]:
batch=np.array([[1,2,3],[4,5,6]])
multibatch = np.stack([batch]*2,axis=1)
multibatch # (batch, model, x)

In [None]:
tf.convert_to_tensor(multibatch)

In [None]:
result=np.stack([batch, np.array([[1.1,2.1,2.9],[4.5,5.5,6]])],axis=1)
result

In [None]:
fr = np.array([[1,2,3,1.1,2.1,2.9],[4,5,6,4.5,5.5,6]])
r2 = np.reshape(fr,(2,2,3))
r2

In [None]:
metric=keras.metrics.mean_squared_error(multibatch,r2)
metric

In [None]:
mean = tf.reduce_mean(metric,axis=0)
mean

In [None]:
en = np.array([[32, 16, 8]]*10)

m3 = AAEMultiModel3((X_train.shape[1],),en)

In [None]:
m3(X_train[0:1])

In [None]:
m3.compile()

In [None]:
m3.fit(ds,epochs=1)

In [None]:
m3.aes.layers[3].get_weights()

In [None]:
m3.aes.summary()

In [None]:
ds = tf.data.Dataset.from_tensor_slices(X_train).shuffle(2048).batch(256,drop_remainder=True)

In [None]:
en = []
for an in range(32,129,8):
    hp = _default_hp.copy()
    hp['ae_neuron_number_seed'] = an
    en.append(_compute_number_of_neurons(hp,True))
    
en = np.array(en)

dn = np.array([[ 32, 64, 32 ]] * en.shape[0])
dn

In [None]:
m3a = AAEMultiModel3((X_train.shape[1],),en,dn)
m3a.compile()
m3a.aes.summary()

In [None]:
m3a.fit(ds,epochs=20)

In [None]:
lows = m3a.enc(X_train).numpy()
lows.shape

In [None]:
for mod in range(lows.shape[1]//2):
    plows = lows[::200,mod*2:(mod+1)*2]
    plt.scatter(plows[:,0],plows[:,1],marker='.',label=str(mod))

lim=.05
plt.legend()    
plt.xlim((-lim,lim))
plt.ylim((-lim,lim))
plt.show()

In [None]:
en = []
dn = []
hp = _default_hp.copy()

for an in range(32,129,32):
    hp['ae_neuron_number_seed'] = an
    en.append(_compute_number_of_neurons(hp,True))
    
for an in range(64,257,64):
    hp['disc_neuron_number_seed'] = an
    dn.append(_compute_number_of_neurons(hp,False))

el = len(en)
dl = len(dn)
en = np.array(en * dl)
dn = np.repeat(np.array(dn),el,axis=0)

In [None]:
en

In [None]:
dn

In [None]:
m3b = AAEMultiModel3((X_train.shape[1],),en,dn)
m3b.compile()
m3b.aes.summary()

In [None]:
m3b.fit(ds,epochs=20)

In [None]:
lows = m3b.enc(X_train).numpy()
lows.shape

In [None]:
plt.figure(figsize=(12,9))
for mod in range(lows.shape[1]//2):
    plows = lows[::100,mod*2:(mod+1)*2]
    plt.scatter(plows[:,0],plows[:,1],marker='.',label=str(mod))

lim=.04
plt.legend()    
plt.xlim((-lim,lim))
plt.ylim((-lim,lim))
plt.show()

In [None]:
m3c = AAEMultiModel3((X_train.shape[1],),np.array([[96,48],[96,48]]),np.array([[224,112],[64,32]]))
m3c.compile()
m3c.aes.summary()

In [None]:
m3c.fit(ds,epochs=50)

In [None]:
lows = m3c.enc(X_train).numpy()
lows.shape

In [None]:
l

In [None]:
class _Sparse(keras.constraints.Constraint):
    def __init__(self,mask):
        self.mask = tf.convert_to_tensor(mask)

    @tf.function    
    def __call__(self,w):
        return w * self.mask
    

class AAEMultiModel4(keras.models.Model):
    # enc_neurons: (models,layers)
    
    def __init__(self,inp_shape,enc_neurons,disc_neurons,latent_dim=2):
        super().__init__()
        
        self.n_models = enc_neurons.shape[0]
        self.latent_dim = latent_dim
        
        assert disc_neurons.shape[0] == self.n_models
        
        self.masks = {}
        
        # masks for encoder layers 1:, decoder layers 0: are just .T's
        idx = np.concatenate((np.zeros((1,enc_neurons.shape[1]),dtype=np.int32),np.cumsum(enc_neurons,axis=0)))
        for layer in range(1, enc_neurons.shape[1]):
            mask = np.zeros((np.sum(enc_neurons[:,layer-1]),np.sum(enc_neurons[:,layer])),dtype=np.float32)
            for mod in range(enc_neurons.shape[0]):
                mask[idx[mod,layer-1]:idx[mod+1,layer-1],idx[mod,layer]:idx[mod+1,layer]] = 1.
            self.masks[f'enc_{layer}'] = _Sparse(mask)
            self.masks[f'dec_{layer-1}'] = _Sparse(mask.T)
        
        # mask to and from latent layer
        mask = np.zeros((np.sum(enc_neurons[:,-1]),latent_dim * enc_neurons.shape[0]),dtype=np.float32)
        for mod in range(enc_neurons.shape[0]):
            mask[idx[mod,-1]:idx[mod+1,-1],mod*latent_dim:(mod+1)*latent_dim] = 1.
        
        self.masks['enc_out'] = _Sparse(mask)
        self.masks[f'dec_{enc_neurons.shape[1]-1}'] = _Sparse(mask.T)
        
        # mask for decoder output
        mask = np.zeros((np.sum(enc_neurons[:,0]),inp_shape[0]*enc_neurons.shape[0]),dtype=np.float32)
        for mod in range(enc_neurons.shape[0]):
            mask[idx[mod,0]:idx[mod+1,0], mod*inp_shape[0]:(mod+1)*inp_shape[0]] = 1.
        self.masks['dec_out'] = _Sparse(mask)
             
        idx = np.concatenate((np.zeros((1,disc_neurons.shape[1]),dtype=np.int32),np.cumsum(disc_neurons,axis=0)))
        # mask for discriminator layer 0:
        mask = np.zeros((latent_dim * disc_neurons.shape[0],np.sum(disc_neurons[:,0])),dtype=np.float32)
        for mod in range(disc_neurons.shape[0]):
            mask[mod*latent_dim:(mod+1)*latent_dim, idx[mod,0]:idx[mod+1,0]] = 1.
        self.masks['disc_0'] = _Sparse(mask)
#        print('disc_0',mask)
            
        # mask for discriminator layers 1:
        for layer in range(1,disc_neurons.shape[1]):
            mask = np.zeros((np.sum(disc_neurons[:,layer-1]),np.sum(disc_neurons[:,layer])),dtype=np.float32)
            for mod in range(disc_neurons.shape[0]):
                mask[idx[mod,layer-1]:idx[mod+1,layer-1],idx[mod,layer]:idx[mod+1,layer]] = 1.
            self.masks[f'disc_{layer}'] = _Sparse(mask)
#            print(f'disc_{layer}',mask)
            
        # mask for discriminator output:
        mask = np.zeros((np.sum(disc_neurons[:,-1]),disc_neurons.shape[0]),dtype=np.float32)
        for mod in range(disc_neurons.shape[0]):
            mask[idx[mod,disc_neurons.shape[1]-1]:idx[mod+1,disc_neurons.shape[1]-1],mod] = 1.
        self.masks['disc_out'] = _Sparse(mask)  
#        print('disc_out',mask)

        
        inp = keras.Input(shape=inp_shape)
        out = inp

        # TODO: emulate "empty" layers
        for num in range(enc_neurons.shape[1]):
            name = f'enc_{num}'
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',
                                     name = name, kernel_constraint = self.masks.get(name))(out)
#            out = keras.layers.BatchNormalization(momentum=0.8,name=f'enc_bn_{num}')(out)
            
        out = keras.layers.Dense(enc_neurons.shape[0]*latent_dim,name='enc_out', kernel_constraint = self.masks.get('enc_out'))(out) # 
        latent = out
        
        # decoder layers are numbered in reverse so that neuron numbers match with encoder
        for num  in reversed(range(enc_neurons.shape[1])):
            name = f'dec_{num}'
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',
                                     name=name,kernel_constraint=self.masks.get(name))(out)
#            out = keras.layers.BatchNormalization(momentum=0.8,name=f'dec_bn_{num}')(out)
            
        out = keras.layers.Dense(enc_neurons.shape[0]*inp_shape[0],name='dec_out',kernel_constraint=self.masks.get('dec_out'))(out)
        out = keras.layers.Reshape((enc_neurons.shape[0],inp_shape[0]))(out)
        
        self.aes = keras.Model(inputs=inp,outputs=[out,latent])
        self.enc = keras.Model(inputs=inp,outputs=latent)
        self.dec = keras.Model(inputs=latent,outputs=out)
        
        inp = keras.Input(shape=(latent_dim * self.n_models,))
        disc = inp
        for num in range(disc_neurons.shape[1]):
            name = f'disc_{num}'
            disc = keras.layers.Dense(np.sum(disc_neurons[:,num]),
                                      name=name,kernel_constraint=self.masks.get(name))(disc)
            disc = keras.layers.LeakyReLU(alpha=0.2,name=f'disc_relu_{num}')(disc)
            
        disc = keras.layers.Dense(disc_neurons.shape[0],name='disc_out',kernel_constraint=self.masks.get('disc_out'))(disc)
        
        self.disc = keras.Model(inputs=inp,outputs=disc)
        
 

    def compile(self):
        # XXX
        super().compile(optimizer = keras.optimizers.legacy.Adagrad(learning_rate=0.0002))
        self.ae_weights = self.enc.trainable_weights + self.dec.trainable_weights
         
    @tf.function
    def train_step(self,batch):
        if isinstance(batch,tuple):
            batch = batch[0]
        
        # multiple models need replicated batch to compute loss simultaneously
        multibatch = tf.stack([batch]*self.n_models,axis=1)
 
        with tf.GradientTape() as tape:
            reconstruct = self.aes(batch)
            ae_multiloss = tf.reduce_mean(keras.metrics.mean_squared_error(multibatch,reconstruct[0]),axis=0)
            ae_loss = tf.reduce_sum(ae_multiloss)
               
        ae_grad = tape.gradient(ae_loss,self.ae_weights)
        self.optimizer.apply_gradients(zip(ae_grad,self.ae_weights))
               
#        rand_low = tf.random.normal((batch.shape[0],self.latent_dim * self.n_models))
        rand_low = tf.random.normal((batch.shape[0],self.latent_dim))
#        rand_low = tf.random.uniform((batch.shape[0],self.latent_dim))
        rand_low = tf.tile(rand_low,(1,self.n_models))
        lows = tf.concat([rand_low, reconstruct[1]],axis=0)
        labels = tf.concat([tf.ones((batch.shape[0],self.n_models)), tf.zeros((batch.shape[0],self.n_models))], axis=0)
        labels += 0.05 * tf.random.uniform(labels.shape)
                        
        with tf.GradientTape() as tape:
            
            # FIXME: perturbe
            neg_pred = self.disc(reconstruct[1])
            neg_losses = tf.reduce_sum(neg_pred,axis=0) 
            pos_pred = self.disc(rand_low)
            pos_losses = -tf.reduce_sum(pos_pred,axis=0)
            disc_losses = neg_losses + pos_losses
            disc_loss = tf.reduce_mean(disc_losses)
            
        disc_grads = tape.gradient(disc_loss,self.disc.trainable_weights)
        self.optimizer.apply_gradients(zip(disc_grads,self.disc.trainable_weights))
          
        all_true = tf.ones((batch.shape[0],self.n_models))
    
        with tf.GradientTape() as tape:
            cheat = self.disc(self.enc(batch))
            cheat_losses = -tf.reduce_sum(cheat,axis=0)
            cheat_loss = tf.reduce_mean(cheat_losses)
            
        cheat_grads = tape.gradient(cheat_loss,self.enc.trainable_weights)
        self.optimizer.apply_gradients(zip(cheat_grads,self.enc.trainable_weights))
        
#        return {
#            'AE loss min' : tf.reduce_min(ae_multiloss),
#            'AE loss max' : tf.reduce_max(ae_multiloss),
#            'disc loss min' : tf.reduce_min(disc_losses),
#            'disc loss max' : tf.reduce_max(disc_losses),
#            'cheat loss min' : tf.reduce_min(cheat_losses),
#            'cheat loss max' : tf.reduce_max(cheat_losses)
#        }

        return { str(i): disc_losses[i] for i in range(disc_losses.shape[0]) }
        #return { '0': disc_losses[0], '1' : disc_losses[1] }
    
    @tf.function
    def call(self,inp):
        return self.aes(inp)
                            
            

In [None]:
m4 = AAEMultiModel4((X_train.shape[1],),np.array([[96,48],[96,48]]*10),np.array([[224,112],[64,32]]*10))
#m4 = AAEMultiModel4((X_train.shape[1],),np.array([[96,48],[96,48]]),np.array([[224,112],[64,32]]))
m4.compile()
m4.aes.summary()

In [None]:
f = m4.fit(ds,epochs=3)

In [None]:
f.history

In [None]:
lows = m4.enc(X_train).numpy()

plt.figure(figsize=(12,9))
for mod in range(lows.shape[1]//2):
    plows = lows[::10,mod*2:(mod+1)*2]
    plt.scatter(plows[:,0],plows[:,1],marker='.',label=str(mod))

lim=5
plt.legend()    
plt.xlim((-lim,lim))
plt.ylim((-lim,lim))
plt.show()

In [None]:
rand_low = tf.random.normal((64,2))
rand_low = tf.tile(rand_low,(1,m4.n_models))
rand_low

In [None]:
pred = m4.disc(rand_low)
tf.reduce_sum(pred,axis = 0)

In [None]:
r2 = tf.tile(tf.random.normal((5000,2),stddev=1.),(1,m4.n_models))
p2 = m4.disc(r2)
tf.reduce_sum(p2,axis=0)

In [None]:
d2 = m4.disc(lows[:5000,:])
tf.reduce_sum(d2,axis=0)

In [None]:
_random_init = keras.initializers.RandomUnivorm(seed=42,minval=-1.,maxval=1.)

class _SparseConstraint(keras.constraints.Constraint):
    def __init__(self,mask):
        self.mask = mask
        
    @tf.function    
    def __call__(self,w):
        return w * self.mask
    
class _SparseInitializer(keras.initializers.Initializer):
    def __init__(self,mask):
        self.mask = mask
    
    def __call__(self,shape,dtype):
        return _random_init(shape=shape,dtype=dtype) * self.mask
        
    
class _Sparse():
    def __init__(self,mask):
        tmask = tf.convert_to_tensor(mask)
        self.i = _SparseInitializer(tmask)
        self.c = _SparseConstraint(tmask)

class AAEMultiModel5(keras.models.Model):
    # enc_neurons: (models,layers)
    
    def __init__(self,inp_shape,enc_neurons,disc_neurons,latent_dim=2):
        super().__init__()
        
        self.n_models = enc_neurons.shape[0]
        self.latent_dim = latent_dim
        
        assert disc_neurons.shape[0] == self.n_models
        
        self.masks = {}
        
        # XXX: input to first model only, hack
        mask = np.zeros((inp_shape[0],np.sum(enc_neurons[:,0])),dtype=np.float32)
        mask[:,:enc_neurons[0,0]] = 1.
        self.masks['enc_0'] = _Sparse(mask)
        
        # masks for encoder layers 1:, decoder layers 0: are just .T's
        idx = np.concatenate((np.zeros((1,enc_neurons.shape[1]),dtype=np.int32),np.cumsum(enc_neurons,axis=0)))
        for layer in range(1, enc_neurons.shape[1]):
            mask = np.zeros((np.sum(enc_neurons[:,layer-1]),np.sum(enc_neurons[:,layer])),dtype=np.float32)
#            for mod in range(enc_neurons.shape[0]):
            for mod in [0]:
                mask[idx[mod,layer-1]:idx[mod+1,layer-1],idx[mod,layer]:idx[mod+1,layer]] = 1.
            self.masks[f'enc_{layer}'] = _Sparse(mask)
            self.masks[f'dec_{layer-1}'] = _Sparse(mask.T)
        
        # mask to and from latent layer
        mask = np.zeros((np.sum(enc_neurons[:,-1]),latent_dim * enc_neurons.shape[0]),dtype=np.float32)
#        for mod in range(enc_neurons.shape[0]):
        for mod in [0]:
            mask[idx[mod,-1]:idx[mod+1,-1],mod*latent_dim:(mod+1)*latent_dim] = 1.
        
        print('enc_out:', mask)
        self.masks['enc_out'] = _Sparse(mask)
        self.masks[f'dec_{enc_neurons.shape[1]-1}'] = _Sparse(mask.T)
        
        # mask for decoder output
        mask = np.zeros((np.sum(enc_neurons[:,0]),inp_shape[0]*enc_neurons.shape[0]),dtype=np.float32)
#        for mod in range(enc_neurons.shape[0]):
        for mod in [0]:
            mask[idx[mod,0]:idx[mod+1,0], mod*inp_shape[0]:(mod+1)*inp_shape[0]] = 1.
        self.masks['dec_out'] = _Sparse(mask)
             
        idx = np.concatenate((np.zeros((1,disc_neurons.shape[1]),dtype=np.int32),np.cumsum(disc_neurons,axis=0)))
        # mask for discriminator layer 0:
        mask = np.zeros((latent_dim * disc_neurons.shape[0],np.sum(disc_neurons[:,0])),dtype=np.float32)
        for mod in range(disc_neurons.shape[0]):
            mask[mod*latent_dim:(mod+1)*latent_dim, idx[mod,0]:idx[mod+1,0]] = 1.
        self.masks['disc_0'] = _Sparse(mask)
#        print('disc_0',mask)
            
        # mask for discriminator layers 1:
        for layer in range(1,disc_neurons.shape[1]):
            mask = np.zeros((np.sum(disc_neurons[:,layer-1]),np.sum(disc_neurons[:,layer])),dtype=np.float32)
            for mod in range(disc_neurons.shape[0]):
                mask[idx[mod,layer-1]:idx[mod+1,layer-1],idx[mod,layer]:idx[mod+1,layer]] = 1.
            self.masks[f'disc_{layer}'] = _Sparse(mask)
#            print(f'disc_{layer}',mask)
            
        # mask for discriminator output:
        mask = np.zeros((np.sum(disc_neurons[:,-1]),disc_neurons.shape[0]),dtype=np.float32)
        for mod in range(disc_neurons.shape[0]):
            mask[idx[mod,disc_neurons.shape[1]-1]:idx[mod+1,disc_neurons.shape[1]-1],mod] = 1.
        self.masks['disc_out'] = _Sparse(mask)  
#        print('disc_out',mask)

        
        inp = keras.Input(shape=inp_shape)
        out = inp
#        out = keras.layers.RepeatVector(self.n_models,trainable=False,name='inp_repeat')(inp)
#        out = keras.layers.Reshape((self.n_models * inp_shape[0],),trainable=False,name='inp_reshape')(out)

        
        out = keras.layers.Dense(np.sum(enc_neurons[:,0]),
                                 activation='relu',name = 'enc_0',
                                 kernel_constraint = self.masks['enc_0'].c,
                                 kernel_initializer = self.masks['enc_0'].i)(out)
        out = keras.layers.BatchNormalization(momentum=0.8,name=f'enc_bn_0')(out)
        for num in range(1,enc_neurons.shape[1]):
            name = f'enc_{num}'
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',
                                     name = name, kernel_constraint = self.masks[name].c,
                                     kernel_initializer=self.masks[name].i)(out)
            out = keras.layers.BatchNormalization(momentum=0.8,name=f'enc_bn_{num}')(out)
            
        out = keras.layers.Dense(enc_neurons.shape[0]*latent_dim,name='enc_out',
                                 kernel_constraint = self.masks['enc_out'].c,
                                 kernel_initializer = self.masks['enc_out'].i)(out) # 
        latent = out
        
        # decoder layers are numbered in reverse so that neuron numbers match with encoder
        for num  in reversed(range(enc_neurons.shape[1])):
            name = f'dec_{num}'
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',
                                     name=name,
                                     kernel_constraint=self.masks[name].c,
                                     kernel_initializer=self.masks[name].i
                                    )(out)
            out = keras.layers.BatchNormalization(momentum=0.8,name=f'dec_bn_{num}')(out)
            
        out = keras.layers.Dense(enc_neurons.shape[0]*inp_shape[0],name='dec_out',
                                 kernel_constraint=self.masks['dec_out'].c,
                                 kernel_initializer=self.masks['dec_out'].i
                                )(out)
        out = keras.layers.Reshape((enc_neurons.shape[0],inp_shape[0]))(out)
        
        self.aes = keras.Model(inputs=inp,outputs=[out,latent])
        self.enc = keras.Model(inputs=inp,outputs=latent)
        self.dec = keras.Model(inputs=latent,outputs=out)
        
        inp = keras.Input(shape=(latent_dim * self.n_models,))
        disc = inp
        for num in range(disc_neurons.shape[1]):
            name = f'disc_{num}'
            disc = keras.layers.Dense(np.sum(disc_neurons[:,num]),
                                      name=name,kernel_constraint=self.masks[name].c,
                                      kernel_initializer=self.masks[name].i
                                     )(disc)
            disc = keras.layers.LeakyReLU(alpha=0.2,name=f'disc_relu_{num}')(disc)
            
        disc = keras.layers.Dense(disc_neurons.shape[0],name='disc_out',
                                  kernel_constraint=self.masks['disc_out'].c,
                                  kernel_initializer=self.masks['disc_out'].i,
                                  activation='sigmoid')(disc)
        
        self.disc = keras.Model(inputs=inp,outputs=disc)
        
 

    def compile(self):
        # XXX
        super().compile(optimizer = keras.optimizers.legacy.Adagrad(learning_rate=0.0002))
        self.ae_weights = self.enc.trainable_weights + self.dec.trainable_weights

    @tf.function
    def train_step(self,batch):
        if isinstance(batch,tuple):
            batch = batch[0]
        
        # multiple models need replicated batch to compute loss simultaneously
        multibatch = tf.stack([batch]*self.n_models,axis=1)
 
        with tf.GradientTape() as aetape:
            reconstruct = self.aes(batch)
            print(reconstruct[0].shape)
            mse = keras.metrics.mean_squared_error(multibatch,reconstruct[0])
            print(mse.shape)
            ae_multiloss = tf.reduce_mean(mse,axis=0)
            print(ae_multiloss.shape)
            
            #ae_loss = tf.reduce_sum(ae_multiloss)
            ae_loss = ae_multiloss[0]
               
        ae_grad = aetape.gradient(ae_loss,self.ae_weights)
        self.last_ae_grad = ae_grad
        self.optimizer.apply_gradients(zip(ae_grad,self.ae_weights))
               
#        rand_low = tf.random.normal((batch.shape[0],self.latent_dim * self.n_models))
        rand_low = tf.random.normal((batch.shape[0],self.latent_dim))
#        rand_low = tf.random.uniform((batch.shape[0],self.latent_dim))
        rand_low = tf.tile(rand_low,(1,self.n_models))
        lows = tf.concat([rand_low, reconstruct[1]],axis=0)
        labels = tf.concat([tf.ones((batch.shape[0],self.n_models)), tf.zeros((batch.shape[0],self.n_models))], axis=0)
        labels += 0.05 * tf.random.uniform(labels.shape)
                        
        with tf.GradientTape() as dtape:
            
            # FIXME: perturbe
            neg_pred = self.disc(reconstruct[1])
            neg_losses = -tf.reduce_mean(tf.math.log(1-neg_pred),axis=0) 
            pos_pred = self.disc(rand_low)
            pos_losses = -tf.reduce_mean(tf.math.log(pos_pred),axis=0)
            disc_losses = neg_losses + pos_losses
#            disc_loss = tf.reduce_mean(disc_losses)
            disc_loss = disc_losses[0]
                
        disc_grads = dtape.gradient(disc_loss,self.disc.trainable_weights)
#        self.optimizer.apply_gradients(zip(disc_grads,self.disc.trainable_weights))
          
        all_true = tf.ones((batch.shape[0],self.n_models))
    
        with tf.GradientTape() as ctape:
            cheat = self.disc(self.enc(batch))
            cheat_losses = -tf.reduce_mean(tf.math.log(cheat),axis=0)
            cheat_loss = cheat_losses[0]
            #cheat_loss = tf.reduce_mean(cheat_losses)
            
        cheat_grads = ctape.gradient(cheat_loss,self.enc.trainable_weights)
#        self.optimizer.apply_gradients(zip(cheat_grads,self.enc.trainable_weights))
        
#        return {
#            'AE loss min' : tf.reduce_min(ae_multiloss),
#            'AE loss max' : tf.reduce_max(ae_multiloss),
#            'disc loss min' : tf.reduce_min(disc_losses),
#            'disc loss max' : tf.reduce_max(disc_losses),
#            'cheat loss min' : tf.reduce_min(cheat_losses),
#            'cheat loss max' : tf.reduce_max(cheat_losses)
#        }

        return { str(i): ae_multiloss[i] for i in range(ae_multiloss.shape[0]) }
        #return { '0': disc_losses[0], '1' : disc_losses[1] }
    
    @tf.function
    def call(self,inp):
        return self.aes(inp)
                            
            

In [None]:
#m5 = AAEMultiModel5((X_train.shape[1],),np.array([[96,48],[96,48]]*10),np.array([[224,112],[64,32]]*10))
#m5 = AAEMultiModel5((X_train.shape[1],),np.array([[96,48],[96,48]]),np.array([[224,112],[64,32]]))
m5 = AAEMultiModel5((X_train.shape[1],),np.array([[96,48]]*10),np.array([[224,112]]*10))
#m5 = AAEMultiModel5((X_train.shape[1],),np.array([[4]]*10),np.array([[4,3]]*10))
m5.compile()

In [None]:
ds = tf.data.Dataset.from_tensor_slices(X_train[:256]).shuffle(2048).batch(256,drop_remainder=True)

In [None]:
m5.fit(ds,epochs=1)

In [None]:
lows = m5.enc(X_train[::500]).numpy()
#lows = m5.aes(X_train[::500])[1].numpy()

plt.figure(figsize=(12,9))
for mod in range(lows.shape[1]//2):
    plows = lows[:,mod*2:(mod+1)*2]
    plt.scatter(plows[:,0],plows[:,1],marker='.',label=str(mod))

lim=2000000
plt.legend()    
plt.xlim((-lim,lim))
plt.ylim((-lim,lim))
plt.show()

In [None]:
len(m5.last_ae_grad)

In [None]:
for w in m5.ae_weights:
    if 'kernel' in w.name: print(w.name,np.sum(np.abs(w.numpy())))

In [None]:
for w in m5.ae_weights:
    if 'kernel' in w.name: print(w.name,np.sum(np.abs(w.numpy())))

In [None]:
m5.ae_weights[4]

In [None]:
_glorot(shape=(2,5))

In [None]:
np.sum(np.array([1,2,3]))

In [None]:
_random_init = keras.initializers.GlorotUniform(seed=42)
_random_init((2,3))

In [None]:
_random_init = keras.initializers.GlorotUniform(seed=42)

class _SparseConstraint(keras.constraints.Constraint):
    
    def __init__(self,left,right):
        assert len(left) == len(right)
        mask = np.zeros((np.sum(left),np.sum(right)),dtype=np.float32)
        idxl = np.concatenate((np.zeros((1,),dtype=np.int32),np.cumsum(left)))
        idxr = np.concatenate((np.zeros((1,),dtype=np.int32),np.cumsum(right)))
        for mod in range(len(left)):
            mask[idxl[mod]:idxl[mod+1],idxr[mod]:idxr[mod+1]] = 1.
     
        self.mask = tf.convert_to_tensor(mask)
        
    @tf.function    
    def __call__(self,w):
        return w * self.mask

class _SparseInitializer(keras.initializers.Initializer):
    def __init__(self,left,right):
        assert len(left) == len(right)
        self.left = left
        self.right = right
        self.idxl = np.concatenate((np.zeros((1,),dtype=np.int32),np.cumsum(left)))
        self.idxr = np.concatenate((np.zeros((1,),dtype=np.int32),np.cumsum(right)))
    
    def __call__(self,shape,dtype=None):
#        print(shape,self.left,self.right)
        assert shape == [np.sum(self.left),np.sum(self.right)]
        
        init = np.zeros((np.sum(self.left),np.sum(self.right)),dtype=dtype.as_numpy_dtype)
        for mod in range(len(self.left)):
            init[self.idxl[mod]:self.idxl[mod+1],self.idxr[mod]:self.idxr[mod+1]] = _random_init((self.left[mod],self.right[mod])).numpy()
        
        return tf.convert_to_tensor(init)
    
    
def _masks(left,right):
    return { 'kernel_initializer': _SparseInitializer(left,right),
            'kernel_constraint': _SparseConstraint(left,right) }



class AAEMultiModel6(keras.models.Model):
    # enc_neurons: (models,layers)
    
    def __init__(self,inp_shape,enc_neurons,disc_neurons,latent_dim=2):
        super().__init__()
        
        self.n_models = enc_neurons.shape[0]
        self.latent_dim = latent_dim
        
        assert disc_neurons.shape[0] == self.n_models
                
        inp = keras.Input(shape=inp_shape)
        out = inp

        out = keras.layers.Dense(np.sum(enc_neurons[:,0]),activation='relu',name = 'enc_0')(out)
        out = keras.layers.BatchNormalization(momentum=0.8,name=f'enc_bn_0')(out)
        
        for num in range(1,enc_neurons.shape[1]):
            name = f'enc_{num}'
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',
                                     name = name, **_masks(enc_neurons[:,num-1],enc_neurons[:,num]))(out)
            out = keras.layers.BatchNormalization(momentum=0.8,name=f'enc_bn_{num}')(out)
            
        out = keras.layers.Dense(self.n_models*latent_dim,name='enc_out',
                                 **_masks(enc_neurons[:,-1],[latent_dim]*self.n_models))(out) 
        latent = out
        
        out = keras.layers.Dense(np.sum(enc_neurons[:,-1]),activation='relu',name=f'dec_{enc_neurons.shape[1]}',
                                 **_masks([latent_dim]*self.n_models,enc_neurons[:,-1]))(out)
        out = keras.layers.BatchNormalization(momentum=0.8,name=f'dec_bn_{enc_neurons.shape[1]}')(out)
        
        # decoder layers are numbered in reverse so that neuron numbers match with encoder
        for num in reversed(range(enc_neurons.shape[1]-1)):
            name = f'dec_{num}'
            out = keras.layers.Dense(np.sum(enc_neurons[:,num]),activation='relu',name=name,
                                     **_masks(enc_neurons[:,num+1],enc_neurons[:,num]))(out)
            out = keras.layers.BatchNormalization(momentum=0.8,name=f'dec_bn_{num}')(out)
            
        out = keras.layers.Dense(self.n_models*inp_shape[0],name='dec_out',activation='relu',
                                    **_masks(enc_neurons[:,0],[inp_shape[0]]*self.n_models))(out)
        out = keras.layers.Reshape((self.n_models,inp_shape[0]))(out)
        
        self.aes = keras.Model(inputs=inp,outputs=[out,latent])
        self.enc = keras.Model(inputs=inp,outputs=latent)
        self.dec = keras.Model(inputs=latent,outputs=out)
        
        inp = keras.Input(shape=(latent_dim * self.n_models,))
        disc = inp
        disc = keras.layers.Dense(np.sum(disc_neurons[:,0]),name='disc_0',
                                  **_masks([latent_dim]*self.n_models,disc_neurons[:,0]))(disc)
        disc = keras.layers.LeakyReLU(alpha=0.2,name=f'disc_relu_{num}')(disc)
        
        for num in range(1,disc_neurons.shape[1]):
            name = f'disc_{num}'
            disc = keras.layers.Dense(np.sum(disc_neurons[:,num]),name=name,
                                      **_masks(disc_neurons[:,num-1],disc_neurons[:,num]))(disc)
            disc = keras.layers.LeakyReLU(alpha=0.2,name=f'disc_relu_{num}')(disc)
            
        disc = keras.layers.Dense(self.n_models,name='disc_out',
                                  **_masks(disc_neurons[:,-1],[1]*self.n_models))(disc)
        
        self.disc = keras.Model(inputs=inp,outputs=disc)
        
 

    def compile(self):
        # XXX
        super().compile(optimizer = keras.optimizers.legacy.Adagrad(learning_rate=0.0002))
        self.ae_weights = self.enc.trainable_weights + self.dec.trainable_weights

    @tf.function
    def train_step(self,batch):
        if isinstance(batch,tuple):
            batch = batch[0]
            
        # multiple models need replicated batch to compute loss simultaneously
        multibatch = tf.stack([batch]*self.n_models,axis=1)
 
        with tf.GradientTape() as aetape:
            reconstruct = self.aes(batch)
            mse = keras.metrics.mean_squared_error(multibatch,reconstruct[0])
            ae_multiloss = tf.reduce_mean(mse,axis=0)
            
            ae_loss = tf.reduce_sum(ae_multiloss)
               
        ae_grad = aetape.gradient(ae_loss,self.ae_weights)
        self.last_ae_grad = ae_grad
        self.optimizer.apply_gradients(zip(ae_grad,self.ae_weights))
               
#        labels += 0.05 * tf.random.uniform(labels.shape)
        rand_low = tf.random.normal((tf.shape(batch)[0],self.latent_dim))
#        rand_low = tf.random.uniform((batch.shape[0],self.latent_dim))
        rand_low = tf.tile(rand_low,(1,self.n_models))
                        
        with tf.GradientTape() as dtape:
            # perturbe
            neg_pred = self.disc(reconstruct[1])
            neg_losses = tf.reduce_mean(neg_pred*tf.random.uniform(tf.shape(neg_pred),1.,1.05),axis=0) 
            pos_pred = self.disc(rand_low)
            pos_losses = -tf.reduce_mean(pos_pred*tf.random.uniform(tf.shape(pos_pred),1.,1.05),axis=0)
            disc_losses = neg_losses + pos_losses
            disc_loss = tf.reduce_mean(disc_losses)
                
        disc_grads = dtape.gradient(disc_loss,self.disc.trainable_weights)
        self.optimizer.apply_gradients(zip(disc_grads,self.disc.trainable_weights))
          
        with tf.GradientTape() as ctape:
            # perturbe
            cheat = self.disc(self.enc(batch))
            cheat_losses = -tf.reduce_mean(cheat*tf.random.uniform(tf.shape(cheat),1.,1.05),axis=0)
            cheat_loss = tf.reduce_mean(cheat_losses)
            
        cheat_grads = ctape.gradient(cheat_loss,self.enc.trainable_weights)
        self.optimizer.apply_gradients(zip(cheat_grads,self.enc.trainable_weights))
        
#        return {
#            'AE loss min' : tf.reduce_min(ae_multiloss),
#            'AE loss max' : tf.reduce_max(ae_multiloss),
#            'disc loss min' : tf.reduce_min(disc_losses),
#            'disc loss max' : tf.reduce_max(disc_losses),
#            'cheat loss min' : tf.reduce_min(cheat_losses),
#            'cheat loss max' : tf.reduce_max(cheat_losses)
#        }

        return { str(i): ae_multiloss[i] for i in range(ae_multiloss.shape[0]) }
        #return { '0': disc_losses[0], '1' : disc_losses[1] }
    
    @tf.function
    def call(self,inp):
        return self.aes(inp)
                            
            

In [None]:
m=_masks([2,3],[4,5])
m['kernel_constraint'].mask

In [None]:
m['kernel_initializer']([5,9],tf.float32)

In [None]:
#m6 = AAEMultiModel6((X_train.shape[1],),np.array([[96,48],[96,48]]*10),np.array([[224,112],[64,32]]*10))
#m6 = AAEMultiModel6((X_train.shape[1],),np.array([[96,48],[96,48]]),np.array([[224,112],[64,32]]))
#m6 = AAEMultiModel6((X_train.shape[1],),np.array([[96,48]]*10),np.array([[224,112]]*10))
#m6 = AAEMultiModel6((X_train.shape[1],),np.array([[4]]*10),np.array([[4,3]]*10))
m6 = AAEMultiModel6((X_train.shape[1],),en,dn)
m6.compile()

In [None]:
for w in m6.ae_weights:
    if 'kernel' in w.name: print(w.name,np.sum(np.abs(w.numpy())))

In [None]:
lows = m6.enc(X_train[::200]).numpy()
#lows = m5.aes(X_train[::500])[1].numpy()

plt.figure(figsize=(12,9))
for mod in range(lows.shape[1]//2):
    plows = lows[:,mod*2:(mod+1)*2]
    plt.scatter(plows[:,0],plows[:,1],marker='.',label=str(mod))

lim=2
plt.legend()    
plt.xlim((-lim,lim))
plt.ylim((-lim,lim))
plt.show()

In [None]:
ds = tf.data.Dataset.from_tensor_slices(X_train).shuffle(2048).batch(256,drop_remainder=True)

In [None]:
m6.fit(ds,epochs=5)

In [None]:
en

In [None]:
dn