In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import torch.cuda.amp as amp
import torch.optim as optim
from torchvision.ops import *
from torchvision.models import *
from torchsummary import summary
from matplotlib import pyplot as plt
# import bitsandbytes as bnb
import numpy as np
import pandas as pd
import scipy.stats
import os
import time
import copy
import random
from tqdm import tqdm, trange

import warnings
warnings.filterwarnings("ignore")

!nvidia-smi

# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
device = torch.device('cuda:0')
# device1 = torch.device('cuda:1')
print(torch.cuda.is_available())

In [None]:
import time
import math
import random
import numpy as np
import pandas as pd
from tqdm import *
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from multiprocessing import cpu_count
from concurrent.futures import ProcessPoolExecutor

def split_targets(targets):
    cores = cpu_count()
    part = len(targets)//cores
    
    parts = []
    for i in range(cores-1):
        temp = targets[i*part: (i+1)*part]
        parts.append(temp)
    parts.append(targets[(i+1)*part:])
    return parts, cores
    
# Gather Expression dataset
def GenData(paras):  # (gene_part, sub_tf, sub_tf_exp, genesexp_part, sub_gold, time_lag)
    error = 0
    all_tf, all_target = [], []
    gene_pair, exp_pair, labels = [], [], []
    tfs, targets, exp, gold, core = paras
    
    if core == 1:
        for i in trange(len(tfs)):
            tf = tfs[i]
            tf_exp = exp[tf]
            for target in targets:
                target_exp = exp[target]
                relation = [tf, target]
                gene_pair.append(relation)
                exp_pair.append(np.vstack((tf_exp, target_exp)))
                if relation in gold:
                    labels.append(1)
                else:
                    labels.append(0)
    else:
        for tf in tfs:
            tf_exp = exp[tf]
            for target in targets:
                target_exp = exp[target]
                relation = [tf, target]
                gene_pair.append(relation)
                exp_pair.append(np.vstack((tf_exp, target_exp)))
                if relation in gold:
                    labels.append(1)
                else:
                    labels.append(0)
    return (gene_pair, exp_pair, labels)

def processing(tf_raw, gene_raw, exp, gold):
    print('Multi-core processing...')
    start = time.time()
    if __name__ == '__main__':
        tf = np.intersect1d(tf_raw, gold[:,0].reshape(-1,))
        targets = np.intersect1d(gene_raw, gold[:,1].reshape(-1,))
        print('num of tfs:', len(tf), 'num of targets:', len(targets))
        
        # targets_part, cores = split_targets(targets) 
        # gene_pair, all_exp, all_labels = [], [], []
        # p = ProcessPoolExecutor(max_workers=cores)
        # process = [p.submit(GenData, (tf, targets_part[i], exp, gold.tolist(), i)) for i in range(cores)]
        # p.shutdown()
        # for j in range(cores):
        #     results = process[j].result()
        #     gene_pair.extend(results[0])
        #     all_exp.extend(results[1])
        #     all_labels.extend(results[2])

        gene_pair, all_exp, all_labels = GenData((tf, targets, exp, gold.tolist(), 1))

        end = time.time()
        RuningTime = end - start
        print('multiprocessing Done! Runing Time:', round(RuningTime / 60, 2), 'sec')

    return gene_pair, all_exp, all_labels

def split_datasets(labels):
    print('labels', np.sum(labels))
    pos_index, neg_index = [], []
    pos_index = [index for index, value in enumerate(labels) if value == 1]
    neg_index = [index for index, value in enumerate(labels) if value == 0]
    pos_shuffle, neg_shuffle = random.sample(pos_index, len(pos_index)), random.sample(neg_index, len(neg_index))
    pos_part, neg_part = len(pos_shuffle) // 5, len(neg_shuffle) // 5
    pos_train, neg_train = pos_shuffle[ :3*pos_part], neg_shuffle[ :3*neg_part]
    pos_val, neg_val = pos_shuffle[3*pos_part : 4*pos_part], neg_shuffle[3*neg_part : 4*neg_part]
    pos_test, neg_test = pos_shuffle[4*pos_part: ], neg_shuffle[4*neg_part: ]
    train_index = pos_train + neg_train
    val_index = pos_val + neg_val
    test_index = pos_test + neg_test

    return train_index, val_index, test_index



