In [17]:
import codecs
import argparse
import os

import matplotlib.pyplot as plt
import tqdm
import numpy as np
import glob
from sklearn.model_selection import train_test_split, cross_val_score, validation_curve
import torch
import torch.nn as nn
import torch.nn.utils.rnn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torch.optim.lr_scheduler import ExponentialLR
import codecs
from torchsummary import summary
# from torchshape import tensorshape
from pprint import pprint
# from torch.utils.tensorboard import SummaryWriter

hyper = {
    "randomSeed": 42,
    "nEpochs":70,
    "PATH": "model.pt",
    "batchSize":8,
    "lr":1e-3,
    "clip_grad":0.1,

}

torch.manual_seed(hyper["randomSeed"]);
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_binary_file(file_name, dimension):
    fid_lab = open(file_name, 'rb')
    features = np.fromfile(fid_lab, dtype=np.float32)
    fid_lab.close()
    assert features.size % float(dimension) == 0.0, 'specified dimension %s not compatible with data' % (dimension)
    features = features[:(dimension * (features.size // dimension))]
    features = features.reshape((-1, dimension))
    return features

def load_binary_file_frame(self, file_name, dimension):
    fid_lab = open(file_name, 'rb')
    features = numpy.fromfile(fid_lab, dtype=numpy.float32)
    fid_lab.close()
    assert features.size % float(dimension) == 0.0, 'specified dimension %s not compatible with data' % (dimension)
    frame_number = features.size // dimension
    features = features[:(dimension * frame_number)]
    features = features.reshape((-1, dimension))

    return features, frame_number



dir_phonePath = "/home/rania/Documents/workspace/tools/merlin/output"
dir_PPGpath = "/home/rania/Documents/workspace/IntraSpkVC/data_fr/ppg/FFR0009"
phone_data = []
PPG_data = []

PPG_lengths = []
phone_lengths = []


def load_data(dir_path, x):
    
    file_list = glob.glob(dir_path + '/*.'+x+'')[:48]
    
    for idx, s in enumerate(file_list):
        if x == 'flab':
            data =torch.tensor(load_binary_file(s, 49))
            print(data.shape)
            data + torch.arange(2386)
            print("cccccccccccccccccc", data.shape)

#           print(len(data[:,0]))
            phone_data.append(data)
            phone_lengths.append(len(data))
        else:
            data = np.load(s)
            PPG_data.append(data)
            PPG_lengths.append(len(data))

    return phone_data, phone_lengths, PPG_data, PPG_lengths

phone_data, phone_lengths, _, _ = load_data(dir_phonePath, 'flab')
print("phone_data {}".format(len(phone_data)  )    )


class Dataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.data = pad_sequence(data, batch_first=True, padding_value=0)
    
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx].clone().detach().requires_grad_(True)
        x = torch.unsqueeze(x, 0)
        return x  

# _, _, PPG_data, PPG_lengths = load_data(dir_phonePath, 'npy')

padded_phone = Dataset(phone_data)


for i, item in enumerate(padded_phone):
    print(item.shape)


trainDataX, testDataX,trainDataY, testDataY = train_test_split([data for data in padded_phone ],[data for data in padded_phone ],test_size=0.2,random_state=41,shuffle=False,stratify=None)
# print("size of the training dataset {}".format(len(trainDataX)), trainDataX[2].shape)
# print("size of the training dataset {}".format(len(trainDataX)), trainDataX[3].shape)
valDataX, testDataX, valDataY, testDataY = train_test_split(testDataX,testDataY,test_size=0.5,random_state=41,shuffle=False,stratify=None)
# print("size of the validation dataset {}".format(len(valDataX)), valDataY[0].shape)
# print("size of the validation dataset {}".format(len(valDataX)), valDataY[1].shape)

def masking(data):
    pad = 0
    dataMask = (data == pad).type(torch.int16)
    dataMask = 1 - dataMask
    maskedData = torch.mul(dataMask, data)
    return maskedData

def train(trainLoader, model, criterion, optimizer,epoch =hyper["nEpochs"]):
    model.train()
    totalLoss, batchNum = 0, 0
    for i, trainData in enumerate(trainLoader):
        trainData = Variable(trainData).to(device)
        # ===================forward=====================
        output = model(trainData)
        loss = criterion(masking(output), trainData)
        # print(loss)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        totalLoss += loss.item()
        batchNum += 1
    trainTotalCount = totalLoss / batchNum
    # ===================log========================
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, epoch, trainTotalCount))
    return trainTotalCount

