## import packages

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import random
import network_fcn as vae
import SiameseNetwork as smsn
import argparse
import pickle
from keras.layers import concatenate, Flatten
import tensorflow as tf
from tensorflow import keras
import tensorflow.keras.backend as K
from matplotlib.backends.backend_pdf import PdfPages
import pickle
from scipy import spatial
import os
from os.path import dirname, abspath
from utils import *
from itertools import combinations
import warnings
import Classify as cls
from modeltransfer import *
import matplotlib.colors as mcolors

## configure GPU for tensorflow

In [None]:
warnings.filterwarnings('ignore')
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
physical_devices = tf.config.list_physical_devices('GPU') 
for gpu_instance in physical_devices: 
    tf.config.experimental.set_memory_growth(gpu_instance, True)

## initialize hyperparameters

In [None]:
parser = argparse.ArgumentParser(description='')
parser.add_argument('--input_size', dest='input_size', default= (None, 1))
parser.add_argument('--batch_shape', dest='batch_shape', default=8)
parser.add_argument('--latent_shape', dest='latent_shape', default=128)
parser.add_argument('--para_shape', dest='para_shape', default=15)
parser.add_argument('--batch_size', dest='batch_size', type=int, default=64)
parser.add_argument('--epochs', dest='epochs', type=int, default=50)
parser.add_argument('--step_decay', dest='step_decay', type=int, default=5)
parser.add_argument('--patience', dest='patience', type=int, default=10)
parser.add_argument('--lr', dest='lr', type=float, default=0.00005)
parser.add_argument('--dec_loss', dest='dec_loss', default=vae.loss().loss_vae)
parser.add_argument('--dis_loss', dest='dis_loss', default=vae.loss().loss_bce)
parser.add_argument('--cls_loss', dest='cls_loss', default=vae.loss().loss_bce)
parser.add_argument('--beta_1', dest='beta_1', type=float, default=0.5)
parser.add_argument('--norm', dest='norm', default='batch_norm')
parser.add_argument('--n_down', dest='n_down', default=7)
parser.add_argument('--n_std', dest='n_std', default=0.15)
parser.add_argument('--batch', dest='batch', default=False)
parser.add_argument('--group', dest='group', default=False)
parser.add_argument('--n_group', dest='n_group', default=4)
args, unknown = parser.parse_known_args()

## import data and remove spectra with much lower correlation than the others

In [None]:
metadata = pd.read_csv(dirname(abspath(os.getcwd())) + '/datasets/bacterial_SSP/metadata.csv')
spec = pd.read_csv(dirname(abspath(os.getcwd())) + '/datasets/bacterial_SSP/data.csv')
wn = np.array(spec.iloc[0,1:])
spec = np.array(spec.iloc[1:,1:])
labels = np.array(metadata['labels'])
batches = np.array(metadata['batches'])

ix_wn = np.asarray(range(len(wn)))[(wn>1850) & (wn<2750)]
ix_wn1 = np.asarray(range(len(wn)))[(wn<1800) | (wn>2800)]

for i in range(spec.shape[0]):
    spec[i,:] = spec[i, :]/np.max(spec[i, :])

ix = np.argwhere((labels != 'L-innocua') & (labels != 'P-stutzeri'))[:,0]
spec = spec[ix, :]
labels = labels[ix]
batches = batches[ix]

corr_all = np.mean(np.corrcoef(spec), 0)
i_good = corr_all>np.percentile(corr_all, 1)
spec = spec[i_good, :]
labels = labels[i_good]
batches = batches[i_good]

uni_labels = np.unique(labels)
dummy_y = np.zeros((len(labels), len(uni_labels)))
for i in range(len(uni_labels)):
    dummy_y[labels==uni_labels[i], i] = 1

n_spec_gen = 60

## perform spectra generation and model transfer
- perform model transfer on real data with different number of training samples
- perform model transfer with different approaches: EMSC, MS, and siamese network

