In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTFeatureExtractor, ViTModel, ViTConfig, DistilBertModel, DistilBertConfig
from tqdm.notebook import tqdm
from torch.autograd import Variable
from datetime import datetime, timedelta
import models
import data_preparation

In [2]:
X_train = np.load('./data/X_train_surge_new.npz')
Y_train = pd.read_csv('./data/Y_train_surge.csv')
X_test = np.load('./data/X_test_surge_new.npz')

In [3]:
model_scale_finder = models.ScaleFinder()

epochs = 100
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model_scale_finder.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

In [4]:
train_scale_data, val_scale_data = data_preparation.data_preparation_means_std(X_train, Y_train)

In [6]:
for epoch in range(epochs):
    model_scale_finder.train()
    for x1, y in tqdm(train_scale_data, total = len(train_scale_data), leave=False):
        # x1, y = x1.to(device), y.to(device)
        x1 = x1.type(torch.FloatTensor)
        y = y.type(torch.FloatTensor)
        optimizer.zero_grad()
        pred = model_scale_finder(x1)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
    scheduler.step()
    model_scale_finder.eval()
    val_loss = 0
    with torch.no_grad():
        for x1, y in tqdm(val_scale_data, total = len(val_scale_data), leave = False):
            # x1, y = x1.to(device), y.to(device)
            x1 = x1.type(torch.FloatTensor)
            y = y.type(torch.FloatTensor)
            pred = model_scale_finder(x1)
            loss = criterion(pred, y)
            val_loss += loss.item()
    val_loss /= (len(val_scale_data))
    # print(f'Epoch {epoch+1}: Validation Loss = {val_loss}', )

  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 1: Validation Loss = 0.5364448645285198


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 2: Validation Loss = 0.4583775275519916


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 3: Validation Loss = 0.41318939817803246


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 4: Validation Loss = 0.39067767539194653


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 5: Validation Loss = 0.3818125311817442


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 6: Validation Loss = 0.3792737867150988


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 7: Validation Loss = 0.3787987508944103


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 8: Validation Loss = 0.3787896603345871


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 9: Validation Loss = 0.3787720520581518


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 10: Validation Loss = 0.37884558737277985


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 11: Validation Loss = 0.3788994429366929


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 12: Validation Loss = 0.37876220771244595


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 13: Validation Loss = 0.3787845888308116


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 14: Validation Loss = 0.3788177134735244


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 15: Validation Loss = 0.378871990953173


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 16: Validation Loss = 0.3788195931485721


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 17: Validation Loss = 0.37883641655955996


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 18: Validation Loss = 0.3789273589849472


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 19: Validation Loss = 0.3789526879787445


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 20: Validation Loss = 0.3788748807140759


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 21: Validation Loss = 0.37882724702358245


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 22: Validation Loss = 0.37887233644723894


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 23: Validation Loss = 0.37894168730293004


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 24: Validation Loss = 0.3788209527730942


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 25: Validation Loss = 0.3789122545293399


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 26: Validation Loss = 0.3788256675004959


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 27: Validation Loss = 0.37876493973391395


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 28: Validation Loss = 0.3788256281188556


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 29: Validation Loss = 0.37877131402492525


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 30: Validation Loss = 0.3787470651524408


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 31: Validation Loss = 0.3787822214620454


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 32: Validation Loss = 0.37873836159706115


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 33: Validation Loss = 0.3788224763103894


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 34: Validation Loss = 0.3787896047745432


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 35: Validation Loss = 0.37879185591425213


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 36: Validation Loss = 0.37880472178970065


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 37: Validation Loss = 0.37879635372332165


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 38: Validation Loss = 0.37889429458550045


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 39: Validation Loss = 0.37884526061160223


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 40: Validation Loss = 0.3788240156003407


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 41: Validation Loss = 0.3788451565163476


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 42: Validation Loss = 0.3788122981786728


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 43: Validation Loss = 0.37883215802056447


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 44: Validation Loss = 0.37882509699889594


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 45: Validation Loss = 0.37879475951194763


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 46: Validation Loss = 0.3788204382572855


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 47: Validation Loss = 0.37884713475193293


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 48: Validation Loss = 0.3788529817547117


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 49: Validation Loss = 0.378814349429948


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 50: Validation Loss = 0.37882648621286663


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 51: Validation Loss = 0.37879473460572105


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 52: Validation Loss = 0.37879120920385634


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 53: Validation Loss = 0.37879452173198974


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 54: Validation Loss = 0.3787641455020223


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 55: Validation Loss = 0.3787907717483384


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 56: Validation Loss = 0.3788089411599295


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 57: Validation Loss = 0.3788113909108298


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 58: Validation Loss = 0.3788031190633774


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 59: Validation Loss = 0.3788207139287676


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 60: Validation Loss = 0.3788233124784061


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 61: Validation Loss = 0.3788183297429766


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 62: Validation Loss = 0.37883176377841404


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 63: Validation Loss = 0.37882514532123296


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 64: Validation Loss = 0.37881977387837


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 65: Validation Loss = 0.37882738794599263


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 66: Validation Loss = 0.378837197806154


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 67: Validation Loss = 0.3788239264062473


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 68: Validation Loss = 0.3788167072193963


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 69: Validation Loss = 0.3788244055850165


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 70: Validation Loss = 0.37882315473897116


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 71: Validation Loss = 0.37881929257086344


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 72: Validation Loss = 0.37881669125386647


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 73: Validation Loss = 0.3788257109267371


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 74: Validation Loss = 0.37882907624755585


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 75: Validation Loss = 0.37882516511848996


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 76: Validation Loss = 0.3788204608219011


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 77: Validation Loss = 0.378819896706513


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 78: Validation Loss = 0.378818596473762


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 79: Validation Loss = 0.37882584482431414


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 80: Validation Loss = 0.37882043783153807


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 81: Validation Loss = 0.3788245839732034


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 82: Validation Loss = 0.3788224386317389


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 83: Validation Loss = 0.37882310215915954


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 84: Validation Loss = 0.3788190103002957


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 85: Validation Loss = 0.3788211286067963


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 86: Validation Loss = 0.37882159948349


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 87: Validation Loss = 0.3788204531584467


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 88: Validation Loss = 0.3788208175982748


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 89: Validation Loss = 0.37882128485611505


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 90: Validation Loss = 0.37882159182003566


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 91: Validation Loss = 0.37882122546434405


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 92: Validation Loss = 0.37882145260061534


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 93: Validation Loss = 0.37882193773984907


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 94: Validation Loss = 0.37882130444049833


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 95: Validation Loss = 0.37882130593061447


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 96: Validation Loss = 0.3788213027375085


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 97: Validation Loss = 0.3788213695798601


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 98: Validation Loss = 0.37882138597113746


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 99: Validation Loss = 0.378821382352284


  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 100: Validation Loss = 0.378821382352284