# Gather Expression dataset
class Feeder(Dataset):
    def __init__(self, exp_data, label, patch_size, mode='pretrain', base=128):
        assert mode=='pretrain' or mode=='lincls', 'mode should in [pretrain, lincls]'

        self.exp_data = exp_data[:, :, :exp_data.shape[-1]-exp_data.shape[-1] % patch_size]
        self.label = label
        self.arange = np.arange(exp_data.shape[-1])

    def __len__(self):
        return len(self.label)
        
    def __getitem__(self, index):
        
        tf, target = copy.deepcopy(self.exp_data[index][0]), copy.deepcopy(self.exp_data[index][1])
        label = self.label[index]
        
        X = np.concatenate([tf.reshape(1,-1), target.reshape(1,-1)], axis=0).astype(np.float16)
        return X, label
    
batch_size = 512
patch_size = 8

ExpressionData = pd.read_csv('../Benchmark Dataset/Non-Specific Dataset/hESC/TFs+1000/BL--ExpressionData.csv', index_col=0, engine='c')
network = pd.read_csv('../Benchmark Dataset/Non-Specific Dataset/hESC/TFs+1000/Label.csv', index_col=0, engine='c').values
tfs_raw = pd.read_csv('../Benchmark Dataset/Non-Specific Dataset/hESC/TFs+1000/TF.csv', index_col=0, engine='c')['index'].values.tolist()
targets_raw = pd.read_csv('../Benchmark Dataset/Non-Specific Dataset/hESC/TFs+1000/Target.csv', index_col=0, engine='c')['index'].values.tolist()

gene_pair, all_exp, labels = processing(tfs_raw, targets_raw, ExpressionData.values, network)

data = Feeder(np.array(all_exp), labels, patch_size, 'pretrain')
loader = DataLoader(data, batch_size=batch_size, shuffle=True)
labels_unique, counts = np.unique(labels, return_counts=True)
class_weight = [sum(counts)/i for i in counts]
print('labels', np.sum(labels), 'postive VS negative:', class_weight, 'Density:', round(class_weight[0]/sum(class_weight), 3))

In [None]:
# classes
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange
import math
        
def SinCosEmbed(seq_len, d_model, max_len=5000):
    pe = torch.zeros(max_len, d_model).float()
    pe.require_grad = False
    
    position = torch.arange(0, max_len).float().unsqueeze(1)
    div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
    
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    pe = pe.unsqueeze(0)
    return pe[:, :seq_len, :d_model]

class PosEmb(nn.Module):
    def __init__(self, num, dim, emb_dropout):
        super().__init__()
        self.num = num
        self.dropout = nn.Dropout(emb_dropout)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_emd = nn.Parameter(SinCosEmbed(num, dim), requires_grad=False)
        # self.pos_emd = nn.Parameter(SinCosEmbed(num+1, dim), requires_grad=False)
    
    def forward(self, x):
        # cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
        # x = torch.cat((x, cls_tokens), dim=1)
        return x + self.dropout(self.pos_emd)
        
class Logits(nn.Module):
    def __init__(self, cls_token=False):
        super().__init__()
        self.cls_token = cls_token
    def forward(self, x):
        out = x[:, -1] if self.cls_token else x.mean(dim = 1)
        return out

