In [1]:
import os,sys
import platform
import numpy as np
import tensorflow as tf
from tensorflow.keras import optimizers
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.callbacks import EarlyStopping
from segnet import segnet
from generator import DataGenerator, DataLoader

In [2]:
print(tf.__version__) # 2.1.0

2.1.0


In [3]:
#import keras
print(tf.keras.__version__) # 2.2.4-tf

2.2.4-tf


In [4]:
# keep these constant for now
nv=int(2**98) # variants
na=2          # alleles
nc=7          # ancestry classes (incl. OCE)

In [5]:
# hyperparameters
ne=100        # number of epochs
nf=8          # number of segnet filters
fs=16         # filter size
bs=8          # batch size
ps=4          # pool size
dp=4          # number of segnet blocks (depth)
wt=False      # use inverse frequency weighted loss -- this doesn't seem to help
gen=True      # use data generator
rem=[]        # remove these ancestries: [0:AFR, 1:EAS, 2:EUR, 3:NAT, 4:OCE, 5:SAS, 6:WAS]
dev=True      # do we have a dev set (this is jank if we're removing any ancestries, above)
ac=1          # minimum allele count for inclusion
stoch=True    # randomly sample individuals during training (requires gen=True)
               # sampling is currently inverse-frequency weighted

# reproducibility
np.random.seed(23910464)

In [6]:
# sanity check that we're on gpu -- use #1 (if we're on galangal)
if platform.uname()[1]=='galangal.stanford.edu':
    os.environ["CUDA_VISIBLE_DEVICES"]="1" 
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [7]:
from train import load_train_set, load_dev_set, filter_ac

In [8]:
X, Y, S, V, train_ix, ix1, ix2 = load_train_set()
X_dev, Y_dev, S_dev = load_dev_set()

[(2764, 516801, 2), (2764, 516801, 7), (2764,), 2090]
[(120, 516801, 2), (120, 516801, 7)]


In [9]:
if False: # print sizes of allocated objects -- 
    # can be a nice sanity check, but also annoying to look at
    for x in dir():
        print(x, sys.getsizeof(eval(x)))

In [10]:
# filter oceanians, etc.
if rem:
    train_ix=[i for i in train_ix if Y[i,0,:].dot(np.arange(nc)) not in rem]
    if dev:
        Y_dev=Y_dev[:,:,[i for i in range(nc) if i not in rem]] # this is hella janky
anc=[i for i in range(nc) if i not in rem]

In [11]:
# filter variants by allele count
v=filter_ac(X[train_ix,:,:], ac=ac)
print(np.sum(v))
nv=np.sum(v)-(np.sum(v) % (ps**dp))
v=np.array([i and j <= nv for i,j in zip(v, np.cumsum(v))])
print(np.sum(v))
#np.random.shuffle(train_ix)
#np.save('/scratch/users/magu/aY.npy', Y[np.ix_(train_ix, v, anc)])
#np.save('/scratch/users/magu/aX.npy', X[np.ix_(train_ix, v, np.arange(na))])
#X=np.load('/scratch/users/magu/aX.npy', mmap_mode='r')
#Y=np.load('/scratch/users/magu/aY.npy', mmap_mode='r')
Y=Y[np.ix_(train_ix, v, anc)]
X=X[np.ix_(train_ix, v, np.arange(na))]
X_dev=X_dev[:,v,:]
Y_dev=Y_dev[:,v,:]
print(X.shape, X_dev.shape, Y.shape, Y_dev.shape)

447827
447744
(2090, 447744, 2) (120, 447744, 2) (2090, 447744, 7) (120, 447744, 7)


In [12]:
# declare model
model=segnet(input_shape=(X.shape[1], na), n_classes=nc-len(rem), n_filters=nf, width=fs, n_blocks=dp, pool_size=ps)

# and optimizer
adam=optimizers.Adam(lr=3e-4)#, beta_1=0.8)

In [None]:
# now compile and show parameter summary
model.compile(optimizer=adam, loss='categorical_crossentropy', metrics=['accuracy']) 
print(model.summary())

Model: "segnet"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 447744, 2)]  0                                            
__________________________________________________________________________________________________
dropout (Dropout)               (None, 447744, 2)    0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_down1 (Conv1D)            (None, 447744, 8)    264         dropout[0][0]                    
__________________________________________________________________________________________________
activation (Activation)         (None, 447744, 8)    0           conv1_down1[0][0]                
_____________________________________________________________________________________________

In [None]:
# now try it out!

# number weighted epochs -- will be [ne] if wt=False else [0, ne] (ne is set above)
nwe=ne 
nes=[ne-(wt*nwe)]         
cws=[np.ones((nc,))]
if wt:
    nes.append(nwe)
    cws.append(np.sqrt(Y.sum()/Y.sum(axis=0).sum(axis=0)))
    
