In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
import torch.optim as optim
import torch.backends.cudnn as cudnn

from torch.utils.data import TensorDataset, DataLoader, Dataset
from torch.autograd import Variable
%matplotlib inline

In [2]:
import numpy as np
import pandas as pd

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
def partial_correlation_score_torch_faster(y_true, y_pred):
    """Compute the correlation between each rows of the y_true and y_pred tensors.
    Compatible with backpropagation.
    """
    y_true_centered = y_true - torch.mean(y_true, dim=1)[:,None]
    y_pred_centered = y_pred - torch.mean(y_pred, dim=1)[:,None]
    cov_tp = torch.sum(y_true_centered*y_pred_centered, dim=1)/(y_true.shape[1]-1)
    var_t = torch.sum(y_true_centered**2, dim=1)/(y_true.shape[1]-1)
    var_p = torch.sum(y_pred_centered**2, dim=1)/(y_true.shape[1]-1)
    return cov_tp/torch.sqrt(var_t*var_p)

def correl_loss(pred, tgt):
    """Loss for directly optimizing the correlation.
    """
    return -torch.mean(partial_correlation_score_torch_faster(tgt, pred))

In [5]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fcs = nn.Sequential(
                    nn.Linear(228942, 512),
                    nn.ReLU(True),
                    nn.Linear(512, 256),
                    nn.ReLU(True),
                    nn.Linear(256, 128),
                    nn.ReLU(True),
                    nn.Linear(128, 23418),
                )
    def forward(self, input_data):
        feature = self.fcs(input_data)
        return feature

In [6]:
def partial_correlation_score_torch_faster(y_true, y_pred):
    y_true_centered = y_true - torch.mean(y_true, dim=1)[:,None]
    y_pred_centered = y_pred - torch.mean(y_pred, dim=1)[:,None]
    cov_tp = torch.sum(y_true_centered*y_pred_centered, dim=1)/(y_true.shape[1]-1)
    var_t = torch.sum(y_true_centered**2, dim=1)/(y_true.shape[1]-1)
    var_p = torch.sum(y_pred_centered**2, dim=1)/(y_true.shape[1]-1)
    return cov_tp/torch.sqrt(var_t*var_p)

def correl_loss(pred, tgt):
    return -torch.mean(partial_correlation_score_torch_faster(tgt, pred))

In [7]:
def Average(lst):
    return sum(lst) / len(lst)

def train_one_epoch(loader_size, batch_size, net, criterion, optimizer, scheduler, use_cuda):
    loss_lst = []
    for i in range(0, 50000//loader_size):
        train_input = pd.read_hdf('./train_multi_inputs.h5', start= loader_size * i, stop=loader_size * i + loader_size)
        train_target = pd.read_hdf('./train_multi_targets.h5', start= loader_size * i, stop=loader_size * i + loader_size)
        x_train = train_input.to_numpy()
        y_train = train_target.to_numpy()
        x_train = torch.Tensor(x_train).to(device)
        y_train = torch.Tensor(y_train).to(device)

        train_dataset = TensorDataset(x_train, y_train)
        dataloader= DataLoader(train_dataset, shuffle = True, batch_size=batch_size, drop_last=True)

        loss_sum = 0
        for batch_idx, (inputs, target) in enumerate(dataloader):
            optimizer.zero_grad()
            inputs, target = Variable(inputs), Variable(target)
            
            pred_target = net(inputs)
            
            loss = criterion(target, pred_target)
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                loss_sum += loss.detach().cpu().numpy()
        
        loss_lst.append(loss_sum)
        print("#%d load, loss sum: %.4f" % (i, loss_sum))
        if(i % 3 == 0): scheduler.step()

    return loss_lst

def val_one_epoch(loader_size, batch_size, net, use_cuda):
    loss_lst = []
    partial_correlation_scores = []
    for i in range(75000//loader_size, 100000//loader_size):
        train_input = pd.read_hdf('./train_multi_inputs.h5', start= loader_size * i, stop=loader_size * i + loader_size)
        train_target = pd.read_hdf('./train_multi_targets.h5', start= loader_size * i, stop=loader_size * i + loader_size)
        x_train = train_input.to_numpy()
        y_train = train_target.to_numpy()

        x_train = torch.Tensor(x_train).to(device)
        y_train = torch.Tensor(y_train).to(device)

        train_dataset = TensorDataset(x_train, y_train)
        dataloader = DataLoader(train_dataset, shuffle = True, batch_size=batch_size, drop_last=True)

        for batch_idx, (inputs, target) in enumerate(dataloader):
            inputs, target = Variable(inputs), Variable(target)

            with torch.no_grad():
                pred_target = net(inputs)
            
            score = partial_correlation_score_torch_faster(target, pred_target)
            partial_correlation_scores.append(score)

    partial_correlation_scores = torch.cat(partial_correlation_scores)
    score = torch.sum(partial_correlation_scores).cpu().item()/len(partial_correlation_scores)
    
    return score

In [8]:
criterion = torch.nn.MSELoss()

net = Net()
use_cuda = False
if (torch.cuda.is_available()):
    torch.cuda.manual_seed_all(42)
    cudnn.benchmark = True
    net = net.cuda()
    # criterion = criterion.cuda()
    use_cuda = True

criterion = correl_loss
optimizer = optim.Adam(net.parameters(), lr=0.001) 
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

In [9]:
epoch_loss_lst = []
for epoch in range (0, 10):
    train_loss_lst = train_one_epoch(2048, 256, net, criterion, optimizer, scheduler, True)

    epoch_loss = Average(train_loss_lst)
    epoch_loss_lst.append(epoch_loss)
    print("Train: ep: %d, loss: %.4f" % (epoch, epoch_loss))

    val_score = val_one_epoch(2048, 256, net, use_cuda)
    print("Val: ep: %d, score: %.4f" % (epoch, val_score))


#0 load, loss sum: -1.8684
#1 load, loss sum: -4.3051
#2 load, loss sum: -5.0066
#3 load, loss sum: -5.1047
#4 load, loss sum: -5.0796
#5 load, loss sum: -5.1575
#6 load, loss sum: -5.1784
#7 load, loss sum: -5.2303
#8 load, loss sum: -5.3288
#9 load, loss sum: -5.3516
#10 load, loss sum: -5.3715
#11 load, loss sum: -5.3689
#12 load, loss sum: -5.3662
#13 load, loss sum: -4.6091
#14 load, loss sum: -4.6498
#15 load, loss sum: -4.6947
#16 load, loss sum: -4.8574
#17 load, loss sum: -5.1407
#18 load, loss sum: -5.2734
#19 load, loss sum: -5.3299
#20 load, loss sum: -5.3161
#21 load, loss sum: -5.1964
#22 load, loss sum: -5.2335
#23 load, loss sum: -5.2452
Train: ep: 0, loss: -4.9693
Val: ep: 0, score: 0.6527
#0 load, loss sum: -5.3317
#1 load, loss sum: -5.3788
#2 load, loss sum: -5.3975
#3 load, loss sum: -5.2900
#4 load, loss sum: -5.1605
#5 load, loss sum: -5.2005
#6 load, loss sum: -5.2059
#7 load, loss sum: -5.2450
#8 load, loss sum: -5.3290
#9 load, loss sum: -5.3495
#10 load, loss

In [None]:
torch.save(net.state_dict(), "./multi_MSE.pt")