In [None]:
#Imports
import os, sys
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import functools
import cuml

from VICReg.vicreg        import GeneralMultipleVICReg
from VICReg.vicreg_utils  import create_projector, create_resnet, create_adam_opt
from VICReg.dataset_utils import preprocess_ds
from VICReg.augmentations import custom_augment_image
from VICReg.warmup_learning_rate import WarmUpLR
from VICReg.warmupcosine import WarmUpCosine
from VICReg.classifier    import ClusterClassifier, classifier_class

In [None]:
#Hyperparameters

AUTO = tf.data.AUTOTUNE
SEED = 42

PROJECT_DIM = 2048
BATCH_SIZE  = 128
EPOCHS      = 100
IM_SIZE = 224
DATASET_SIZE = 100000
STEPS_PER_EPOCH = DATASET_SIZE//BATCH_SIZE
WARMUP_EPOCHS = EPOCHS * 0.001
WARMUP_STEPS = int(WARMUP_EPOCHS * STEPS_PER_EPOCH)
SHUFFLE_BUFFER = 2**10

path_train = ""
path_test = ""
MODEL_SAVE_PATH = ""

In [None]:
#Functions

lr_decayed_fn = WarmUpCosine(learning_rate_base=1e-4,
                             total_steps=EPOCHS*STEPS_PER_EPOCH,
                             warmup_learning_rate=0.0,
                             warmup_steps=WARMUP_STEPS
                             )

augment_im = lambda x: custom_augment_image(x, input_shape=(IM_SIZE,IM_SIZE,3), output_shape=(IM_SIZE,IM_SIZE,3))

In [None]:
#Dataset

image_train_ds = load_dataset(path_train)
image_valid_ds = load_dataset(path_test)
patch_valid_ds = image_valid_ds.map(lambda x: (augment_im(x['original_images']), augment_im(tf.stack([x['masked_images']for _ in range(3)],-1))), num_parallel_calls=AUTO)
patch_valid_ds = patch_valid_ds.batch(BATCH_SIZE).prefetch(AUTO)
patch_train_ds = image_train_ds.map(lambda x: (augment_im(x['original_images']), augment_im(tf.stack([x['masked_images']for _ in range(3)],-1))), num_parallel_calls=AUTO)
patch_train_ds = preprocess_ds(patch_train_ds, batch_size=BATCH_SIZE, seed=SEED, pre=AUTO, shuffle_no=SHUFFLE_BUFFER, rei=True)

In [None]:
input_shape = (IM_SIZE, IM_SIZE, 3)
encoder1    = create_resnet(input_shape)
projector1  = create_projector()
encoder2    = create_resnet((IM_SIZE,IM_SIZE,3))
projector2      = create_projector()

optimizer_enc1  = create_adam_opt(lr_decayed_fn)
optimizer_proj1 = create_adam_opt(lr_decayed_fn)
optimizer_enc2  = create_adam_opt(lr_decayed_fn)
optimizer_proj2 = create_adam_opt(lr_decayed_fn)

enc_list            = [encoder1, encoder2]
proj_list           = [projector1, projector2]
optimizer_list_enc  = [optimizer_enc1, optimizer_enc2]
optimizer_list_proj = [optimizer_proj1, optimizer_proj2]
optimizer_list      = [optimizer_list_enc, optimizer_list_proj]

In [None]:
for i, enc in enumerate(enc_list):
    enc.load_weights(os.path.join(MODEL_SAVE_PATH, f'encoder_weights_{i}'))
for i, enc in enumerate(proj_list):
    enc.load_weights(os.path.join(MODEL_SAVE_PATH, f'projector_weights_{i}'))

In [None]:
vicreg = GeneralMultipleVICReg(encoder_list=enc_list, projector_list=proj_list, encoder_indices=[0,1], projector_indices=[0,1])
vicreg.compile(optimizer=optimizer_list)

In [None]:
vicreg.fit(patch_train_ds,
           epochs=EPOCHS,
           callbacks=[],
           validation_data=patch_valid_ds
          )

In [None]:
for i, enc in enumerate(enc_list):
    enc.save_weights(os.path.join(MODEL_SAVE_PATH, f'encoder_weights_{i}'))
for i, enc in enumerate(proj_list):
    enc.save_weights(os.path.join(MODEL_SAVE_PATH, f'projector_weights_{i}'))

In [None]:
# Tissue Classification Patches

task_train_ds = image_train_ds.map(lambda x: (x['original_images'], x['tissue_types']))
task_train_ds = task_train_ds.map(lambda x, y: (augment_im(x), y), num_parallel_calls=AUTO)
task_train_ds = preprocess_ds(task_train_ds, batch_size=BATCH_SIZE, seed=SEED, pre=AUTO, shuffle_no=SHUFFLE_BUFFER, rei=True)
task_train_ds = task_train_ds.map(lambda x, y: (x, tf.reshape(y, shape=(-1,1))), num_parallel_calls=AUTO)
task_test_ds = image_valid_ds.map(lambda x: (tf.cast(x['original_images'], tf.float32),  x['tissue_types']), num_parallel_calls=AUTO).batch(BATCH_SIZE).prefetch(AUTO)
task_test_ds = task_test_ds.map(lambda x, y: (x, tf.reshape(y, shape=(-1,1))), num_parallel_calls=AUTO)

In [None]:
classifier_head = classifier_class()
classifier = ClusterClassifier(vicreg.encoder_list[0], classifier_head)
classifier.compile(optimizer='adam',
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                   metrics=["accuracy"]
                  )
classifier.fit(task_train_ds, epochs=100, validation_data=task_test_ds)

In [None]:
classifier.evaluate(task_test_ds)

In [None]:
classifier.save_weights(os.path.join(MODEL_SAVE_PATH, 'classifier_weights'))