## Implementation of deep CORAL

This is based on the paper: <i>Deep CORAL: Correlation Alignment for Deep Domain Adaptation</i>. In this case I am not fine-tuning a pretrained model but rather using a pretrained feature extractor. The task is sentiment analysis in which the target domain (Yelp reviews) has no labeled data (Amazon reviews are the source).

In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_probability as tfp
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Input,Dense,Dropout,Activation,Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.initializers import Constant
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score,balanced_accuracy_score
from sklearn.utils import shuffle

import warnings
warnings.filterwarnings("ignore")

In [2]:
tf.__version__ # required for tfp

'2.3.0'

In [3]:
## Getting the Yelp review data
yelp_x = np.load("../data/yelp_reviews/x.npy")
yelp_y = np.load("../data/yelp_reviews/y.npy")
yelp_x = yelp_x[:10000]
yelp_y = yelp_y[:10000]
print(yelp_x.shape,yelp_y.shape)

(10000, 640) (10000,)


In [4]:
## Getting the Amazon review data
amazon_x = np.load("../data/amazon_reviews/x.npy")
amazon_y = np.load("../data/amazon_reviews/y.npy")
amazon_x,amazon_y = amazon_x[:11000],amazon_y[:11000]
amazon_x_val,amazon_y_val = amazon_x[10000:],amazon_y[10000:]
amazon_x,amazon_y = amazon_x[:10000],amazon_y[:10000]
print(amazon_x.shape,amazon_y.shape,amazon_x_val.shape,amazon_y_val.shape)

(10000, 640) (10000,) (1000, 640) (1000,)


### Model implementation

$$ L_{CORAL} = \frac{1}{4d^2}||C_S - C_T||^2_F $$

Where d is the model layer dim. and $ C_S \& C_T $ are the batch feature covariance matrices for the source and target domains respectively.

In [5]:
def get_model(input_dim=640):
    """ model implementation
        -also returns the last hidden state (to use w/ CORAL loss component)
        -using the h2 vector prior to the relu activation makes little difference in practice
    """
    x = Input(shape=(input_dim))
    h1 = Dense(512,activation='relu')(x)
    h2 = Dense(256,activation='relu')(h1)
    out = Dense(1,activation='sigmoid')(h2)
    
    model = Model(inputs=x,outputs=[out,h2])
    return model

In [6]:
def get_coral_loss(mat1,mat2,d=256):
    """ calculates the CORAL loss component
    args:
        d: dimensionality of the input model hidden layer
    """
    mat1_cov = K.flatten(tfp.stats.covariance(mat1))
    mat2_cov = K.flatten(tfp.stats.covariance(mat2))
    squared_frobenius_distance = (1/(4*d**2))*tf.reduce_sum(tf.square(mat1_cov-mat2_cov)) # removed tf.sqrt()
    return squared_frobenius_distance

In [7]:
def model_loss(amazon_y_subset,amazon_pred,amazon_h,yelp_h,coral_lam):
    """ loss implementation
    args:
        coral_lam: amount to scale coral loss component by
    """
    class_loss = BinaryCrossentropy()(amazon_y_subset,amazon_pred) # automatic avg over batch
    coral_loss = get_coral_loss(amazon_h,yelp_h)
    total_loss = class_loss+(coral_lam*coral_loss)
    return total_loss

In [8]:
@tf.function
def train_model(model,optimizer,amazon_x_subset,amazon_y_subset,yelp_x_subset,coral_lam=1.0):
    """ used to train the model
    """
    with tf.GradientTape() as tape:
        amazon_pred,amazon_h = model(amazon_x_subset)
        _,yelp_h = model(yelp_x_subset)
        loss = model_loss(amazon_y_subset,amazon_pred,amazon_h,yelp_h,coral_lam)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

In [9]:
model = get_model()

In [10]:
batch_size=50
optimizer = Adam(lr=0.01)
epochs=3

for epoch_i in range(epochs):
    losses = []
    for i in range(0,len(amazon_x),batch_size):
        x_train_subset = amazon_x[i:i+batch_size]
        y_train_subset = amazon_y[i:i+batch_size]
        yelp_x_subset = yelp_x[i:i+batch_size]
        batch_loss = train_model(model,optimizer,x_train_subset,y_train_subset,yelp_x_subset)
        losses.append(float(batch_loss))
    
    print("epoch {}; loss:{}".format(epoch_i,round(sum(losses)/len(losses),4)))
    y_train_pred,_ = model(amazon_x)
    y_train_pred = y_train_pred.numpy()
    y_train_pred[y_train_pred >= 0.5]=1 ; y_train_pred[y_train_pred < 0.5]=0
    
    y_val_pred,_ = model(amazon_x_val)
    y_val_pred = y_val_pred.numpy()
    y_val_pred[y_val_pred >= 0.5]=1 ; y_val_pred[y_val_pred < 0.5]=0
    
    yelp_pred,_ = model(yelp_x)
    yelp_pred = yelp_pred.numpy()
    yelp_pred[yelp_pred >= 0.5]=1 ; yelp_pred[yelp_pred < 0.5]=0
    
    train_acc,train_bal_acc = round(accuracy_score(amazon_y,y_train_pred),4),round(balanced_accuracy_score(amazon_y,y_train_pred),4)    
    val_acc,val_bal_acc = round(accuracy_score(amazon_y_val,y_val_pred),4),round(balanced_accuracy_score(amazon_y_val,y_val_pred),4)
    yelp_acc,yelp_bal_acc = round(accuracy_score(yelp_y,yelp_pred),4),round(balanced_accuracy_score(yelp_y,yelp_pred),4)
    
    print("-Train; accuracy:{}; bal_accuracy:{}".format(train_acc,train_bal_acc))
    print("-Test; accuracy:{}; bal_accuracy:{}".format(val_acc,val_bal_acc))
    print("-YELP; accuracy:{}; bal_accuracy:{}".format(yelp_acc,yelp_bal_acc))

epoch 0; loss:0.3801
-Train; accuracy:0.8874; bal_accuracy:0.8878
-Test; accuracy:0.87; bal_accuracy:0.8693
-YELP; accuracy:0.8738; bal_accuracy:0.8758
epoch 1; loss:0.3005
-Train; accuracy:0.8926; bal_accuracy:0.8927
-Test; accuracy:0.878; bal_accuracy:0.8776
-YELP; accuracy:0.8661; bal_accuracy:0.8692
epoch 2; loss:0.2718
-Train; accuracy:0.9012; bal_accuracy:0.9013
-Test; accuracy:0.878; bal_accuracy:0.8777
-YELP; accuracy:0.8592; bal_accuracy:0.8625
