In [None]:
# Important!: This codebase was initially written for internal use and requires manual path adjustments and environment setup.


import os
import math
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.nn.parameter import Parameter
from torch.autograd import Variable
from sklearn import metrics
from sklearn.model_selection import KFold, train_test_split
from scipy.stats import pearsonr
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

#Base code retrieved from J Cheminform 13, 7 (2021). https://doi.org/10.1186/s13321-021-00488-1


# Seed
SEED = 2333
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.set_device(1)
    torch.cuda.manual_seed(SEED)
    
    
# Model parameters
NUMBER_EPOCHS = 9
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1E-4
BATCH_SIZE = 1
NUM_CLASSES = 2

# GCN parameters
GCN_FEATURE_DIM = 30
GCN_HIDDEN_DIM = 128
GCN_OUTPUT_DIM = 32   
# Attention parameters
DENSE_DIM = 16
ATTENTION_HEADS = 4


def load_features(label_number):
    feature_matrix = np.load("C:\\Users\\johnkwon\\Desktop\\PDB_seed_AF3\\masking\\GCN_node_norm_padded_GG\\" + str(label_number) + '.npy').astype(np.float32)
    return feature_matrix


def load_graph(label_number): 
    edge_matrix = np.load("C:\\Users\\johnkwon\\Desktop\\PDB_seed_AF3\\masking\\GCN_contact_maps_norm_padded_GG\\" + str(label_number) + '.npy').astype(np.float32)
    return edge_matrix


class ProDataset(Dataset):

    def __init__(self, dataframe):
        self.label = dataframe['index'].values
        self.solubility = dataframe['Class'].values

    def __getitem__(self, index):
        sequence_label = self.label[index]
        solubility = self.solubility[index]
        # L * 30
        sequence_feature = load_features(sequence_label)
        # L * L
        sequence_graph = load_graph(sequence_label)
        return sequence_label, solubility, sequence_feature, sequence_graph

    def __len__(self):
        return len(self.solubility)
    
    
    