---

In [7]:
model = models.PressureEncorderSemiFull(scale_finder=model_scale_finder)
# device = torch.device('cuda')
# model = model.to(device)

epochs = 10
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

w = torch.linspace(1, 0.1, 10)[np.newaxis]
def custom_weighted_losses(output, target):
    loss = torch.mean(w * (output[:10] - target[:10])**2)
    loss += torch.mean(w * (output[10:] - target[10:])**2)
    return loss

In [8]:
train_dataloader, val_dataloader = data_preparation.data_prepare_pretrain_semifull(X_train, Y_train)

In [9]:
for epoch in range(epochs):
    model.train()
    for x1, x2, x3, y in tqdm(train_dataloader, total = len(train_dataloader), leave=False):
        # x1, x2, x3, y = x1.to(device), x2.to(device), x3.to(device), y.to(device)
        x1 = x1.type(torch.FloatTensor)
        x2 = x2.type(torch.FloatTensor)
        x3 = x3.type(torch.FloatTensor)
        y = y.type(torch.FloatTensor)
        optimizer.zero_grad()
        pred = model((x1, x2, x3))
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()
    scheduler.step()
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for x1, x2, x3, y in tqdm(val_dataloader, total = len(val_dataloader), leave = False):
            # x1, x2, x3, y = x1.to(device), x2.to(device), x3.to(device), y.to(device)
            x1 = x1.type(torch.FloatTensor)
            x2 = x2.type(torch.FloatTensor)
            x3 = x3.type(torch.FloatTensor)
            y = y.type(torch.FloatTensor)
            pred = model((x1, x2, x3))
            loss = criterion(pred, y)
            val_loss += loss.item()
    val_loss /= (len(val_dataloader))
    print(f'Epoch {epoch+1}: Validation Loss = {val_loss}')

  0%|          | 0/630 [00:00<?, ?it/s]

  0%|          | 0/70 [00:00<?, ?it/s]

Epoch 1: Validation Loss = 0.7593996529068265


  0%|          | 0/630 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
dl = data_preparation.data_prepare_pretrain(X_test, Y_train, train=False)

In [None]:
model.eval()

ys = []
with torch.no_grad():
    for x1, x2, x3 in tqdm(dl, total = len(dl), leave = False):
        # x1, x2, x3, x4 = x1.to(device), x2.to(device), x3.to(device), x4.to(device)
        x1 = x1.type(torch.FloatTensor)
        x2 = x2.type(torch.FloatTensor)
        x3 = x3.type(torch.FloatTensor)
        ys.append(model((x1, x2, x3)))
surge_test = np.concatenate(ys, axis=0)

In [None]:
y_columns = [f'surge1_t{i}' for i in range(10)] + [f'surge2_t{i}' for i in range(10)]
y_test_benchmark = pd.DataFrame(data=surge_test, columns=y_columns, index=X_test['id_sequence'])
y_test_benchmark.to_csv('Y_test_benchmark_semi_full_1.csv', index_label='id_sequence', sep=',')