def valid(valLoader, model, criterion, epoch=hyper["nEpochs"]):
    model.eval()
    with torch.no_grad():
        totalLoss, batchNum = 0, 0
        for i, valData in enumerate(valLoader):
            valData = Variable(valData).to(device)
            # ===================forward=====================
            output = model(valData)
            loss = criterion(masking(output), valData)
            totalLoss += loss.item()
            batchNum += 1
        ValTotalCount = totalLoss / batchNum
    print('epoch [{}/{}], Validation loss:{:.4f}'.format(epoch + 1, epoch, ValTotalCount))
    return ValTotalCount


class SaveBestModel:
    """
    Class to save the best model while training. If the current epoch's
    validation loss is less than the previous least less, then save the
    model state.
    """

    def __init__(
            self, bestValidLoss = float('inf')
    ):
        self.bestValidLoss = bestValidLoss

    def __call__(
            self, currentValidLoss,
            epoch, model, optimizer, criterion
    ):
        if currentValidLoss < self.bestValidLoss:
            self.bestValidLoss = currentValidLoss
            print(f"\nBest validation loss: {self.bestValidLoss}")
            print(f"\nSaving best model for epoch: {epoch + 1}\n")
            torch.save({
                'epoch': hyper["nEpochs"] + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': criterion,
            }, 'outputs/best_model.pth')
saveBestModel = SaveBestModel()

def save_model(model, optimizer, criterion, epochs=hyper["nEpochs"]):
    """
    Function to save the trained model
    """
    print(f"Saving final model...")
    torch.save({
        'epoch': epochs,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': criterion,
    }, 'outputs/final_model.pth')


class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),  # b, 16, 1768, 14
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # b,  32, 884, 7
            nn.ReLU(),
            nn.Conv2d(32, 64, 7),  # b, 64, 878, 1
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),  # b, 32, 884, 7
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1,  output_padding=1),  # b, 16, 1768, 14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),  # b, 1, 3536, 28
            nn.Sigmoid()
        )
    def get_latent(self, x):
        x = self.encoder(x)
        return x

    def forward(self, x):
        x = self.encoder(x)
        y = self.decoder(x)
        return x, y
    
def get_data_loaders(train_set, val_set):
    
    #load data loader
    train_loader = DataLoader(
        train_set, 
        batch_size=hyper["batchSize"], 
        drop_last=False, 
        shuffle=True,
    )
    test_loader = DataLoader(
        val_set, 
        batch_size=hyper["batchSize"], 
        drop_last=False, 
        shuffle=False,
    )
    return train_loader, test_loader

trainLoader, valLoader = get_data_loaders(
        trainDataX, 
        valDataX,

    )

def main():

    # for i, data in enumerate(valLoader):
    #     print('val ', data.shape)
    model = Autoencoder().to(device)
    model = summary(model, (1, 3536, 28))
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=hyper["lr"],
                                 weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)

    for epoch in range(hyper["nEpochs"]):
        print(f"[INFO]: Epoch {epoch} of {hyper['nEpochs']}")
        # # ===================training========================
        trainEpochLoss = train(trainLoader, model, criterion, optimizer, epoch)
        # # ===================validation========================
        validEpochLoss = valid(valLoader, model, criterion, epoch)
        # # ===================checkpoints========================
        trainLoss, validLoss = [], []
        # start the training
        trainLoss.append(trainEpochLoss)
        validLoss.append(validEpochLoss)

        print(f"Training loss: {trainEpochLoss:.3f}")
        print(f"Validation loss: {validEpochLoss:.3f}")
        # save the best model till now if we have the least loss in the current epoch
        saveBestModel(
            validEpochLoss, epoch, model, optimizer, criterion
        )
        # save the trained model weights for a final time
        save_model(model, optimizer, criterion, epoch)
        print('TRAINING COMPLETE')