class MAE(nn.Module):
    def __init__(
                self,*,
                seq_len, 
                channels,
                patch_size,   
                dim = 192, 
                heads = 12,
                mlp_ratio = 4,
                cls_token = True,
                emded_grad = True,
                masking_ratio = 0.7,
                # encodoer paraments
                encoder_depth = 2, 
                dropout = .2, 
                emb_dropout = 0.,
                # decoder paraments
                decoder_depth = 2, 
                ):
        super().__init__()
        assert (seq_len % patch_size) == 0, 'seq_len must be divisible by patch_size'
        assert masking_ratio > 0 and masking_ratio < 1, 'masking ratio must be kept between 0 and 1'

        self.dim = dim
        self.cls = cls_token
        self.logits = Logits(cls_token)
        self.masking_ratio = masking_ratio
        num_patches = seq_len // patch_size * channels
        self.to_patch = Rearrange('b c (n p) -> b (c n) p', p = patch_size)
        
        self.patch_to_emb = nn.Sequential(
                                            # nn.LayerNorm(patch_size),
                                            nn.Linear(patch_size, dim, bias=False),
                                            nn.LayerNorm(dim),
                                            PosEmb(num_patches, dim, emb_dropout),
                                            )
        # xavier_uniform initialization
        nn.init.xavier_uniform_(self.patch_to_emb[0].weight)
        self.patch_to_emb[0].weight.requires_grad = emded_grad
                
        self.ln = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(emb_dropout)
        
        # extract some hyperparameters and functions from encoder
        EncoderLayer = nn.TransformerEncoderLayer(d_model=dim, 
                                                    nhead=heads,
                                                    dim_feedforward=int(dim*mlp_ratio),
                                                    dropout=dropout, 
                                                    activation='gelu',
                                                    # layer_norm_eps=1e-3,
                                                    batch_first=True,
                                                    norm_first=True,
                                                    )
        self.encoder = nn.TransformerEncoder(EncoderLayer, num_layers=encoder_depth)
        
        # decoder parameters
        self.mask_token = nn.Parameter(torch.randn(dim))
        DecoderLayer = nn.TransformerEncoderLayer(d_model=dim, 
                                                    nhead=heads,
                                                    dim_feedforward=int(dim*mlp_ratio),
                                                    dropout=dropout, 
                                                    activation='gelu',
                                                    # layer_norm_eps=1e-3,
                                                    batch_first=True,
                                                    norm_first=True,
                                                    )
        self.decoder = nn.TransformerEncoder(DecoderLayer, num_layers=decoder_depth)
        self.decoder_pos_emb = nn.Parameter(SinCosEmbed(num_patches+1, dim), requires_grad=False) if self.cls else nn.Parameter(SinCosEmbed(num_patches, dim), requires_grad=False)
        
        self.to_seqs = nn.Linear(dim, patch_size)
        
        # MSE and Cosine Similarity Loss 
        self.loss = nn.MSELoss()
        self.criterion = nn.CosineSimilarity(dim=1)

    def forward(self, series):
        device = series.device

        # get patches
        patches = self.to_patch(series)
        batch, num_patches, *_ = patches.shape

        # patch to encoder tokens and add positions
        tokens = self.patch_to_emb(patches)

        # calculate of patches needed to be masked, and get random indices, dividing it up for mask vs unmasked
        num_masked = int(self.masking_ratio * num_patches)

        batch_range = torch.arange(batch, device=device)[:, None]
        if self.cls:
            rand_indices = torch.rand(batch, num_patches+1, device=device)
            rand_indices[:, -1] = 1e+7
            rand_indices = rand_indices.argsort(dim = -1)
        else:
            rand_indices = torch.rand(batch, num_patches, device=device).argsort(dim = -1)
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]

        # target feature
        z = tokens[batch_range, masked_indices]
        z = self.encoder(z).mean(dim=1)
        
        # get the unmasked tokens to be encoded
        tokens = tokens[batch_range, unmasked_indices]

        # get the patches to be masked for the final reconstruction loss
        masked_patches = patches[batch_range, masked_indices]
        
        # attend with transformer
        encoded_tokens = self.encoder(tokens)

        # project encoder to decoder dimensions, if they are not equal
        # encoded_tokens += self.decoder_pos_emb[:, unmasked_indices]
        encoded_tokens = self.ln(encoded_tokens) + self.decoder_pos_emb[:, unmasked_indices]
            
        # repeat mask tokens for number of masked, and add the positions using the masked indices derived above
        mask_tokens = repeat(self.mask_token, 'd -> b n d', b = batch, n = num_masked)
        mask_tokens = mask_tokens + self.decoder_pos_emb[:, masked_indices]
        
        
        # concat the masked tokens to the decoder tokens
        if self.cls:
            decoder_tokens = torch.zeros(batch, num_patches+1, self.dim, device=device)
        else:
            decoder_tokens = torch.zeros(batch, num_patches, self.dim, device=device)
        
        decoder_tokens[batch_range, unmasked_indices] = encoded_tokens
        decoder_tokens[batch_range, masked_indices] = mask_tokens

        # attend with decoder
        decoded_tokens = self.decoder(decoder_tokens)

        # splice out the pred_features and pred_values
        mask_tokens = decoded_tokens[batch_range, masked_indices]
        pred_values = self.to_seqs(mask_tokens)
        p = mask_tokens.mean(dim=1)

        # calculate reconstruction loss
        recon_loss = self.loss(pred_values, masked_patches)
        # criterion = - (z.detach().softmax(dim=1) * p.log_softmax(dim=1)).mean()
        criterion = -self.criterion(p, z.detach()).mean()
        # criterion = self.ContrastiveLoss(p, z.detach())
        
        return criterion + recon_loss
    
model = MAE(seq_len=data[0][0].shape[1], 
            channels=2, 
            patch_size=patch_size,
            cls_token=False,)

# define optimizer
opt = optim.AdamW(model.parameters())
scaler = amp.GradScaler()