In [None]:
uni_batches = np.unique(batches)
index = np.array(range(len(uni_batches)))
for n_b in [2, 5, 8]:    ### number of batches for training
    accs = []; 
    b_tests = []; 
    methods = [];
    val_cors = []; 
    b_trains = [];
    for ib_test in range(len(uni_batches)):
        ix_test_spec = np.array(np.argwhere(batches==uni_batches[ib_test])[:, 0])
        
        b_sel = np.resize(True, len(uni_batches))
        b_sel[ib_test] = False
        
        bat = []
        for k in combinations(index[b_sel], n_b):
            bat.append(k)
        bat = np.row_stack(bat)
        bat = bat[np.array(range(0, bat.shape[0], np.max([1, bat.shape[0]//9]))), :]
    
    # for kk in range(bat.shape[0]):
        kk = 0
        k = bat[0,:]   ### get the training batches
        print(k)
        batch_sel = np.array([uni_batches[i] for i in k])
        
        ix = [] 
        for b in batch_sel:
            if len(ix)<1:
                ix = (batches==b)
            else:
                ix = (ix) | (batches==b)
        ix = np.array(np.argwhere(ix==True)[:,0])
        
        ix_test_spec = np.array(list(set(range(spec.shape[0])) - set(ix)))

        #### normal classification without model transfer
        model = cls.Classify(np.concatenate([spec[ix, :], spec[ix_test_spec, :]], axis=0), np.append(labels[ix], labels[ix_test_spec]), np.append(batches[ix], batches[ix_test_spec]))
        pred = model.model(range(len(ix)), range(len(ix), len(ix)+len(ix_test_spec)))
        acc = cls.cal_metric(labels[ix_test_spec], pred, batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('org', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

        ### do model transfer based on MS
        model = cls.Classify(np.concatenate([spec[ix, :], spec[ix_test_spec, :]], axis=0), np.append(labels[ix], labels[ix_test_spec]), np.append(batches[ix], batches[ix_test_spec]), True)
        pred = model.model(range(len(ix)), range(len(ix), len(ix)+len(ix_test_spec)))
        acc = cls.cal_metric(labels[ix_test_spec], pred, batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('MS', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))
        
        ### do model transfer based on EMSC
        interference, _ = cls.get_meanspec(np.concatenate([spec[ix, :], spec[ix_test_spec, :]], axis=0), labels=np.append(np.resize('org', len(ix)), np.resize('test', len(ix_test_spec))))
        ### EMSC with one component
        cur_spec_corr1 = emsc(np.concatenate([spec[ix, :], spec[ix_test_spec, :]], axis=0), degree=2, interferent=interference, interf_pca=1)['corrected']                
        model = cls.Classify(cur_spec_corr1, np.append(labels[ix], labels[ix_test_spec]), np.append(batches[ix], batches[ix_test_spec]))
        pred = model.model(range(len(ix)), range(len(ix), len(ix)+len(ix_test_spec)))
        acc = cls.cal_metric(labels[ix_test_spec], pred, batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('EMSC1', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))
        ### EMSC with two components
        cur_spec_corr2 = emsc(np.concatenate([spec[ix, :], spec[ix_test_spec, :]], axis=0), degree=2, interferent=interference, interf_pca=2)['corrected']                
        model = cls.Classify(cur_spec_corr2, np.append(labels[ix], labels[ix_test_spec]), np.append(batches[ix], batches[ix_test_spec]))
        pred = model.model(range(len(ix)), range(len(ix), len(ix)+len(ix_test_spec)))
        acc = cls.cal_metric(labels[ix_test_spec], pred, batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('EMSC2', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))
        
        ### randomly choose from training data spectra to perform prediction with siamese network 
        smp_train = []
        for l in np.unique(labels[ix]):
            smp_train = np.append(smp_train, np.random.choice(ix[np.argwhere(labels[ix]==l)[:,0]], 10, replace=False))
        smp_train = np.array(smp_train, dtype='int32')
        
        args.batch=False; args.lr=0.00001
        args.input_size = (spec.shape[1], 1)   
        args.epochs=200
        m_nn = smsn.SiameseNetwork(args)    
        m_nn.train_cls(spec[ix, :], dummy_y[ix, :])   ### train ordinary neural network
        pred_test, acc, min_acc, std_acc, method, b_test = m_nn.pred_cls(spec[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec], uni_labels)
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('nn', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))
    
        n_pairs=40000
        x, y, cb, g = prep_training_cls(spec[ix,:], labels[ix], batches[ix], n_pairs, cb_shape=args.batch_shape)
        
        args.batch=False; args.group=False; args.lr=0.00001
        args.epochs=100
        m_smsn = smsn.SiameseNetwork(args)
        m_smsn.train_snet(x, y, g, cb, f_model='_weights.h5')  ### train siamese neural network
        ### predict spectra without EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], spec[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

        ### predict spectra with EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], cur_spec_corr1[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet-EMSC1', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

        args.batch=False; args.group=True; args.lr=0.00001
        args.epochs=100
        m_smsn = smsn.SiameseNetwork(args)
        m_smsn.train_snet(x, y, g, cb, f_model='_weights.h5') ### train siamese neural network (batch-loss)
        
        ### predict spectra without EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], spec[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet-g', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

        ### predict spectra with EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], cur_spec_corr1[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet-g-EMSC1', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))
        
        args.batch=True; args.group=False; args.lr=0.00001
        m_smsn = smsn.SiameseNetwork(args)
        m_smsn.train_snet(x, y, g, cb, f_model='_weights.h5') ### train siamese neural network (group-loss)
        ### predict spectra without EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], spec[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet-b', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

        ### predict spectra with EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], cur_spec_corr1[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet-b-EMSC1', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

        args.batch=True; args.group=True; args.lr=0.00001
        m_smsn = smsn.SiameseNetwork(args)
        m_smsn.train_snet(x, y, g, cb, f_model='_weights.h5') ### train siamese neural network (group- and batch-loss)
        ### predict spectra without EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], spec[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet-gb', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

        ### predict spectra with EMSC
        pred_test, acc, min_acc, std_acc, method, b_test = m_smsn.pred_snet(spec[smp_train,:], labels[smp_train], cur_spec_corr1[ix_test_spec, :], labels[ix_test_spec], batches[ix_test_spec])
        accs = np.append(accs, acc) 
        b_tests = np.append(b_tests, np.unique(batches[ix_test_spec]))
        methods = np.append(methods, np.resize('snet-gb-EMSC1', len(acc)))
        b_trains = np.append(b_trains, np.resize(kk, len(acc)))

    db_methods = ['org', 'nn', 'snet', 'snet-b', 'snet-g', 'snet-gb', 'MS', 'EMSC1', 'EMSC2', 'snet-EMSC1', 'snet-g-EMSC1', 'snet-b-EMSC1', 'snet-gb-EMSC1']
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    cols = list(mcolors.TABLEAU_COLORS)
    acc_all = []
    for i in range(len(db_methods)):
        for j in range(len(uni_batches)):
            acc_all = np.append(acc_all, np.mean(accs[(methods==db_methods[i]) & (b_tests==uni_batches[j])]))
    acc_all = np.reshape(acc_all, (len(db_methods), len(uni_batches)))
    sns.boxplot(data=acc_all.T, width=0.5, ax=ax)
    ax.set_xticks(range(len(db_methods)), db_methods, rotation=-45)
    ax.set_ylabel('balanced accuracy')
    ax.set_title('Cross-Batch on Real Data')  
    plt.show()
    
    with open(dirname(abspath(os.getcwd())) + '/results/accs_mt1_org_' + str(n_b) + '_Batch.pkl', 'wb') as f:
        pickle.dump([accs, methods, b_trains, b_tests], f)  