In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import random_split, Dataset, DataLoader, TensorDataset
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
import numpy as np
import copy
import random

print(torch.__version__)

2.2.2


In [5]:
# find correct device
if torch.backends.mps.is_available():
    print("using mps")
    device = torch.device("mps")
elif torch.cuda.is_available():
    print("using cuda")
    device = torch.device("cuda")
else:
    print("using cpu")
    device = torch.device("cpu")

using mps


In [6]:
bd_clean_file = "gt_clean_Bay5_101223.npz"
bd_model_file = "bd_clean_Bay5model.pt"
n_rounds = 2

In [7]:
data_dict = np.load(bd_clean_file)

In [8]:
ori32 = torch.tensor(data_dict["ori64"].transpose((2,1,0))[:,:,:].astype(np.float32))
sim32 = torch.tensor(data_dict["sim64"].transpose((2,1,0))[:,:,:].astype(np.float32))
ori32means = torch.tensor(data_dict["ori64means"].transpose((2,1,0))[:,:,:].astype(np.float32))
sim32means = torch.tensor(data_dict["sim64means"].transpose((2,1,0))[:,:,:].astype(np.float32))
ori32sigmas = torch.tensor(data_dict["ori64sigmas"].transpose((2,1,0))[:,:,:].astype(np.float32))
sim32sigmas = torch.tensor(data_dict["sim64sigmas"].transpose((2,1,0))[:,:,:].astype(np.float32))

In [9]:
# create dataset from input and output
dataset = TensorDataset(ori32,sim32,ori32means,sim32means,ori32sigmas,sim32sigmas)

In [16]:
#Split dataset 80-20
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
batch_size = 8
print(train_size, test_size)
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=test_size)

1948 488


In [17]:
print("train set correlation: ",torch.corrcoef(torch.stack((torch.flatten(train_dataset[:][0]),torch.flatten(train_dataset[:][1])),dim=0))[1,0])

train set correlation:  tensor(0.3652)


In [18]:
print("test set correlation: ",torch.corrcoef(torch.stack((torch.flatten(test_dataset[:][0]),torch.flatten(test_dataset[:][1])),dim=0))[1,0])

test set correlation:  tensor(0.3606)


In [19]:
# torch.nn.Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv1d(1, 18, 9, padding="same")
        self.conv2 = nn.Conv1d(18, 18, 9, padding="same")
        self.conv3 = nn.Conv1d(18, 1, 1, padding="same")
        self.bn = nn.BatchNorm1d(18)
        self.dropout1 = nn.Dropout(0.10)
        self.dropout2 = nn.Dropout(0.5)

    def forward(self, x):
        x = self.conv1(x)
        x = F.sigmoid(x)
        
        x = self.conv2(x)
        x = self.bn(x)
        x = F.sigmoid(x)

        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.bn(x)
        x = F.sigmoid(x)

        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.bn(x)
        x = F.sigmoid(x)
        
        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.bn(x)
        x = F.sigmoid(x)

        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.bn(x)
        x = F.sigmoid(x)
        
        x = self.dropout1(x)

        x = self.conv2(x)
        x = self.bn(x)
        x = F.sigmoid(x)

        x = self.dropout1(x)
        
        output = self.conv3(x)
        
        return output

In [20]:
model = Net()

In [21]:
def custom_loss(y_true,y_pred):
    SS_res =  torch.sum(torch.square(y_true - y_pred)) 
    SS_tot = torch.sum(torch.square(y_true - torch.mean(y_true))) 
    loss2 =  (1.0 - SS_res/(SS_tot + torch.finfo(torch.float32).eps) )
    return -loss2

In [22]:
def corr_loss(y_true,y_pred):
    c = torch.corrcoef(torch.stack((torch.flatten(y_true),torch.flatten(y_pred)),dim=0))[1,0]
    return -c/(1-c)

In [23]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target,_,_,_,_) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = corr_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx == int(train_size/batch_size):
            print('Train Epoch: {} \t\tCorr: {:.6f}'.format(
                epoch, loss.item()/(loss.item()-1)))

In [24]:
def test(model, device, test_loader):
    model.eval()
    corr = 0
    with torch.no_grad():
        for data, target,_,_,_,_ in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            corr = corr_loss(output,target)
    c = corr.item()/(corr.item()-1)
    return c

In [25]:
model = Net().to(device)

In [26]:
best_c = [0.0]
best_lr = []
best_eps = []
best_model = copy.deepcopy(model.state_dict())