# define function to training
def Training(model, num_epochs, opt=opt, data_dl=loader):
    loss_history = []
    start_time = time.time()
    path2weights = './models/SIGMA_hESC1000+Non-Specific.pt'
    best_loss = 1e+7
    
    # 模型输出和loss计算
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        
        model.train()
        running_loss = 0
        for i, (x, _) in enumerate(tqdm(data_dl)):
            # retrieve query and key`````
            x = x.float().to(device, non_blocking=True)
            # compute output and loss
            with amp.autocast():
                loss = model(x)
                if loss == torch.nan:
                    break
                opt.zero_grad()
            # compute gradient and do SGD step
                scaler.scale(loss).backward()
                scaler.step(opt)
                scale = scaler.get_scale()
                scaler.update()
#             loss.backward()
#             opt.step()
            running_loss += loss
            
        if loss == torch.nan:
            break
        # store loss history
        epoch_loss = running_loss / (i+1)
        loss_history.append(epoch_loss.detach().cpu().numpy())
        print('train loss: %.6f, time: %.2f min' %(epoch_loss,(time.time()-start_time)/60))
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            best_model = copy.deepcopy(model)

    # save weights
    torch.save(model.state_dict(), path2weights)
    best_encoder = nn.Sequential(copy.deepcopy(best_model.to_patch),
                            copy.deepcopy(best_model.patch_to_emb),
                            copy.deepcopy(best_model.encoder),
                            # copy.deepcopy(best_model.ln),
                            copy.deepcopy(best_model.logits))
    # projector = copy.deepcopy(encoder.spatial.fc)# .fc
    torch.cuda.empty_cache()  # 释放显存
    return best_encoder, loss_history

In [None]:
# create folder to save model weights
os.makedirs('./models', exist_ok=True)
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')
        
# start training
num_epochs = 200
encoder, loss_history = Training(model.to(device), num_epochs=num_epochs)

In [None]:
# plot loss history
plt.title('Loss History')
plt.plot(range(1, num_epochs+1), loss_history, label='train')
plt.ylabel('Loss')
plt.xlabel('Training Epochs')
plt.legend()
plt.show()

In [None]:
# define Linear Classifier for transfer learning
class LinearClassifier(nn.Module):
    def __init__(self, backbone, finetune=True):
        super(LinearClassifier, self).__init__()
        self.backbone = backbone
        dim = self.backbone[1][0].weight.shape[0]
        self.fc1 = nn.Linear(dim, dim)
        self.fc2 = nn.Linear(dim, 1)