class GraphConvolution(nn.Module):

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = input @ self.weight    # X * W
        output = adj @ support           # A * X * W
        if self.bias is not None:        # A * X * W + b
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'

    
class GCN(nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        self.gc1 = GraphConvolution(GCN_FEATURE_DIM, GCN_HIDDEN_DIM)
        self.ln1 = nn.LayerNorm(GCN_HIDDEN_DIM)
        self.gc2 = GraphConvolution(GCN_HIDDEN_DIM, GCN_OUTPUT_DIM)
        self.ln2 = nn.LayerNorm(GCN_OUTPUT_DIM)
        self.relu1 = nn.LeakyReLU(0.1,inplace=True)
        self.relu2 = nn.LeakyReLU(0.1,inplace=True)

    def forward(self, x, adj):  			# x.shape = (seq_len, GCN_FEATURE_DIM); adj.shape = (seq_len, seq_len)
        x = self.gc1(x, adj)  				# x.shape = (seq_len, GCN_HIDDEN_DIM)
        x = self.relu1(self.ln1(x))
        x = self.gc2(x, adj)
        output = self.relu2(self.ln2(x))	# output.shape = (seq_len, GCN_OUTPUT_DIM)
        return output
class Attention(nn.Module):
    def __init__(self, input_dim, dense_dim, n_heads):
        super(Attention, self).__init__()
        self.input_dim = input_dim
        self.dense_dim = dense_dim
        self.n_heads = n_heads
        self.fc1 = nn.Linear(self.input_dim, self.dense_dim)
        self.fc2 = nn.Linear(self.dense_dim, self.n_heads)

    def softmax(self, input, axis=1):
        input_size = input.size()
        trans_input = input.transpose(axis, len(input_size) - 1)
        trans_size = trans_input.size()
        input_2d = trans_input.contiguous().view(-1, trans_size[-1])
        soft_max_2d = torch.softmax(input_2d, dim=1)
        soft_max_nd = soft_max_2d.view(*trans_size)
        return soft_max_nd.transpose(axis, len(input_size) - 1)

    def forward(self, input):  				# input.shape = (1, seq_len, input_dim)
        x = torch.tanh(self.fc1(input))  	# x.shape = (1, seq_len, dense_dim)
        x = self.fc2(x)  					# x.shape = (1, seq_len, attention_hops)
        x = self.softmax(x, 1)
        attention = x.transpose(1, 2)  		# attention.shape = (1, attention_hops, seq_len)
        return attention
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()

        self.gcn = GCN()  # Assuming GCN() is defined elsewhere in your code
        self.attention = Attention(GCN_OUTPUT_DIM, DENSE_DIM, ATTENTION_HEADS)  # Assuming Attention is defined
        self.fc_final = nn.Linear(GCN_OUTPUT_DIM, 1)  # Only 1 output for binary classification

        self.criterion = nn.BCEWithLogitsLoss()  # Binary cross-entropy loss
        self.optimizer = torch.optim.Adam(self.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    def forward(self, x, adj):
        x = x.float()
        x = self.gcn(x, adj)
        x = x.unsqueeze(0)
        att = self.attention(x)
        node_feature_embedding = att @ x
        node_feature_embedding_avg = torch.mean(node_feature_embedding, 1)
        output = self.fc_final(node_feature_embedding_avg)
        return output
        # No need to squeeze as we're now handling binary output directly

model = Model()

import torch
import torch.nn as nn
from tqdm import tqdm

def train_one_epoch(model, data_loader, epoch):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)  # Move model to the correct device

    criterion = nn.BCEWithLogitsLoss()  # Set up BCEWithLogitsLoss for binary classification
    epoch_loss_train = 0.0
    n_batches = 0

    for data in tqdm(data_loader):
        model.optimizer.zero_grad()
        _, solubility, sequence_features, sequence_graphs = data

        # Move data to device
        sequence_features = sequence_features.to(device).squeeze()
        sequence_graphs = sequence_graphs.to(device).squeeze()
        y_true = solubility.to(device).float()  # Ensure y_true is float for BCEWithLogitsLoss

        # Forward pass
        y_pred = model(sequence_features, sequence_graphs)  # Expecting y_pred shape: [batch_size, 1]

        # Calculate loss
        loss = criterion(y_pred.view(-1), y_true)   # Remove extra dimension if necessary

        # Backward pass and optimization step
        loss.backward()
        model.optimizer.step()
        

        # Accumulate loss
        epoch_loss_train += loss.item()
        n_batches += 1

    # Average loss for the epoch
    epoch_loss_train_avg = epoch_loss_train / n_batches
    return epoch_loss_train_avg

import torch
from tqdm import tqdm

def evaluate(model, data_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    epoch_loss = 0.0
    n_batches = 0
    valid_pred = []
    valid_true = []
    valid_label = []

    criterion = model.criterion  # Use the criterion defined in the model

    for data in tqdm(data_loader):
        with torch.no_grad():
            sequence_label, Class, sequence_features, sequence_graphs = data
            sequence_features = sequence_features.to(device).squeeze()
            sequence_graphs = sequence_graphs.to(device).squeeze()
            y_true = Class.to(device).float()

            # Forward pass
            y_pred = model(sequence_features, sequence_graphs)
            loss = criterion(y_pred.squeeze(0), y_true)

            # Apply sigmoid to logits to get probabilities
            y_pred_prob = torch.sigmoid(y_pred).cpu().numpy().tolist()
            y_true = y_true.cpu().numpy().tolist()

            valid_pred.extend(y_pred_prob)
            valid_true.extend(y_true)
            valid_label.extend(sequence_label)

            epoch_loss += loss.item()
            n_batches += 1

    epoch_loss_avg = epoch_loss / n_batches

    return epoch_loss_avg, valid_true, valid_pred, valid_label

def train(model, train_dataframe, valid_dataframe, fold=0):
    train_loader = DataLoader(dataset=ProDataset(train_dataframe), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    valid_loader = DataLoader(dataset=ProDataset(valid_dataframe), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    train_losses = []
    train_pearson = []
    train_r2 = []
    train_acc = []
    train_precision = []
    train_recall = []
    train_f1 = []
    train_auc = []
    train_mcc = []
    train_sensitivity = []
    train_specificity = []

    valid_losses = []
    valid_pearson = []
    valid_r2 = []
    valid_acc = []
    valid_precision = []
    valid_recall = []
    valid_f1 = []
    valid_auc = []
    valid_mcc = []
    valid_sensitivity = []
    valid_specificity = []

    best_val_loss = 1000
    best_epoch = 0

    for epoch in range(NUMBER_EPOCHS):
        print("\n========== Train epoch " + str(epoch + 1) + " ==========")
        model.train()
        
        print(train_loader)

        epoch_loss_train_avg = train_one_epoch(model, train_loader, epoch + 1)
        print("========== Evaluate Train set ==========")
        _, train_true, train_pred, _ = evaluate(model, train_loader)
        result_train = analysis(train_true, train_pred)
        
        print("Train loss: ", np.sqrt(epoch_loss_train_avg))
        print("Train acc: ", result_train['accuracy'])
        print("Train precision: ", result_train['precision'])
        print("Train recall: ", result_train['recall'])
        print("Train F1: ", result_train['f1'])

        train_losses.append(np.sqrt(epoch_loss_train_avg))
        train_acc.append(result_train['accuracy'])
        train_precision.append(result_train['precision'])
        train_recall.append(result_train['recall'])
        train_f1.append(result_train['f1'])


        print("========== Evaluate Valid set ==========")
        epoch_loss_valid_avg, valid_true, valid_pred, valid_label = evaluate(model, valid_loader)
        result_valid = analysis(valid_true, valid_pred)
        
        print("Valid loss: ", np.sqrt(epoch_loss_valid_avg))
        print("Valid acc: ", result_valid['accuracy'])
        print("Valid precision: ", result_valid['precision'])
        print("Valid recall: ", result_valid['recall'])
        print("Valid f1: ", result_valid['f1'])
        
        valid_losses.append(np.sqrt(epoch_loss_valid_avg))
        valid_acc.append(result_valid['accuracy'])
        valid_precision.append(result_valid['precision'])
        valid_recall.append(result_valid['recall'])
        valid_f1.append(result_valid['f1'])


        print(epoch)
        if epoch==(NUMBER_EPOCHS-1):
            valid_final_dataframe = pd.DataFrame({'class_real': valid_true, 'class_predicted': valid_pred})
            valid_final_dataframe.to_csv("C:\\Users\\johnkwon\\Desktop\\PDB_seed_AF3\\binary\\GCN_glo_glu_CV\\" + 'Fold' + str(fold) + "EPOCH"+ str(NUMBER_EPOCHS)+"_valid_detail.csv", header=True, sep=',')
            print("saved!")


    # save calculation information
    result_all = {
        'Train_loss': train_losses,
        'Train_binary_acc': train_acc,
        'Train_precision': train_precision,
        'Train_recall': train_recall,
        'Train_f1': train_f1,
        'Valid_loss': valid_losses,
        'Valid_binary_acc': valid_acc,
        'Valid_precision': valid_precision,
        'Valid_recall': valid_recall,
        'Valid_f1': valid_f1,
        'Best_epoch': [best_epoch for _ in range(len(train_losses))]
    }
    
    result = pd.DataFrame(result_all)
    
    
    
    print("Fold", str(fold), "Best epoch at", str(best_epoch))
    result.to_csv("C:\\Users\\johnkwon\\Desktop\\PDB_seed_AF3\\binary\\GCN_glo_glu_CV\\" + "Fold" + str(fold) + "_result.csv", sep=',')
    ##############################4################################
            
def train(model, train_dataframe):
    train_loader = DataLoader(dataset=ProDataset(train_dataframe), batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

    train_losses = []
    train_r2 = []
    train_acc = []
    train_precision = []
    train_recall = []
    train_f1 = []

    best_val_loss = 1000
    best_epoch = 0

    for epoch in range(NUMBER_EPOCHS+1):
        print("\n========== Train epoch " + str(epoch + 1) + " ==========")
        model.train()
        
        epoch_loss_train_avg = train_one_epoch(model, train_loader, epoch + 1)
        print("========== Evaluate Train set ==========")
        _, train_true, train_pred, train_label = evaluate(model, train_loader)
        result_train = analysis(train_true, train_pred)
        print("Train loss: ", np.sqrt(epoch_loss_train_avg))
        print("Train binary acc: ", result_train['accuracy'])
        print("Train precision: ", result_train['precision'])
        print("Train recall: ", result_train['recall'])
        print("Train F1: ", result_train['f1'])


        train_losses.append(np.sqrt(epoch_loss_train_avg))
        train_acc.append(result_train['accuracy'])
        train_precision.append(result_train['precision'])
        train_recall.append(result_train['recall'])
        train_f1.append(result_train['f1'])
        print(epoch)
        print(NUMBER_EPOCHS)

        if epoch==NUMBER_EPOCHS:
            torch.save(model.state_dict(), os.path.join("C:\\Users\\johnkwon\\Desktop\\PDB_seed_AF3\\binary\\GCN_glo_glu_CV\\train_all.pkl")) ## 
            ##############################2############################
   

def analysis(y_true, y_pred):
    # Flatten y_pred if it contains nested lists
    y_pred_flat = [p[0] if isinstance(p, list) else p for p in y_pred]

    # Convert predictions to class labels based on a threshold of 0.5
    y_pred_labels = [1 if prob > 0.5 else 0 for prob in y_pred_flat]

    # Metrics
    accuracy = accuracy_score(y_true, y_pred_labels)
    precision = precision_score(y_true, y_pred_labels, average='binary')
    recall = recall_score(y_true, y_pred_labels, average='binary')
    f1 = f1_score(y_true, y_pred_labels, average='binary')
    conf_matrix = confusion_matrix(y_true, y_pred_labels)

    result = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'confusion_matrix': conf_matrix
    }
    return result

def train_all(all_dataframe):
    print("split_seed: ", SEED)
    sequence_label = all_dataframe['index'].values
    Class = all_dataframe['Class'].values
    model=Model()
    print(sequence_label)
    
    train(model, all_dataframe)
    
    
def cross_validation(all_dataframe,fold_number=5):
    print("split_seed: ", SEED)
    sequence_label = all_dataframe['index'].values
    Class = all_dataframe['Class'].values
    kfold = KFold(n_splits=fold_number, shuffle=True)
    fold = 0

    for train_index, valid_index in kfold.split(sequence_label, Class):
        print("\n========== Fold " + str(fold + 1) + " ==========")
        train_dataframe = all_dataframe.iloc[train_index, :]
        valid_dataframe = all_dataframe.iloc[valid_index, :]
        print("Training on", str(train_dataframe.shape[0]), "examples, Validation on", str(valid_dataframe.shape[0]),
              "examples")
        model = Model()
        if torch.cuda.is_available():
            model.cuda()

        train(model, train_dataframe, valid_dataframe, fold + 1)
        fold += 1

train_dataframe = pd.read_csv("C:\\Users\\johnkwon\\Desktop\\PDB_seed_AF3\\binary\\Data\\GCN_binary_glo_glu_train.csv")
train_all(train_dataframe)