if __name__ == '__main__':
    main()



torch.Size([416, 49])
cccccccccccccccccc torch.Size([416, 49])
torch.Size([466, 49])
cccccccccccccccccc torch.Size([466, 49])
torch.Size([565, 49])
cccccccccccccccccc torch.Size([565, 49])
torch.Size([173, 49])
cccccccccccccccccc torch.Size([173, 49])
torch.Size([815, 49])
cccccccccccccccccc torch.Size([815, 49])
torch.Size([1273, 49])
cccccccccccccccccc torch.Size([1273, 49])
torch.Size([239, 49])
cccccccccccccccccc torch.Size([239, 49])
torch.Size([639, 49])
cccccccccccccccccc torch.Size([639, 49])
torch.Size([1341, 49])
cccccccccccccccccc torch.Size([1341, 49])
torch.Size([1642, 49])
cccccccccccccccccc torch.Size([1642, 49])
torch.Size([174, 49])
cccccccccccccccccc torch.Size([174, 49])
torch.Size([622, 49])
cccccccccccccccccc torch.Size([622, 49])
torch.Size([234, 49])
cccccccccccccccccc torch.Size([234, 49])
torch.Size([371, 49])
cccccccccccccccccc torch.Size([371, 49])
torch.Size([1101, 49])
cccccccccccccccccc torch.Size([1101, 49])
torch.Size([344, 49])
cccccccccccccccccc torch.

AttributeError: 'NoneType' object has no attribute 'parameters'

In [None]:
hyper = {
    "nEpochs":100,
    "dimRNA":dim_rna,
    "dimATAC":dim_atac,
    "layer_sizes":[1024, 512, 256],
    "nz":128,
    "batchSize":512,
    "lr":1e-3,
    "add_hinge":True,
    "lamb_hinge":10,
    "lamb_match":1,
    "lamb_nn":1.5,
    "lamb_kl":1e-9,
    "lamb_anc":1e-9,
    "clip_grad":0.1,
    "checkpoint_path": './checkpoint/vae_hinge.pt',
}