In [50]:
# do this over and over
for i in range(n_rounds):
    for lr in [0.04,0.05,0.06]:
        for eps in [0.0,1e-8,1e-7]:
            model = Net().to(device)
            print("running with lr = {:.6f} and eps = {:.6f}".format(lr,eps))
            optimizer = optim.Adam(model.parameters(),lr = lr, eps=eps)
            scheduler = StepLR(optimizer, step_size=20, gamma=0.9)
            train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
            for epoch in range(1, 1 + 50):
                train(model, device, train_dataloader, optimizer, epoch)
                c = test(model, device, test_dataloader)
                if c > best_c[-1]:
                    print("***************** *************** Found better test model: {:.6f} ".format(c))
                    best_c.append(c)
                    best_lr.append(lr)
                    best_eps.append(eps)
                    best_model = copy.deepcopy(model.state_dict())
                scheduler.step()

running with lr = 0.040000 and eps = 0.000000
Train Epoch: 1 		Corr: 0.315691
Train Epoch: 2 		Corr: 0.168623
Train Epoch: 3 		Corr: -0.077547
Train Epoch: 4 		Corr: 0.481018
Train Epoch: 5 		Corr: 0.605436
Train Epoch: 6 		Corr: 0.462988
Train Epoch: 7 		Corr: 0.478323
Train Epoch: 8 		Corr: 0.089016
Train Epoch: 9 		Corr: -0.044931
Train Epoch: 10 		Corr: -0.025443
Train Epoch: 11 		Corr: 0.424390
Train Epoch: 12 		Corr: 0.579634
Train Epoch: 13 		Corr: 0.732461
Train Epoch: 14 		Corr: 0.616303
Train Epoch: 15 		Corr: 0.592260
Train Epoch: 16 		Corr: -0.020441
Train Epoch: 17 		Corr: 0.695645
Train Epoch: 18 		Corr: 0.342388
Train Epoch: 19 		Corr: 0.727683
Train Epoch: 20 		Corr: 0.731123
Train Epoch: 21 		Corr: 0.139955
Train Epoch: 22 		Corr: 0.403598
Train Epoch: 23 		Corr: 0.092625
Train Epoch: 24 		Corr: 0.164824
Train Epoch: 25 		Corr: 0.459185
Train Epoch: 26 		Corr: 0.280189
Train Epoch: 27 		Corr: 0.683152
Train Epoch: 28 		Corr: 0.554143
Train Epoch: 29 		Corr: 0.705616
Tr

In [51]:
best_c[-1]

0.41876758905218486

In [52]:
best_lr

[0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.04, 0.05, 0.05]

In [53]:
best_eps

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]

In [31]:
torch.save(best_model, bd_model_file)

In [34]:
# load model
model_saved = Net()
model_saved.load_state_dict(best_model)
model_saved.eval()

Net(
  (conv1): Conv1d(1, 18, kernel_size=(9,), stride=(1,), padding=same)
  (conv2): Conv1d(18, 18, kernel_size=(9,), stride=(1,), padding=same)
  (conv3): Conv1d(18, 1, kernel_size=(1,), stride=(1,), padding=same)
  (bn): BatchNorm1d(18, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
)

In [35]:
test_ori, test_sim, test_ori_mean, test_sim_mean, test_ori_sigmas, test_sim_sigmas = next(iter(test_dataloader))

In [36]:
test_cleaned = model_saved(test_ori)
test_cleaned_denorm = test_cleaned * test_ori_sigmas + test_ori_mean
test_ori_denorm = test_ori * test_ori_sigmas + test_ori_mean
test_sim_denorm = test_sim * test_sim_sigmas + test_sim_mean

In [37]:
print("Corr test: ",torch.corrcoef(torch.stack((torch.flatten(test_ori),torch.flatten(test_sim)),dim=0))[1,0])

Corr test:  tensor(0.3606)


In [44]:
print("Corr test denormed: ",torch.corrcoef(torch.stack((torch.flatten(test_ori_denorm),torch.flatten(test_sim_denorm)),dim=0))[1,0])

Corr test denormed:  tensor(0.4299)


In [45]:
print("Corr test cleaned: ",torch.corrcoef(torch.stack((torch.flatten(test_cleaned),torch.flatten(test_sim)),dim=0))[1,0])

Corr test cleaned:  tensor(0.4188, grad_fn=<SelectBackward0>)


In [46]:
print("Corr test cleaned denormed: ",torch.corrcoef(torch.stack((torch.flatten(test_cleaned_denorm),torch.flatten(test_sim_denorm)),dim=0))[1,0])

Corr test cleaned denormed:  tensor(0.4945, grad_fn=<SelectBackward0>)


In [47]:
print("Corr dataset: ",torch.corrcoef(torch.stack((torch.flatten(dataset[:][0]),torch.flatten(dataset[:][1])),dim=0))[1,0])

tensor(0.3642)

In [49]:
print("Corr dataset cleaned: ",torch.corrcoef(torch.stack((torch.flatten(model_saved(dataset[:][0])),torch.flatten(dataset[:][1])),dim=0))[1,0])

Corr dataset cleaned:  tensor(0.4240, grad_fn=<SelectBackward0>)
