In [1]:
import tensorflow as tf
import pandas as pd
import numpy as np
import keras
import h5py
from keras import layers,losses, optimizers

  _warn(("h5py is running against HDF5 {0} when it was built against {1}, "


In [2]:
from keras import layers, Model, Input

In [3]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [4]:
unsub = '../data/Dataset_Specific_Unlabelled.h5'

In [5]:
with h5py.File(unsub, 'r') as f:
    print("Keys inside file:", list(f.keys()))
    ds = f['jet'][:]

Keys inside file: ['jet']


In [6]:
   print("Dataset shape:", ds.shape) 

Dataset shape: (60000, 125, 125, 8)


In [7]:
ds /= ds.max()


In [8]:
def augment(x):
    noise = tf.random.normal(tf.shape(x), stddev=0.05)
    x = x + noise

    # Random channel dropout
    mask = tf.cast(tf.random.uniform((8,)) > 0.2, tf.float32)
    x = x * mask

    return tf.clip_by_value(x, 0.0, 1.0)


In [9]:
def ssl_gen():
    for i in range(len(ds)):
        x = ds[i].astype("float32")
        yield augment(x), augment(x)

In [10]:
def res_block(x, filters, stride=1):
    shortcut = x

    x = layers.Conv2D(filters, 3, stride, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters, 3, 1, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)

    if stride != 1 or shortcut.shape[-1] != filters:
        shortcut = layers.Conv2D(filters, 1, stride, padding="same", use_bias=False)(shortcut)
        shortcut = layers.BatchNormalization()(shortcut)

    x = layers.Add()([x, shortcut])
    return layers.ReLU()(x)


In [11]:
ssl_ds = tf.data.Dataset.from_generator(
    ssl_gen,
    output_signature=(
        tf.TensorSpec(shape=(125,125,8), dtype=tf.float32),
        tf.TensorSpec(shape=(125,125,8), dtype=tf.float32)
    )
).batch(32).prefetch(tf.data.AUTOTUNE)


In [12]:
inputs = layers.Input(shape=(125, 125, 8))

x = layers.Conv2D(64, 7, strides=2, padding="same", use_bias=False)(inputs)
x = layers.BatchNormalization()(x)
x = layers.ReLU()(x)
x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

# Residual stages
x = res_block(x, 64)
x = res_block(x, 64)

x = res_block(x, 128, stride=2)
x = res_block(x, 128)

x = res_block(x, 256, stride=2)
x = res_block(x, 256)

x = layers.GlobalAveragePooling2D()(x)
features = layers.Dense(128)(x)

encoder = tf.keras.Model(inputs, features, name="resnet15_encoder")


In [13]:
projection = keras.Sequential([
    layers.Dense(256, use_bias=False),
    layers.BatchNormalization(),
    layers.ReLU(),

    layers.Dense(128)
], name="projection_head")


In [14]:
def off_diagonal(x):
    return tf.reshape(
        x - tf.linalg.diag(tf.linalg.diag_part(x)), [-1]
    )

In [15]:
def vicreg_loss(z1, z2, sim=25.0, var=25.0, cov=1.0):
    # Invariance
    repr_loss = tf.reduce_mean(tf.square(z1 - z2))

    # Variance
    eps = 3e-4
    std_z1 = tf.sqrt(tf.math.reduce_variance(z1, axis=0) + eps)
    std_z2 = tf.sqrt(tf.math.reduce_variance(z2, axis=0) + eps)
    var_loss = tf.reduce_mean(tf.nn.relu(1 - std_z1)) + \
               tf.reduce_mean(tf.nn.relu(1 - std_z2))

    # Covariance
    z1 = z1 - tf.reduce_mean(z1, axis=0)
    z2 = z2 - tf.reduce_mean(z2, axis=0)

    batch_size = tf.cast(tf.shape(z1)[0], tf.float32)
    cov_z1 = tf.matmul(z1, z1, transpose_a=True) / (batch_size - 1)
    cov_z2 = tf.matmul(z2, z2, transpose_a=True) / (batch_size - 1)

    dim = tf.cast(tf.shape(z1)[1], tf.float32)
    cov_loss = (
        tf.reduce_sum(tf.square(off_diagonal(cov_z1))) +
        tf.reduce_sum(tf.square(off_diagonal(cov_z2)))
    ) / dim

    return sim * repr_loss + var * var_loss + cov * cov_loss


In [16]:
optimizer = tf.keras.optimizers.Adam(3e-4)

@tf.function
def train_step(x1, x2):
    with tf.GradientTape() as tape:
        z1 = projection(encoder(x1, training=True))
        z2 = projection(encoder(x2, training=True))
        loss = vicreg_loss(z1, z2)
    grads = tape.gradient(loss, encoder.trainable_variables + projection.trainable_variables)
    optimizer.apply_gradients(zip(grads, encoder.trainable_variables + projection.trainable_variables))
    return loss


In [17]:
for epoch in range(100):
    for x1, x2 in ssl_ds:
        loss = train_step(x1, x2)
    print(f"Epoch {epoch}: VICReg loss = {loss.numpy():.4f}")
x = ds[:1024].astype("float32")  
z = encoder.predict(x, batch_size=8)
print(z.std(axis=0).mean())



Epoch 0: VICReg loss = 15.6760
Epoch 1: VICReg loss = 14.4666
Epoch 2: VICReg loss = 13.3958
Epoch 3: VICReg loss = 12.2484
Epoch 4: VICReg loss = 12.7681
Epoch 5: VICReg loss = 12.0301
Epoch 6: VICReg loss = 11.5664
Epoch 7: VICReg loss = 11.6781
Epoch 8: VICReg loss = 12.8552
Epoch 9: VICReg loss = 11.1211
Epoch 10: VICReg loss = 10.8491
Epoch 11: VICReg loss = 11.1164
Epoch 12: VICReg loss = 11.2775
Epoch 13: VICReg loss = 11.2218
Epoch 14: VICReg loss = 11.4399
Epoch 15: VICReg loss = 11.0096
Epoch 16: VICReg loss = 10.7894
Epoch 17: VICReg loss = 11.1722
Epoch 18: VICReg loss = 11.2748
Epoch 19: VICReg loss = 10.6234
Epoch 20: VICReg loss = 10.9410
Epoch 21: VICReg loss = 10.5002
Epoch 22: VICReg loss = 11.4313
Epoch 23: VICReg loss = 11.1509
Epoch 24: VICReg loss = 10.1638
Epoch 25: VICReg loss = 10.7625
Epoch 26: VICReg loss = 10.7461
Epoch 27: VICReg loss = 10.5618
Epoch 28: VICReg loss = 10.2493
Epoch 29: VICReg loss = 10.3594
Epoch 30: VICReg loss = 10.5048
Epoch 31: VICReg l

In [18]:
x = ds[:1024].astype("float32")  
z = encoder.predict(x, batch_size=8)
print(z.std(axis=0).mean())


0.56758875


In [19]:
encoder.trainable = False
encoder.save("../models/vicreg_encoder.h5")


