In [None]:
import tensorflow as tf
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # set according to available resources 

##  DATA

### to be done by user...
build datasets: might be image-captions, image-image... depending on encoders
To make the code work, simply assign a tf-dataset to "train_ds" and optionally a validation set to "val_ds". 

In [None]:
train_ds = None
val_ds = None

### CONFIG

In [None]:
from helpers import helpers
cfg = helpers.load_config()

# BUILD MODELS
VICReg is a general-purpose architecture that should allow multi-modal represenation learning. To test the multi-modal capabilities this VICReg consists of an image encoder (ResNet-like architecture) and a text encoder (Vanilla Transformer). Two MLPs are used for the expander models.

In [None]:
from models import models

rep_dim = cfg.representation_dim  # dimension of encoder output
emb_dim = cfg.embedding_dim  # dimension of expander output
exp_layers = cfg.n_expander_layers
# encoder specific configs
encoder_1_cfg = cfg.encoder_1_config
encoder_2_cfg = cfg.encoder_2_config

In [None]:
# inputs -> based on specific datasets and modalities
img_size = (256, 256)
seq_len = 25
vocab_size = 10_000

### Encoder 1

In [None]:
# define input_shape -> should match encoder
input_shape_1 = img_size + (3,)

In [None]:
# build encoder 1 
encoder_1 = models.build_ResNet(input_shape_1, blocks=encoder_1_cfg.n_channels, z_dim=rep_dim)
# build expander 1
expander_1 = models.build_expander(embedding_dim=emb_dim, expander_layers=exp_layers)

### Encoder 2

In [None]:
# build encoder 2
encoder_2 = models.TextEncoder(encoder_2_cfg.n_layers, seq_len, vocab_size, encoder_2_cfg.d_model,
                               encoder_2_cfg.num_heads, encoder_2_cfg.mlp_dim, encoder_2_cfg.dropout)
# build expander 2
expander_2 = models.build_expander(embedding_dim=emb_dim, expander_layers=exp_layers)

## VICReg

In [None]:
V_loss, I_loss, C_loss = models.V_loss, models.I_loss, models.C_loss

In [None]:
# build and compile model
params = {"V_loss_weight" : cfg.loss_weights.variance, "I_loss_weight" : cfg.loss_weights.invariance,
          "C_loss_weight" : cfg.loss_weights.covariance}
model = models.VICReg(encoder_1, encoder_2, expander_1, expander_2, params)
opt = tf.keras.optimizers.Adam(learning_rate=cfg.learning_rate)
model.compile(opt, V_loss, I_loss, C_loss)

In [None]:
# fit model
history = model.fit(train_ds, validation_data=val_ds, epochs=cfg.epochs)

## Evaluation
Check that represenations do not collapse:
1. informational collapse (all features identical)
2. sample-wise collapse (identical representation for all samples)