Local Connection-guided Contrastive Leanning (L-CoCL)
---
* Implement (G-CoCL) with the same code
* https://github.com/sayakpaul/Supervised-Contrastive-Learning-in-TensorFlow-2

In [None]:
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np
import losses
import time
import tensorflow as tf
import os
import keras
import random
import pickle

In [None]:
with open("L_CoCL_label.pickle", 'rb') as f:
    DATA = pickle.load(f)

In [None]:
mirrored_strategy = tf.distribute.MirroredStrategy()

In [None]:
import numpy as np
import math
from tensorflow.keras.utils import Sequence

spec = DATA['specEEG']
labels = DATA['LABEL']

class CustomDataset(tf.keras.utils.Sequence):
    def __init__(self,specEEG,labels,batch_size,target_size=(1, 55, 114),shuffle=False,n_classes=1):
        self.batch_size = batch_size
        self.dim        = target_size
        self.labels     = labels
        self.specEEG   = specEEG
        self.n_classes  = n_classes
        self.shuffle    = shuffle
        self.c          = 0
        self.on_epoch_end() 

    def __len__(self):
        # returns the number of batches
        return int(np.floor(len(self.specEEG) / self.batch_size))

    def __getitem__(self, index):
        # returns one batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Generate data
        X, X_aug, y = self.__data_generation(indexes)
        return X, X_aug, y

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.specEEG))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)
            
    def timeshift(self, spec, shift=3, direction='right', roll=True):
        direction_list= ['right', 'left']
        direction = random.choice(direction_list)
        spec = spec.copy()
        if direction == 'right':
            right_slice = spec[:, -shift:, :].copy()
            spec[:, shift:, :] = spec[:, :-shift, :]
            if roll:
                spec[:, :shift, :] = np.fliplr(right_slice)
        if direction == 'left':
            left_slice = spec[:, :shift, :].copy()
            spec[:, :-shift, :] = spec[:, shift:, :]
            if roll:
                spec[:, -shift:, :] = left_slice
        return spec
    
    def block_masking(self, spec, T=5, F=3, time_mask_num=1, freq_mask_num=1):
        feat_size = spec.shape[1] #59
        seq_len = spec.shape[2] #114
        for _ in range(time_mask_num):
            t = np.random.uniform(low=0.0, high=T)
            t = int(t)
            t0 = random.randint(0, seq_len - t)
            for _ in range(freq_mask_num):
                f = np.random.uniform(low=0.0, high=F)
                f = int(f)
                f0 = random.randint(0, feat_size - f)
                
                spec[:,t0 : t0 + t, f0 : f0 + f] = 0
        return spec      
    
    def spec_augment(self, spec, T=5, F=3, time_mask_num=1, freq_mask_num=1):
        feat_size = spec.shape[1]
        seq_len = spec.shape[2]
        # freq mask
        for _ in range(time_mask_num):
            t = np.random.uniform(low=0.0, high=T)
            t = int(t)
            t0 = random.randint(0, seq_len - t)
            spec[:,t0 : t0 + t,:] = 0
        # time mask
        for _ in range(freq_mask_num):
            f = np.random.uniform(low=0.0, high=F)
            f = int(f)
            f0 = random.randint(0, feat_size - f)
            spec[:, :, f0 : f0 + f] = 0
        return spec            
    
    def __data_generation(self, list_IDs_temp):
        X = np.empty((self.batch_size, *self.dim))
        X_aug = np.empty((self.batch_size, *self.dim))
        y = np.empty((self.batch_size), dtype=int)
        augment_list= ['spec_augment', 'time_shift', 'block_masking']

        # Generate data
        for i, ID in enumerate(list_IDs_temp):
            augment = random.choice(augment_list)
            # Store sample
            spec   = self.specEEG [ID]
            if augment == 'spec_augment':
                spec_aug = self.spec_augment(spec)
            if augment == 'time_shift':
                spec_aug = self.timeshift(spec)
            if augment == 'block_masking':
                spec_aug = self.block_masking(spec)
            X[i,]  = spec
            X_aug[i,] = spec_aug
            
            # Store class
            y[i] = self.labels[ID]

            self.c +=1
        return X, X_aug, y
    