#         self.bn2 = nn.BatchNorm1d(dim)
#         self.relu2 = nn.ReLU(inplace=True)
#         self.fc3 = nn.Linear(dim, 1)

        self.bn = nn.BatchNorm1d(dim)
        self.relu = nn.ReLU(inplace=True)
        
        if finetune:
            self.ft = finetune
        else:
            self.ft = finetune
            for name, param in self.backbone.named_parameters():
                param.requires_grad = False
    
    def forward(self, x):
        if self.ft:
            x = self.backbone(x)
        else:
            with torch.no_grad():
                x = self.backbone(x)
        x = self.fc1(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.fc2(x)
#         x = self.bn(x)
#         x = self.relu(x)
#         x = self.fc3(x)
        return x

In [None]:
def downstream(model, linear_epoch=100):
    loss_hist = {'train':[], 'val':[]}
    start_time = time.time()
    path2weights = './models/lincls_weights.pt'
    max_auc = 0
    auc = 0
    from sklearn import metrics
    linear_scaler = amp.GradScaler()#weight=torch.FloatTensor(class_weight).to(device)
    linear_loss_func = nn.BCEWithLogitsLoss(pos_weight=torch.FloatTensor([class_weight[1]]).to(device))#nn.CrossEntropyLoss(weight=torch.HalfTensor(class_weight).to(device))
    # linear_opt =  optim.SGD(encoder.parameters(), lr=30, momentum=0.9, weight_decay=1e-4)#pos_weight=torch.FloatTensor(class_weight[1]).to(device)
    linear_opt =  optim.AdamW(model.parameters())
    # start training
    best_result = 0
    for epoch in range(linear_epoch):
        # clear_output()
        print('Epoch {}/{}'.format(epoch+1, linear_epoch))
    
        running_train_loss = 0
        running_val_loss = 0
        running_test_loss = 0
        train_pred, train_gold = [], []
        val_pred, val_gold = [], []
        # transfer dataloader
        model.train()
        for i, (x, y) in enumerate(tqdm(train_dl)):
            # retrieve query and key
            x = x.float().to(device)
            y = y.float().to(device)
            # extract features using linear_encoder
            with amp.autocast():
                pred = model(x)#.reshape(-1,)
                loss = linear_loss_func(pred, y.reshape(-1,1))
                if loss == torch.nan:
                    break
                linear_opt.zero_grad()
                linear_scaler.scale(loss).backward()
                linear_scaler.step(linear_opt)
                linear_scale = linear_scaler.get_scale()
                linear_scaler.update()
            temp_pos = y.int().detach().cpu().numpy()
            train_pred.extend(pred.detach().cpu().numpy())
            train_gold.extend(temp_pos)
            running_train_loss += loss
            
        if loss == torch.nan:
            break
        train_loss = running_train_loss / (i+1)
        loss_hist['train'].append(train_loss.detach().cpu().numpy())
        train_pred, train_gold = np.array(train_pred).reshape(-1), np.array(train_gold).reshape(-1)
        train_AUROC = metrics.roc_auc_score(train_gold, train_pred)
        train_AUPR = metrics.average_precision_score(train_gold, train_pred)
        # validation dataloader
        model.eval()
        for i, (x, y) in enumerate(val_dl):
            x = x.float().to(device)
            y = y.float().to(device)
            
            with torch.no_grad():
                with amp.autocast():
                    pred = model(x)
                    running_val_loss += loss
                val_pred.extend(pred.detach().cpu().numpy())
                val_gold.extend(y.int().detach().cpu().numpy())
        val_pred, val_gold = np.nan_to_num(val_pred).reshape(-1), np.nan_to_num(val_gold).reshape(-1)
        val_AUROC = metrics.roc_auc_score(val_gold, val_pred)
        val_AUPR = metrics.average_precision_score(val_gold, val_pred)
        
        val_loss = running_val_loss / (i+1)
        loss_hist['val'].append(val_loss.detach().cpu().numpy())
        print('train loss: %.6f, val loss: %.6f, AUROC score: %.4f, AUPR score: %.4f, time: %.4f min' %(train_loss, val_loss, val_AUROC, val_AUPR, (time.time()-start_time)/60))
        print('-'*10)
        if val_AUPR > best_result:
            best_loss = val_AUPR
            best_model = copy.deepcopy(model)
    
    best_model.eval()
    test_pred, test_gold = [], []
    for i, (x, y) in enumerate(tqdm(test_dl)):
        # retrieve query and key
        x = x.float().to(device)
        y = y.float().to(device)
        # extract features using q_encoder
        with torch.no_grad():
            with amp.autocast():
                linear_opt.zero_grad()
                pred = best_model(x)# .reshape(-1,)
                loss = linear_loss_func(pred, y.reshape(-1,1))
            test_pred.extend(pred.detach().cpu().numpy())
            test_gold.extend(y.int().detach().cpu().numpy())
    torch.cuda.empty_cache()
    test_loss = running_test_loss / (i+1)
    test_pred, test_gold = np.nan_to_num(test_pred).reshape(-1), np.nan_to_num(test_gold).reshape(-1)
    AUROC = metrics.roc_auc_score(test_gold, test_pred)
    AUPR = metrics.average_precision_score(test_gold, test_pred)
    # F1 = metrics.f1_score(test_gold, test_pred)
    # BAS = metrics.balanced_accuracy_score(test_gold, test_pred)
    # Accuracy = metrics.accuracy_score(test_gold, test_pred)
    print('Test Metric:',AUROC, AUPR)
    return [AUROC, AUPR]

In [None]:
# 5-fold test
train_index, val_index, test_index = split_datasets(labels)
exp_train = np.array(all_exp)[np.array(train_index)]
y_train = np.array(labels)[np.array(train_index)]
exp_val = np.array(all_exp)[np.array(val_index)]
y_val = np.array(labels)[np.array(val_index)]
exp_test = np.array(all_exp)[np.array(test_index)]
y_test = np.array(labels)[np.array(test_index)]

train_data = Feeder(np.array(exp_train), y_train, patch_size, 'lincls')
val_data = Feeder(np.array(exp_val), y_val, patch_size, 'lincls')
test_data = Feeder(np.array(exp_test), y_test, patch_size, 'lincls')
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True)

# without finetune (frozen backbone params)
Metric = downstream(model=LinearClassifier(backbone=copy.deepcopy(encoder), finetune=False).to(device), linear_epoch=100)
# with finetune (unfrozen backbone params)
# Metric = downstream(model=LinearClassifier(backbone=copy.deepcopy(encoder), finetune=True).to(device), linear_epoch=50)

In [None]:
print(Metric, np.mean(TestMetric, axis=0))