# training
np.random.shuffle(train_ix)
for ne,cw in zip(nes,cws):
    if ne == 0: # don't waste time compiling the model if we aren't training it -- see above
        continue
    if dev:
        # we have a dev set; use it to monitor convergence
        es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=10)

        # fit with generator, or not
        if gen:
            params={'X':X, 'Y':Y, 'dim':nv, 'batch_size':bs, 'n_classes':nc, 'n_alleles':na}
            anc_fq=Y[:,0,:].sum(axis=0)
            generator=DataGenerator(np.arange(X.shape[0]), 
                                    sample=True, anc_wts=((1/anc_fq)/((1/anc_fq).sum())).flatten(),
                                    **params)
            history=model.fit_generator(generator=generator, epochs=ne, validation_data=(X_dev, Y_dev), 
                                        callbacks=[es], class_weight=cw)
        else:
            history=model.fit(X, Y, batch_size=bs, epochs=ne, 
                              validation_data=(X_dev, Y_dev), 
                              callbacks=[es], class_weight=cw)
    else:
        history=model.fit(X, Y, batch_size=bs, epochs=ne, class_weight=cw)

Instructions for updating:
Please use Model.fit, which supports generators.
  ...
    to  
  ['...']
Train for 261 steps, validate on 120 samples
Epoch 1/100


In [None]:
# save
model.save('chm20_short.dev.h5')

In [None]:
_, dev_acc = model.evaluate(X_dev, Y_dev, verbose=0)

# 1.1) plot loss during training
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(1, (9,9))
plt.subplot(211)
plt.title('Loss during training')
plt.plot(history.history['loss'], label='train set')
plt.plot(history.history['val_loss'], label='dev set')
plt.legend()

# 1.2) plot accuracy during training
plt.subplot(212)
plt.title('Accuracy')
plt.plot(history.history['accuracy'], label='train set')
plt.plot(history.history['val_accuracy'], label='dev set')
plt.legend()

print(dev_acc)

In [None]:
Y_hat_p=model.predict(X_dev)
Y_hat=np.argmax(Y_hat_p, axis=-1) # this is p naive tbh

In [None]:
Y_hat.shape

In [None]:
Y.sum(axis=0).sum(axis=0)/Y.sum().sum().sum()

In [None]:
Y_dev.sum(axis=0).sum(axis=0)/Y_dev.sum().sum().sum()

In [None]:
from sklearn.metrics import confusion_matrix
import pandas as pd

In [None]:
anc=[label for i,label in enumerate(['AFR','EAS','EUR','NAT','OCE','SAS','WAS']) if i in anc]

pd.DataFrame(confusion_matrix(Y_dev.dot(np.arange(nc-int(not oce))).flatten(), Y_hat.flatten())/(Y_hat.shape[0]*nv), 
             columns=anc, index=anc)

In [None]:
for i in range(Y_hat.shape[0]):
    print((i, [np.count_nonzero(Y_hat[i,:]==j) for j in range(Y_hat_p.shape[-1])], 
           [np.count_nonzero(Y_dev[i,:,:].argmax(axis=-1)==j) for j in range(Y_hat_p.shape[-1])]))

In [None]:
# this is completely arbitrary, i have no idea what it will show --Matt
iix=27

In [None]:
plt.figure(figsize=(12, 3))
Y_dev_lab = np.argmax(Y_dev, axis=-1)
plt.subplot(211)
plt.title('Dev set ground truths')
plt.imshow(Y_dev_lab[iix:iix+1,:].astype(int), aspect='auto', cmap='jet')

plt.subplot(212)
plt.title('Corresponding dev set predictions')
plt.imshow(Y_hat[iix:iix+1,:].astype(int), aspect='auto', cmap='jet')

In [None]:
np.where(np.argmax(Y_hat_p[iix,:,:], axis=-1)==3)

In [None]:
plt.figure(figsize=(12, 3))
Y_dev_lab = np.argmax(Y_dev, axis=-1)
plt.subplot(211)
plt.title('Dev set ground truths')
plt.imshow(Y_dev_lab[iix:iix+1,29000:29500].astype(int), aspect='auto')#, cmap='jet')

plt.subplot(212)
plt.title('Corresponding dev set predictions')
plt.imshow(Y_hat[iix:iix+1,29000:29500].astype(int), aspect='auto')#, cmap='jet')

In [None]:
Y_hat_p[iix,np.where(np.argmax(Y_hat_p[iix,:,:], axis=-1)==3),:]

In [None]:
# seems like a crf-smoother (even a post-hoc one) could really help
## todo: check out tf2crf