class Generator(keras.utils.Sequence):
    """Wrapper of two generatos for the combined input model"""

    def __init__(self, X, Y, batch_size,target_size=(1, 55, 114)):
        self.genX = CustomDataset(X, Y, batch_size=batch_size,shuffle=False,target_size=target_size)

    def __len__(self):
        return self.genX.__len__()

    def __getitem__(self, index):
        X_batch, X_batch_aug, Y_batch = self.genX.__getitem__(index)
        
        return X_batch, X_batch_aug, Y_batch

Data_loader = Generator(spec,labels,batch_size=1024)

In [None]:
def encoderModel():
    input_shape=(1,55, 114)
    model = Sequential()
    model.add(Conv2D(32, (3, 3), padding='same',data_format= "channels_first",  activation='relu'))
    model.add(Conv2D(1,  (3, 3), padding='same',data_format= "channels_first",  activation='relu'))
    return model

In [None]:
class UnitNormLayer(tf.keras.layers.Layer):
    '''Normalize vectors (euclidean norm) in batch to unit hypersphere.
    '''
    def __init__(self):
        super(UnitNormLayer, self).__init__()

    def call(self, input_tensor):
        norm = tf.norm(input_tensor, axis=1)
        return input_tensor / tf.reshape(norm, [-1, 1])

In [None]:
# Encoder Network
def encoder_net():
    inputs = Input((1,55, 114))
    normalization_layer = UnitNormLayer()

    encoder = encoderModel()
    encoder.trainable = True

    embeddings = encoder(inputs, training=True)
    embeddings = GlobalAveragePooling2D()(embeddings)
    norm_embeddings = normalization_layer(embeddings)
    encoder_network = Model(inputs, norm_embeddings)

    return encoder_network

# Projector Network
def projector_net():
    projector = tf.keras.models.Sequential([
        Flatten(),
        Dense(64, activation="relu"),
        Dense(32, activation="relu"),
#         UnitNormLayer()
        
    ])
    return projector

In [None]:
optimizer = tf.keras.optimizers.Adam()
with mirrored_strategy.scope():
    encoder_r = encoder_net()
    projector_z = projector_net()

@tf.function
def train_step(images, images_aug, labels):
    with tf.GradientTape() as tape:
        r = encoder_r(images, training=True)
        z = projector_z(r, training=True)

        r_aug = encoder_r(images_aug, training=True)
        z_aug = projector_z(r_aug, training=True)

        loss = losses.max_margin_contrastive_loss_aug(z, z_aug, labels, metric='cosine')
        
    gradients = tape.gradient(loss, encoder_r.trainable_variables + projector_z.trainable_variables)
    optimizer.apply_gradients(zip(gradients, encoder_r.trainable_variables + projector_z.trainable_variables))

    return loss

In [None]:
@tf.function
def distributed_train_step(spec, spec_aug, label):
    per_replica_losses = mirrored_strategy.run(train_step, args=(spec, spec_aug, label))
    return mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

In [None]:
EPOCHS = 60
LOG_EVERY = 10
train_loss_results = []

train_log_dir = 'pre-trained_model_dir/L-CoCL/'

for epoch in tqdm(range(EPOCHS)):
    epoch_loss_avg = tf.keras.metrics.Mean()
    for (spec, spec_aug, label) in Data_loader:
        loss = distributed_train_step(spec, spec_aug, label)
        epoch_loss_avg.update_state(loss) 
    train_loss_results.append(epoch_loss_avg.result())
    if epoch % LOG_EVERY == 0:
        print("Epoch: {} Loss: {:.3f}".format(epoch, epoch_loss_avg.result()))

if not os.path.exists(train_log_dir):
    os.makedirs(train_log_dir)
a = 'encoder.h5' 
b = 'projecter.h5'   
encoder_r.save(os.path.join(train_log_dir, a))
projector_z.save(os.path.join(train_log_dir, b))

with plt.xkcd():
    plt.plot(train_loss_results)
    plt.title("L-CoCL Loss")
    plt.show()