In [None]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import Variable
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from sklearn.preprocessing import StandardScaler
import mypackages.purple_teletubbies.purple_teletubbies as purple_teletubbies # =)
import numpy as np
import pandas as pd

In [None]:
df_full = pickle.load(open('data/dtba_prediction/featured_bindingdb', 'rb'))

feature = 'Kd'
drop_uncertain = True
classification = False

df_full.dropna(subset=[feature], inplace=True)
df_full.drop_duplicates(inplace=True)


plt.tight_layout()
fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
fig.set_figwidth(23)
fig.suptitle('Distribution of Label', fontsize=18)
ax1.hist(df_full[feature], bins=50)
ax1.set_ylabel('count')
ax1.set_xlabel('Kd (nM)')

df_full.loc[df_full[feature]==0, feature] = 1e-9
df_full[feature] = -np.log10(df_full[feature])
ax2.hist(df_full[feature], bins=50)
ax2.set_xlabel('pKd')

if drop_uncertain:
    df_full = df_full[df_full[feature+'_r'] == 0]

ax3.hist(df_full[feature], bins=50)
ax3.set_xlabel('pKd (after dropping uncertained values)')

df_full['active'] = ((df_full[feature]+df_full[feature+'_r']) > -3).map({True: 1, False: 0}).astype(np.int8)

In [None]:
dc_col = ['dc_'+str(i) for i in range(1, 112)]
sv_col = ['sv_'+str(i) for i in range(1, 101)]
pv_col = ['sv_'+str(i) for i in range(1, 101)]

df_full = df_full.sample(frac=1, replace=False, random_state=666).reset_index(drop=True)
X = df_full.drop(['id', 'ligand_smiles', 'target_name', 'Ki', 'IC50', 'Kd', 'EC50', 'active', 'Ki_r', 'Kd_r', 'IC50_r', 'pH', 'temp', 'EC50_r', 'dc_10', 'dc_11', 'dc_12', 'dc_13', 'dc_29'], axis=1).values
if classification:
    y = df_full['active'].values.reshape(-1, 1)
else:
    y = df_full[feature].astype(np.float64).values.reshape(-1, 1)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=69)

X_train = torch.from_numpy(X_train)
X_test = torch.from_numpy(X_test)
y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

if torch.cuda.is_available():
    print('using GPU')
    device = torch.device('cuda')
    X_train = X_train.to(device)
    y_train = y_train.to(device)
    X_test = X_test.to(device)
    y_test = y_test.to(device)

input_size = X.shape[1]

In [None]:
def train_model(model, optimizer, X_traini, y_traini, BATCH_SIZE, EPOCH, early_stop=True):
    if classification:
        criterion = torch.nn.BCEWithLogitsLoss()
    else:
        criterion = torch.nn.MSELoss()
    X_train, X_val, y_train, y_val = train_test_split(X_traini, y_traini, test_size=0.2, random_state=123)
    
    standardscaler = StandardScaler()
    X_train = torch.from_numpy(standardscaler.fit_transform(X_train))
    X_val = torch.from_numpy(standardscaler.transform(X_val))
    
    train = TensorDataset(X_train, y_train)
    train_loader = DataLoader(train, BATCH_SIZE, shuffle=False)

    all_train_loss = []
    all_val_loss = []
    if torch.cuda.is_available():
        print('using GPU')
        device = torch.device('cuda')
        model = model.to(device)
        criterion = criterion.to(device)
    
    best = 99999999999
    early_stopping_count = 0
    for epoch in range(EPOCH):
        train_loss = 0
        model.train()
        for X_train, y_train in train_loader:
            optimizer.zero_grad()
            y_pred = model(X_train.float(), True)
            loss = criterion(y_pred.double(), y_train.double())
            loss.backward()
            optimizer.step()
            train_loss += loss.data.item()
        train_loss /= len(train_loader)
        all_train_loss.append(train_loss)
        model.eval()
        with torch.no_grad():
            y_pred = model(X_val.float())
            val_loss = criterion(y_pred.double(), y_val.double())
            if val_loss < best:
                best = val_loss
                early_stopping_count = 0
            elif early_stop:
                early_stopping_count += 1
                if early_stopping_count >= 20:
                    break
        all_val_loss.append(val_loss)
        print(f'epoch {epoch}: \ntrain_loss: {train_loss}\nval_loss: {val_loss}\n============================================================')

    import matplotlib.pyplot as plt
    ax = plt.plot(list(range(len(all_train_loss))), np.array(all_train_loss), label='train_loss')
    plt.plot(list(range(len(all_val_loss))), np.array(all_val_loss), label='val_loss')
    plt.title('Training and Validation Loss', fontsize=18)
    plt.xlabel('epoch')
    plt.ylabel('MSELoss')
    plt.legend()
    return standardscaler

# model training

In [None]:
torch.manual_seed(7423)
model = purple_teletubbies()
optimizer = optim.Adam(model.parameters(), lr=0.00007)
standardscaler = train_model(model, optimizer, X_train, y_train, 512, 400)

In [None]:
X_test = torch.from_numpy(standardscaler.transform(X_test))
F.mse_loss(model(X_test.float()), y_test.float())

In [None]:
pickle.dump(standardscaler, open('model/standardscaler', 'wb+'))

In [None]:
torch.save(model.state_dict(), 'model/purple_teletubbies.model')