In [None]:
class StructureHingeLoss(nn.Module):
    def __init__(self, margin, max_val, lamb_match, lamb_nn, device):
        super(StructureHingeLoss, self).__init__()
        self.margin = margin
        self.max_val = max_val
        self.lamb_match = lamb_match
        self.lamb_nn = lamb_nn
        self.device = device
        
    def forward(self, rna_outputs, atac_outputs, nn_indices):
        #rna_outputs: n_batch x n_latent
        #atac_outputs: n_batch x n_latent
        assert rna_outputs.shape[0] == atac_outputs.shape[0]
        assert rna_outputs.shape[1] == atac_outputs.shape[1]
        n_batch = rna_outputs.shape[0]
        
        #calculated pairwise L2 distance
        #dist_rna_atac[i][j]: the L2 distance between RNA embedding i
        #and ATAC embedding j (n_batch x n_batch)
        #constraint for ensuring every rna embedding is close to matched atac embedding
        dist_rna_atac = torch.cdist(rna_outputs, atac_outputs, p=2)
        match_labels = torch.eye(n_batch).to(self.device)
        match_mask = match_labels > 0
        pos_match_dist = torch.masked_select(dist_rna_atac, match_mask).view(n_batch, 1)
        neg_match_dist = torch.masked_select(dist_rna_atac, ~match_mask).view(n_batch, -1)
        
        loss_match_rna = torch.clamp(self.margin + pos_match_dist - neg_match_dist, 0, self.max_val)
        loss_match_rna = loss_match_rna.mean()
        #print(f"loss_match_rna: {loss_match_rna}")
        
        #constraint for ensuring every atac embedding is close to matched rna embedding
        dist_atac_rna = dist_rna_atac.t()
        pos_match_dist = torch.masked_select(dist_atac_rna, match_mask).view(n_batch, 1)
        neg_match_dist = torch.masked_select(dist_atac_rna, ~match_mask).view(n_batch, -1)
        
        loss_match_atac = torch.clamp(self.margin + pos_match_dist - neg_match_dist, 0, self.max_val)
        loss_match_atac = loss_match_rna.mean()
        #print(f"loss_match_atac: {loss_match_atac}")
        
        #constraint for ensuring that every RNA embedding is close to 
        #the neighboring RNA embeddings.
        nn_masked = torch.zeros(n_batch, n_batch).to(self.device)
        nn_masked.scatter_(1, nn_indices, 1.)
        nn_masked = nn_masked > 0
        
        dist_rna_rna = torch.cdist(rna_outputs, rna_outputs, p=2)
        
        #pos_rna_nn_dist: n_batch x n_neighbor
        pos_rna_nn_dist = torch.masked_select(dist_rna_rna, nn_masked).view(n_batch, -1)
        neg_rna_nn_dist = torch.masked_select(dist_rna_rna, ~nn_masked).view(n_batch, -1)
        rna_nn_loss = torch.clamp(self.margin + pos_rna_nn_dist[...,None] - neg_rna_nn_dist[..., None, :], 0, self.max_val)
        rna_nn_loss = rna_nn_loss.mean()
        #print(f"rna_nn_loss: {rna_nn_loss}")
        
        #constraint for ensuring that every ATAC embedding is close to 
        #the neighboring ATAC embeddings.
        dist_atac_atac = torch.cdist(atac_outputs, atac_outputs, p=2)
        #pos_rna_nn_dist: n_batch x n_neighbor
        pos_atac_nn_dist = torch.masked_select(dist_atac_atac, nn_masked).view(n_batch, -1)
        neg_atac_nn_dist = torch.masked_select(dist_atac_atac, ~nn_masked).view(n_batch, -1)
        atac_nn_loss = torch.clamp(self.margin + pos_atac_nn_dist[...,None] - neg_atac_nn_dist[..., None, :], 0, self.max_val)
        atac_nn_loss = atac_nn_loss.mean()
        #print(f"atac_nn_loss: {atac_nn_loss}")
        
        loss = (self.lamb_match * loss_match_rna 
                + self.lamb_match * loss_match_atac
                + self.lamb_nn * rna_nn_loss 
                + self.lamb_nn * atac_nn_loss)
        return loss

In [None]:
#set up loss function
def basic_loss(recon_x, x, mu, logvar, lamb1):
    MSE = nn.MSELoss()
    lloss = MSE(recon_x, x)
    #KL divergence
    #KL_loss = -0.5*torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    #lloss = lloss + lamb1*KL_loss
    return lloss

#anchor loss for minimizing distance between paired observation
def anchor_loss(embed_rna, embed_atac):
    L1 = nn.L2Loss()
    anc_loss = L2(embed_rna, embed_atac)
    return anc_loss

def hinge_loss(
    margin, 
    max_val, 
    lamb_match,
    lamb_nn, 
    embed_rna, 
    embed_atac, 
    nn_indices,
):
    Hinge_Loss = StructureHingeLoss(margin, max_val, lamb_match, lamb_nn)
    loss = Hinge_Loss(embed_rna, embed_atac, nn_indices)
    return loss

In [None]:
trainDataX, valDataX,trainDataY, valDataY = train_test_split([data for data in padded_phone ],[data for data in padded_phone ],test_size=0.2,random_state=41,shuffle=False,stratify=None)
# print("size of the training dataset {}".format(len(trainDataX)), trainDataX[2].shape)

train_loader, test_loader = get_data_loaders(
        trainDataX, 
        valDataX,
        trainDataX,
        valDataX,
    )

# print(train_loader, test_loader)
# for i, item in enumerate(train_loader):
#     print(item.shape)

#load checkpoint
checkpoint = None
if path.exists(hyper["checkpoint_path"]):
    checkpoint = torch.load(hyper["checkpoint_path"])


#load basic models
netPPG = Autoencoder()
# netATAC = FC_VAE(n_input=hyper["dimATAC"], nz=hyper["nz"], layer_sizes=hyper["layer_sizes"])
if checkpoint != None:
    netRNA.load_state_dict(checkpoint["PPG_state_dict"])
#     netATAC.load_state_dict(checkpoint["net_atac_state_dict"])

if torch.cuda.is_available():
    print("using GPU")
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
netPPG.to(device)
# netATAC.to(device)

#setup optimizers for two nets
opt_netPPG = optim.Adam(list(netPPG.parameters()), lr=hyper["lr"])
# opt_netATAC = optim.Adam(list(netATAC.parameters()), lr=hyper["lr"])
scheduler_netPPG = optim.lr_scheduler.ReduceLROnPlateau(
    opt_netPPG,
    patience=10,
    threshold=0.01,
    threshold_mode="abs",
    min_lr=1e-5,
)
# scheduler_netATAC = optim.lr_scheduler.ReduceLROnPlateau(
#     opt_netATAC,
#     patience=10,
#     threshold=0.01,
#     threshold_mode="abs",
#     min_lr=1e-5,
# )

best_knn_auc = 0
if checkpoint != None:
    best_knn_auc = checkpoint["dev_acc"]

#training
for epoch in range(hyper["nEpochs"]):
    train_losses = []
    #train for epochs
    for idx, (data) in enumerate(train_loader):
        
        data = Variable(data).to(device)

        opt_netPPG.zero_grad()
#         opt_netRNA.zero_grad()
        recon_PPG = netPPG(data)
#         recon_atac, z_atac, mu_atac, logvar_atac = netATAC(atac_inputs_filtered)
        ppg_loss = basic_loss(recon_PPG, data, 0, 0,0)
#         atac_loss = basic_loss(recon_atac, aftac_inputs_filtered, mu_atac, logvar_atac, lamb1=hyper["lamb_kl"])

        if hyper["add_hinge"]:
            hinge_loss = StructureHingeLoss(
                margin=0.3, 
                max_val=1e6, 
                lamb_match=hyper["lamb_match"], 
                lamb_nn=hyper["lamb_nn"],
                device=device,
            )
            h_loss = hinge_loss(z_rna, z_atac, nn_indices)
        '''if epoch % 5 == 0:
            print(f"rna_loss: {rna_loss}")
            print(f"atac_loss:{atac_loss}")
            print(f"anc_loss: {anc_loss}")
            print(f"hinge loss: {h_loss}")'''

        #loss functions for each modalities
        train_loss = rna_loss + atac_loss + hyper["lamb_hinge"] * h_loss
        #train_loss = rna_loss + atac_loss
        #train_loss = rna_loss + atac_loss + hyper["lamb_anc"] * anc_loss
        #rain_loss = rna_loss + atac_loss + hyper["lamb_anc"] * anc_loss + h_loss
        train_loss.backward()
        nn.utils.clip_grad_norm_(netRNA.parameters(), max_norm=hyper["clip_grad"])
        nn.utils.clip_grad_norm_(netATAC.parameters(), max_norm=hyper["clip_grad"])
        opt_netRNA.step()
        opt_netATAC.step()
        train_losses.append(train_loss.item())
    avg_train_loss = np.mean(train_losses)
    if epoch % 5 == 0:
        print("Epoch: " + str(epoch) + ", train loss: " + str(avg_train_loss))

        
        