In [2]:
import numpy as np
import pandas as pd
import random
import torch
from torch import nn
from torch.nn import Linear, LayerNorm, ReLU, Dropout
from torch_geometric.nn import ChebConv, NNConv, DeepGCNLayer, EdgeConv
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
import os
import copy
from catalyst.dl import utils

from constants import FilePaths
# settings



In [3]:
seed = 777
train_file = '../input/stanford-covid-vaccine/train.json'
test_file = '../input/stanford-covid-vaccine/test.json'
bpps_top = '../input/stanford-covid-vaccine/bpps'
nb_fold = 5
device = 'cuda'
batch_size = 16
epochs = 100
lr = 0.0005
train_with_noisy_data = True
add_edge_for_paired_nodes = True
add_codon_nodes = True
T = 5
node_hidden_channels = 96
edge_hidden_channels = 16
hidden_channels3 = 32
num_layers = 10
dropout1 = 0.1
dropout2 = 0.1
dropout3 = 0.1
bpps_nb_mean = 0.077522 # mean of bpps_nb across all training data
bpps_nb_std = 0.08914   # std of bpps_nb across all training data
error_mean_limit = 0.5


In [4]:
def match_pair(structure):
    pair = [-1] * len(structure)
    pair_no = -1

    pair_no_stack = []
    for i, c in enumerate(structure):
        if c == '(':
            pair_no += 1
            pair[i] = pair_no
            pair_no_stack.append(pair_no)
        elif c == ')':
            pair[i] = pair_no_stack.pop()
    return pair


def match_pair2(bpps, threshold=0.0):
    max_idx = np.argmax(bpps, 1)
    max_val = np.max(bpps, 1)
    n = bpps.shape[0]
    pairs = []
    m = 0
    for j in range(n):
        idxs = np.where(bpps[j] > threshold)[0]
        for idx in idxs:
            pairs.append((j, idx))
    return pairs


In [5]:
a = np.array([[0, 0, 0, 0, 0, 0],
              [0, 0, 0.4, 0, 0, 0],
              [0, 0.4, 0, 0, 0, 0],
              [0, 0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0, 0.2],
              [0, 0, 0, 0, 0.2, 0]])
a, a.T

(array([[0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0.4, 0. , 0. , 0. ],
        [0. , 0.4, 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0.2],
        [0. , 0. , 0. , 0. , 0.2, 0. ]]),
 array([[0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0.4, 0. , 0. , 0. ],
        [0. , 0.4, 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0.2],
        [0. , 0. , 0. , 0. , 0.2, 0. ]]))

In [6]:
match_pair(".().()"), match_pair2(a)

([-1, 0, 0, -1, 1, 1], [(1, 2), (2, 1), (4, 5), (5, 4)])

In [7]:
class MyData(Data):
    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
                 pos=None, norm=None, face=None, weight=None, **kwargs):
        super(MyData, self).__init__(x=x, edge_index=edge_index,
                                     edge_attr=edge_attr, y=y, pos=pos,
                                     norm=norm, face=face, **kwargs)
        self.weight = weight


def calc_error_mean(row):
    reactivity_error = row['reactivity_error']
    deg_error_Mg_pH10 = row['deg_error_Mg_pH10']
    deg_error_Mg_50C = row['deg_error_Mg_50C']

    return np.mean(np.abs(reactivity_error) +
                   np.abs(deg_error_Mg_pH10) + \
                   np.abs(deg_error_Mg_50C)) / 3


def calc_sample_weight(row, threshold):
    if sample_is_clean(row):
        return 1.
    else:
        error_mean = calc_error_mean(row)
        if error_mean >= threshold:
            return 0.

        return 1. - error_mean / threshold


# add directed edge for node1 -> node2 and for node2 -> node1
def add_edges(edge_index, edge_features, node1, node2, feature1, feature2):
    edge_index.append([node1, node2])
    edge_features.append(feature1)
    edge_index.append([node2, node1])
    edge_features.append(feature2)


def add_edges_between_base_nodes(edge_index, edge_features, node1, node2):
    edge_feature1 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        1, # forward edge: 1, backward edge: -1
        1, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        -1, # forward edge: 1, backward edge: -1
        1, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_edges_between_paired_nodes(edge_index, edge_features, node1, node2,
                                   bpps_value):
    edge_feature1 = [
        1, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        bpps_value, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        1, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        bpps_value, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_edges_between_codon_nodes(edge_index, edge_features, node1, node2):
    edge_feature1 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        1, # is edge between coden nodes
        1, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        1, # is edge between coden nodes
        -1, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_edges_between_codon_and_base_node(edge_index, edge_features,
                                          node1, node2):
    edge_feature1 = [
        0, # is edge for paired nodes
        1, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        0, # is edge for paired nodes
        1, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_node(node_features, feature):
    node_features.append(feature)


def add_base_node(node_features, sequence, predicted_loop_type,
                  bpps_sum, bpps_nb):
    feature = [
        0, # is codon node
        sequence == 'A',
        sequence == 'C',
        sequence == 'G',
        sequence == 'U',
        predicted_loop_type == 'S',
        predicted_loop_type == 'M',
        predicted_loop_type == 'I',
        predicted_loop_type == 'B',
        predicted_loop_type == 'H',
        predicted_loop_type == 'E',
        predicted_loop_type == 'X',
        bpps_sum,
        bpps_nb,
    ]
    add_node(node_features, feature)

def add_codon_node(node_features):
    feature = [
        1, # is codon node
        0, # sequence == 'A',
        0, # sequence == 'C',
        0, # sequence == 'G',
        0, # sequence == 'U',
        0, # predicted_loop_type == 'S',
        0, # predicted_loop_type == 'M',
        0, # predicted_loop_type == 'I',
        0, # predicted_loop_type == 'B',
        0, # predicted_loop_type == 'H',
        0, # predicted_loop_type == 'E',
        0, # predicted_loop_type == 'X',
        0, # bpps_sum
        0, # bpps_nb
    ]
    add_node(node_features, feature)


def build_data(df, is_train):
    bpps_nb_mean = 0.077522 # mean of bpps_nb across all training data
    bpps_nb_std = 0.08914   # std of bpps_nb across all training data
    add_edge_for_paired_nodes = True
    add_codon_nodes = True

    data = []
    for i in range(len(df)):
        targets = []
        node_features = []
        edge_features = []
        edge_index = []
        train_mask = []
        test_mask = []
        weights = []

        id = df.loc[i, 'id']
        path = f"data/bpps/{id}.npy"
        bpps = np.load(path)
        bpps_sum = bpps.sum(axis=0)
        sequence = df.loc[i, 'sequence']
        structure = df.loc[i, 'structure']
        pair_info = match_pair(sequence)
        predicted_loop_type = df.loc[i, 'predicted_loop_type']
        seq_length = df.loc[i, 'seq_length']
        seq_scored = df.loc[i, 'seq_scored']
        bpps_nb = (bpps > 0).sum(axis=0) / seq_length
        bpps_nb = (bpps_nb - bpps_nb_mean) / bpps_nb_std
        if is_train:
            sample_weight = calc_sample_weight(df.loc[i], 0.8)

            reactivity = df.loc[i, 'reactivity']
            deg_Mg_pH10 = df.loc[i, 'deg_Mg_pH10']
            deg_Mg_50C = df.loc[i, 'deg_Mg_50C']

            for j in range(seq_length):
                if j < seq_scored:
                    targets.append([
                        reactivity[j],
                        deg_Mg_pH10[j],
                        deg_Mg_50C[j],
                        ])
                else:
                    targets.append([0, 0, 0])

        paired_nodes = {}
        for j in range(seq_length):
            add_base_node(node_features, sequence[j], predicted_loop_type[j],
                          bpps_sum[j], bpps_nb[j])

            if j + 1 < seq_length: # edge between current node and next node
                add_edges_between_base_nodes(edge_index, edge_features,
                                             j, j + 1)

            # if pair_info[j] != -1:
            #    if pair_info[j] not in paired_nodes:
            #        paired_nodes[pair_info[j]] = [j]
            #    else:
            #        paired_nodes[pair_info[j]].append(j)

            train_mask.append(j < seq_scored)
            test_mask.append(True)
            if is_train:
                weights.append(sample_weight)
        # paired_nodes = {i: pp for pp in match_pair2(bpps)}
        if add_edge_for_paired_nodes:
            for pair in paired_nodes.values():
                bpps_value = bpps[pair[0], pair[1]]
                add_edges_between_paired_nodes(edge_index, edge_features,
                                               pair[0], pair[1], bpps_value)

        if add_codon_nodes:
            codon_node_idx = seq_length - 1
            for j in range(seq_length):
                if j % 3 == 0:
                    # add codon node
                    add_codon_node(node_features)
                    codon_node_idx += 1
                    train_mask.append(False)
                    test_mask.append(False)
                    if is_train:
                        weights.append(0)
                        targets.append([0, 0, 0])

                    if codon_node_idx > seq_length:
                        # add edges between adjacent codon nodes
                        add_edges_between_codon_nodes(edge_index, edge_features,
                                                      codon_node_idx - 1,
                                                      codon_node_idx)

                # add edges between codon node and base node
                add_edges_between_codon_and_base_node(edge_index, edge_features,
                                                      j, codon_node_idx)

        node_features = torch.tensor(node_features, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_features = torch.tensor(edge_features, dtype=torch.float)

        if is_train:
            data.append(MyData(x=node_features, edge_index=edge_index,
                               edge_attr=edge_features,
                               train_mask=torch.tensor(train_mask),
                               weight=torch.tensor(weights, dtype=torch.float),
                               y=torch.tensor(targets, dtype=torch.float)))
        else:
            data.append(MyData(x=node_features, edge_index=edge_index,
                               edge_attr=edge_features,
                               test_mask=torch.tensor(test_mask)))

    return data


In [8]:
def weighted_mse_loss(prds, tgts, weight):
    return torch.mean(weight * (prds - tgts)**2)


def criterion(prds, tgts, weight=None):
    if weight is None:
        return (torch.sqrt(torch.nn.MSELoss()(prds[:,0], tgts[:,0])) +
                torch.sqrt(torch.nn.MSELoss()(prds[:,1], tgts[:,1])) +
                torch.sqrt(torch.nn.MSELoss()(prds[:,2], tgts[:,2]))) / 3
    else:
        return (torch.sqrt(weighted_mse_loss(prds[:,0], tgts[:,0], weight)) +
                torch.sqrt(weighted_mse_loss(prds[:,1], tgts[:,1], weight)) +
                torch.sqrt(weighted_mse_loss(prds[:,2], tgts[:,2], weight))) / 3

def build_id_seqpos(df):
    id_seqpos = []
    for i in range(len(df)):
        id = df.loc[i, 'id']
        seq_length = df.loc[i, 'seq_length']
        for seqpos in range(seq_length):
            id_seqpos.append(id + '_' + str(seqpos))
    return id_seqpos

def sample_is_clean(row):
    return row['SN_filter'] == 1
    #return row['signal_to_noise'] > 1 and \
    #       min((min(row['reactivity']),
    #            min(row['deg_Mg_pH10']),
    #            min(row['deg_pH10']),
    #            min(row['deg_Mg_50C']),
    #            min(row['deg_50C']))) > -0.5

# categorical value for target (used for stratified kfold)
def add_y_cat(df):
    target_mean = df['reactivity'].apply(np.mean) + \
                  df['deg_Mg_pH10'].apply(np.mean) + \
                  df['deg_Mg_50C'].apply(np.mean)
    df['y_cat'] = pd.qcut(np.array(target_mean), q=20).codes

In [9]:
#
# originally copied from
# https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ogbn_proteins_deepgcn.py
# 
class MapE2NxN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels):
        super(MapE2NxN, self).__init__()
        self.linear1 = Linear(in_channels, hidden_channels)
        self.linear2 = Linear(hidden_channels, out_channels)
        self.dropout = Dropout(dropout3)
        self.gelu = nn.GELU()
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class MyDeeperGCN(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features,
                 node_hidden_channels,
                 edge_hidden_channels,
                 num_layers, num_classes):
        super(MyDeeperGCN, self).__init__()

        self.node_encoder = ChebConv(num_node_features, node_hidden_channels, T)
        self.edge_encoder = Linear(num_edge_features, edge_hidden_channels)

        self.layers = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            conv = NNConv(node_hidden_channels, node_hidden_channels,
                          MapE2NxN(edge_hidden_channels,
                                   node_hidden_channels * node_hidden_channels,
                                   hidden_channels3))
            norm = LayerNorm(node_hidden_channels, elementwise_affine=True)
            act = nn.GELU()

            layer = DeepGCNLayer(conv, norm, act, block='res+',
                                 dropout=dropout1, ckpt_grad=i % 3)
            self.layers.append(layer)

        self.lin = Linear(node_hidden_channels, num_classes)
        self.dropout = Dropout(dropout2)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr

        # edge for paired nodes are excluded for encoding node
        seq_edge_index = edge_index[:, edge_attr[:,0] == 0]
        x = self.node_encoder(x, seq_edge_index)

        edge_attr = self.edge_encoder(edge_attr)

        x = self.layers[0].conv(x, edge_index, edge_attr)

        for layer in self.layers[1:]:
            x = layer(x, edge_index, edge_attr)

        x = self.layers[0].act(self.layers[0].norm(x))
        x = self.dropout(x)

        return self.lin(x)

In [10]:
seed = 777
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [11]:
FN = FilePaths("data")
df_tr = pd.read_json(FN.train_json, lines=True)
add_y_cat(df_tr)

In [12]:
T = 5
hidden_channels3 = 32
num_layers = 10
dropout1 = 0.1
dropout2 = 0.1
dropout3 = 0.1
HPARAMS = {
    "nb_fold": 5,
    "filter_noise": True,
    "signal_to_noise_ratio": 0.5,
    "batch_size": 16,
    "lr": 1e-3,
    "wd": 0,
    "num_layers": 10,
    "node_hidden_channels": 96,
    "edge_hidden_channels": 16
}

In [13]:
device = utils.get_device()
all_ys = torch.zeros((0, 3)).to(device).detach()
all_outs = torch.zeros((0, 3)).to(device).detach()
best_model_states = []
cvlist = list(StratifiedKFold(HPARAMS["nb_fold"], shuffle=True, random_state=seed).split(df_tr, df_tr["y_cat"]))

In [14]:
def get_dataloader(df, hparams):
    data_train = build_data(df.reset_index(drop=True), True)
    return data_train, DataLoader(data_train, batch_size=hparams["batch_size"], shuffle=True)


def train_fold(model, loader_train, loader_valid, optimizer, criterion, epochs, device):
    best_mcrmse = np.inf
    for epoch in range(epochs):
        print('Epoch', epoch)
        model.train()
        train_loss = 0.0
        nb = 0
        for data in tqdm(loader_train):
            data = data.to(device)
            mask = data.train_mask
            weight = data.weight[mask]

            optimizer.zero_grad()
            out = model(data)[mask]
            y = data.y[mask]
            loss = criterion(out, y, weight)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * y.size(0)
            nb += y.size(0)

            del data
            del out
            del y
            del loss
            #gc.collect()
            #torch.cuda.empty_cache()
        train_loss /= nb

        model.eval()
        valid_loss = 0.0
        nb = 0
        ys = torch.zeros((0, 3)).to(device).detach()
        outs = torch.zeros((0, 3)).to(device).detach()
        for data in tqdm(loader_valid):
            data = data.to(device)
            mask = data.train_mask

            out = model(data)[mask].detach()
            y = data.y[mask].detach()
            loss = criterion(out, y).detach()
            valid_loss += loss.item() * y.size(0)
            nb += y.size(0)

            outs = torch.cat((outs, out), dim=0)
            ys = torch.cat((ys, y), dim=0)

            del data
            del out
            del y
            del loss
            #gc.collect()
            #torch.cuda.empty_cache()
        valid_loss /= nb

        mcrmse = criterion(outs, ys).item()

        print("T Loss: {:.4f} V Loss: {:.4f} V MCRMSE: {:.4f}".\
                format(train_loss, valid_loss, mcrmse))

        if mcrmse < best_mcrmse:
            print('Best valid MCRMSE updated to', mcrmse)
            best_mcrmse = mcrmse
            best_model_state = copy.deepcopy(model.state_dict())
    return best_model_state

In [15]:
#for i, (tr_idx, vl_idx) in enumerate(cvlist):
tr, vl = df_tr.iloc[cvlist[0][0]], df_tr.iloc[cvlist[0][1]]

if HPARAMS["filter_noise"]:
    cond = tr.apply(calc_error_mean, axis=1) < 0.5
    tr = tr.loc[cond].reset_index(drop=True)

vl = vl.loc[vl["SN_filter"] == 1].reset_index(drop=True)
print(tr.shape, vl.shape)

(1720, 20) (324, 20)


In [16]:
data_train, loader_train = get_dataloader(tr, HPARAMS)
data_valid, loader_valid = get_dataloader(vl, HPARAMS)


In future, it will be an error for 'np.bool_' scalars to be interpreted as an index



In [17]:
model = MyDeeperGCN(data_train[0].num_node_features,
                    data_train[0].num_edge_features,
                    node_hidden_channels=HPARAMS["node_hidden_channels"],
                    edge_hidden_channels=HPARAMS["edge_hidden_channels"],
                    num_layers=HPARAMS["num_layers"],
                    num_classes=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=HPARAMS["lr"], weight_decay=HPARAMS["wd"])

In [18]:
best_state_fold0 = train_fold(model, loader_train, loader_valid, optimizer, criterion, 100, device)

  0%|          | 0/108 [00:00<?, ?it/s]

Epoch 0


100%|██████████| 108/108 [00:06<00:00, 16.07it/s]
100%|██████████| 21/21 [00:00<00:00, 46.95it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.83it/s]

T Loss: 0.3895 V Loss: 0.3360 V MCRMSE: 0.3372
Best valid MCRMSE updated to 0.33717888593673706
Epoch 1


100%|██████████| 108/108 [00:06<00:00, 16.68it/s]
100%|██████████| 21/21 [00:00<00:00, 46.52it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.69it/s]

T Loss: 0.3204 V Loss: 0.3048 V MCRMSE: 0.3059
Best valid MCRMSE updated to 0.3059006333351135
Epoch 2


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.16it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.84it/s]

T Loss: 0.3004 V Loss: 0.2847 V MCRMSE: 0.2865
Best valid MCRMSE updated to 0.28646108508110046
Epoch 3


100%|██████████| 108/108 [00:06<00:00, 16.63it/s]
100%|██████████| 21/21 [00:00<00:00, 47.17it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.86it/s]

T Loss: 0.2899 V Loss: 0.2770 V MCRMSE: 0.2783
Best valid MCRMSE updated to 0.2782767117023468
Epoch 4


100%|██████████| 108/108 [00:06<00:00, 16.68it/s]
100%|██████████| 21/21 [00:00<00:00, 47.12it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.74it/s]

T Loss: 0.2818 V Loss: 0.2729 V MCRMSE: 0.2744
Best valid MCRMSE updated to 0.2743653953075409
Epoch 5


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 46.84it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.78it/s]

T Loss: 0.2719 V Loss: 0.2711 V MCRMSE: 0.2724
Best valid MCRMSE updated to 0.2724047303199768
Epoch 6


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 46.94it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.77it/s]

T Loss: 0.2648 V Loss: 0.2626 V MCRMSE: 0.2639
Best valid MCRMSE updated to 0.26385313272476196
Epoch 7


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.23it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.75it/s]

T Loss: 0.2587 V Loss: 0.2619 V MCRMSE: 0.2631
Best valid MCRMSE updated to 0.26311802864074707
Epoch 8


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.68it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.84it/s]

T Loss: 0.2567 V Loss: 0.2581 V MCRMSE: 0.2596
Best valid MCRMSE updated to 0.2595968246459961
Epoch 9


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.62it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.82it/s]

T Loss: 0.2518 V Loss: 0.2599 V MCRMSE: 0.2613
Epoch 10


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.47it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.82it/s]

T Loss: 0.2487 V Loss: 0.2576 V MCRMSE: 0.2586
Best valid MCRMSE updated to 0.2586478590965271
Epoch 11


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.64it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.79it/s]

T Loss: 0.2449 V Loss: 0.2560 V MCRMSE: 0.2572
Best valid MCRMSE updated to 0.2571963667869568
Epoch 12


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.28it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.76it/s]

T Loss: 0.2409 V Loss: 0.2544 V MCRMSE: 0.2557
Best valid MCRMSE updated to 0.25567564368247986
Epoch 13


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.07it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.85it/s]

T Loss: 0.2392 V Loss: 0.2540 V MCRMSE: 0.2562
Epoch 14


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.47it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.87it/s]

T Loss: 0.2376 V Loss: 0.2522 V MCRMSE: 0.2539
Best valid MCRMSE updated to 0.2538580298423767
Epoch 15


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.30it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.82it/s]

T Loss: 0.2342 V Loss: 0.2497 V MCRMSE: 0.2515
Best valid MCRMSE updated to 0.25152018666267395
Epoch 16


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.24it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.91it/s]

T Loss: 0.2321 V Loss: 0.2489 V MCRMSE: 0.2503
Best valid MCRMSE updated to 0.2503361105918884
Epoch 17


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.45it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.86it/s]

T Loss: 0.2307 V Loss: 0.2476 V MCRMSE: 0.2487
Best valid MCRMSE updated to 0.2486577033996582
Epoch 18


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 46.89it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.74it/s]

T Loss: 0.2285 V Loss: 0.2472 V MCRMSE: 0.2489
Epoch 19


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.47it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.39it/s]

T Loss: 0.2270 V Loss: 0.2455 V MCRMSE: 0.2470
Best valid MCRMSE updated to 0.24695995450019836
Epoch 20


100%|██████████| 108/108 [00:06<00:00, 16.38it/s]
100%|██████████| 21/21 [00:00<00:00, 46.86it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.74it/s]

T Loss: 0.2242 V Loss: 0.2452 V MCRMSE: 0.2462
Best valid MCRMSE updated to 0.2461862564086914
Epoch 21


100%|██████████| 108/108 [00:06<00:00, 16.68it/s]
100%|██████████| 21/21 [00:00<00:00, 47.11it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.85it/s]

T Loss: 0.2220 V Loss: 0.2478 V MCRMSE: 0.2496
Epoch 22


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 46.67it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.80it/s]

T Loss: 0.2201 V Loss: 0.2440 V MCRMSE: 0.2451
Best valid MCRMSE updated to 0.24510395526885986
Epoch 23


100%|██████████| 108/108 [00:06<00:00, 16.90it/s]
100%|██████████| 21/21 [00:00<00:00, 47.36it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.79it/s]

T Loss: 0.2189 V Loss: 0.2437 V MCRMSE: 0.2453
Epoch 24


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.19it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.91it/s]

T Loss: 0.2178 V Loss: 0.2457 V MCRMSE: 0.2470
Epoch 25


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.38it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.78it/s]

T Loss: 0.2160 V Loss: 0.2425 V MCRMSE: 0.2437
Best valid MCRMSE updated to 0.24367330968379974
Epoch 26


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.48it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.87it/s]

T Loss: 0.2141 V Loss: 0.2470 V MCRMSE: 0.2483
Epoch 27


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.04it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.84it/s]

T Loss: 0.2133 V Loss: 0.2457 V MCRMSE: 0.2472
Epoch 28


100%|██████████| 108/108 [00:06<00:00, 16.64it/s]
100%|██████████| 21/21 [00:00<00:00, 47.03it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.67it/s]

T Loss: 0.2110 V Loss: 0.2422 V MCRMSE: 0.2438
Epoch 29


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 46.99it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.90it/s]

T Loss: 0.2095 V Loss: 0.2403 V MCRMSE: 0.2417
Best valid MCRMSE updated to 0.24174457788467407
Epoch 30


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.16it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.90it/s]

T Loss: 0.2093 V Loss: 0.2411 V MCRMSE: 0.2429
Epoch 31


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 46.77it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.61it/s]

T Loss: 0.2069 V Loss: 0.2435 V MCRMSE: 0.2454
Epoch 32


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.24it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.85it/s]

T Loss: 0.2060 V Loss: 0.2406 V MCRMSE: 0.2423
Epoch 33


100%|██████████| 108/108 [00:06<00:00, 16.66it/s]
100%|██████████| 21/21 [00:00<00:00, 45.97it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.66it/s]

T Loss: 0.2035 V Loss: 0.2385 V MCRMSE: 0.2402
Best valid MCRMSE updated to 0.24024665355682373
Epoch 34


100%|██████████| 108/108 [00:06<00:00, 16.65it/s]
100%|██████████| 21/21 [00:00<00:00, 47.62it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.87it/s]

T Loss: 0.2035 V Loss: 0.2388 V MCRMSE: 0.2403
Epoch 35


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.38it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.77it/s]

T Loss: 0.2016 V Loss: 0.2430 V MCRMSE: 0.2446
Epoch 36


100%|██████████| 108/108 [00:06<00:00, 16.67it/s]
100%|██████████| 21/21 [00:00<00:00, 47.37it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.84it/s]

T Loss: 0.1996 V Loss: 0.2395 V MCRMSE: 0.2415
Epoch 37


100%|██████████| 108/108 [00:06<00:00, 16.67it/s]
100%|██████████| 21/21 [00:00<00:00, 47.13it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.84it/s]

T Loss: 0.1991 V Loss: 0.2430 V MCRMSE: 0.2454
Epoch 38


100%|██████████| 108/108 [00:06<00:00, 16.67it/s]
100%|██████████| 21/21 [00:00<00:00, 47.17it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.83it/s]

T Loss: 0.1983 V Loss: 0.2389 V MCRMSE: 0.2402
Best valid MCRMSE updated to 0.2401527762413025
Epoch 39


100%|██████████| 108/108 [00:06<00:00, 16.65it/s]
100%|██████████| 21/21 [00:00<00:00, 47.31it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.83it/s]

T Loss: 0.1968 V Loss: 0.2384 V MCRMSE: 0.2406
Epoch 40


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.37it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.88it/s]

T Loss: 0.1957 V Loss: 0.2386 V MCRMSE: 0.2401
Best valid MCRMSE updated to 0.24009814858436584
Epoch 41


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.41it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.79it/s]

T Loss: 0.1953 V Loss: 0.2386 V MCRMSE: 0.2398
Best valid MCRMSE updated to 0.2398395836353302
Epoch 42


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.17it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.92it/s]

T Loss: 0.1941 V Loss: 0.2380 V MCRMSE: 0.2392
Best valid MCRMSE updated to 0.2391744703054428
Epoch 43


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.26it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.88it/s]

T Loss: 0.1923 V Loss: 0.2370 V MCRMSE: 0.2390
Best valid MCRMSE updated to 0.23898421227931976
Epoch 44


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 46.48it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.71it/s]

T Loss: 0.1914 V Loss: 0.2376 V MCRMSE: 0.2391
Epoch 45


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.14it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.90it/s]

T Loss: 0.1894 V Loss: 0.2377 V MCRMSE: 0.2396
Epoch 46


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.37it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.78it/s]

T Loss: 0.1904 V Loss: 0.2417 V MCRMSE: 0.2435
Epoch 47


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.17it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.72it/s]

T Loss: 0.1883 V Loss: 0.2410 V MCRMSE: 0.2430
Epoch 48


100%|██████████| 108/108 [00:06<00:00, 16.65it/s]
100%|██████████| 21/21 [00:00<00:00, 47.08it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.72it/s]

T Loss: 0.1872 V Loss: 0.2419 V MCRMSE: 0.2435
Epoch 49


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.04it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.79it/s]

T Loss: 0.1874 V Loss: 0.2384 V MCRMSE: 0.2400
Epoch 50


100%|██████████| 108/108 [00:06<00:00, 16.62it/s]
100%|██████████| 21/21 [00:00<00:00, 47.44it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.84it/s]

T Loss: 0.1850 V Loss: 0.2382 V MCRMSE: 0.2399
Epoch 51


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.50it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.79it/s]

T Loss: 0.1852 V Loss: 0.2354 V MCRMSE: 0.2367
Best valid MCRMSE updated to 0.23669354617595673
Epoch 52


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.66it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.86it/s]

T Loss: 0.1842 V Loss: 0.2359 V MCRMSE: 0.2379
Epoch 53


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.59it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.93it/s]

T Loss: 0.1833 V Loss: 0.2360 V MCRMSE: 0.2373
Epoch 54


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.61it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.95it/s]

T Loss: 0.1819 V Loss: 0.2362 V MCRMSE: 0.2384
Epoch 55


100%|██████████| 108/108 [00:06<00:00, 16.75it/s]
100%|██████████| 21/21 [00:00<00:00, 47.24it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.89it/s]

T Loss: 0.1817 V Loss: 0.2387 V MCRMSE: 0.2401
Epoch 56


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.60it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.83it/s]

T Loss: 0.1812 V Loss: 0.2360 V MCRMSE: 0.2377
Epoch 57


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.45it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.80it/s]

T Loss: 0.1799 V Loss: 0.2354 V MCRMSE: 0.2375
Epoch 58


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.45it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.85it/s]

T Loss: 0.1791 V Loss: 0.2379 V MCRMSE: 0.2396
Epoch 59


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.22it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.70it/s]

T Loss: 0.1792 V Loss: 0.2380 V MCRMSE: 0.2395
Epoch 60


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.67it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.89it/s]

T Loss: 0.1778 V Loss: 0.2363 V MCRMSE: 0.2381
Epoch 61


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.75it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.93it/s]

T Loss: 0.1772 V Loss: 0.2360 V MCRMSE: 0.2377
Epoch 62


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.36it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.89it/s]

T Loss: 0.1768 V Loss: 0.2352 V MCRMSE: 0.2366
Best valid MCRMSE updated to 0.23656153678894043
Epoch 63


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.29it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.87it/s]

T Loss: 0.1754 V Loss: 0.2342 V MCRMSE: 0.2358
Best valid MCRMSE updated to 0.23578527569770813
Epoch 64


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 46.87it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.86it/s]

T Loss: 0.1753 V Loss: 0.2356 V MCRMSE: 0.2374
Epoch 65


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.57it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.94it/s]

T Loss: 0.1743 V Loss: 0.2363 V MCRMSE: 0.2380
Epoch 66


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.01it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.76it/s]

T Loss: 0.1738 V Loss: 0.2363 V MCRMSE: 0.2377
Epoch 67


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.51it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.92it/s]

T Loss: 0.1737 V Loss: 0.2356 V MCRMSE: 0.2373
Epoch 68


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.61it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.86it/s]

T Loss: 0.1725 V Loss: 0.2354 V MCRMSE: 0.2372
Epoch 69


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.16it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.80it/s]

T Loss: 0.1720 V Loss: 0.2354 V MCRMSE: 0.2369
Epoch 70


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.28it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.77it/s]

T Loss: 0.1711 V Loss: 0.2361 V MCRMSE: 0.2376
Epoch 71


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.80it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.83it/s]

T Loss: 0.1711 V Loss: 0.2379 V MCRMSE: 0.2398
Epoch 72


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.76it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.89it/s]

T Loss: 0.1701 V Loss: 0.2346 V MCRMSE: 0.2363
Epoch 73


100%|██████████| 108/108 [00:06<00:00, 16.75it/s]
100%|██████████| 21/21 [00:00<00:00, 47.52it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.86it/s]

T Loss: 0.1710 V Loss: 0.2365 V MCRMSE: 0.2377
Epoch 74


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.65it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.94it/s]

T Loss: 0.1686 V Loss: 0.2368 V MCRMSE: 0.2385
Epoch 75


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.31it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.79it/s]

T Loss: 0.1691 V Loss: 0.2364 V MCRMSE: 0.2382
Epoch 76


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.03it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.90it/s]

T Loss: 0.1679 V Loss: 0.2342 V MCRMSE: 0.2364
Epoch 77


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.25it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.84it/s]

T Loss: 0.1680 V Loss: 0.2379 V MCRMSE: 0.2394
Epoch 78


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.40it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.90it/s]

T Loss: 0.1672 V Loss: 0.2364 V MCRMSE: 0.2384
Epoch 79


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.18it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.83it/s]

T Loss: 0.1664 V Loss: 0.2374 V MCRMSE: 0.2389
Epoch 80


100%|██████████| 108/108 [00:06<00:00, 16.72it/s]
100%|██████████| 21/21 [00:00<00:00, 47.08it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.82it/s]

T Loss: 0.1676 V Loss: 0.2337 V MCRMSE: 0.2357
Best valid MCRMSE updated to 0.23573629558086395
Epoch 81


100%|██████████| 108/108 [00:06<00:00, 16.66it/s]
100%|██████████| 21/21 [00:00<00:00, 47.37it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.88it/s]

T Loss: 0.1673 V Loss: 0.2351 V MCRMSE: 0.2367
Epoch 82


100%|██████████| 108/108 [00:06<00:00, 16.70it/s]
100%|██████████| 21/21 [00:00<00:00, 47.30it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.80it/s]

T Loss: 0.1653 V Loss: 0.2360 V MCRMSE: 0.2379
Epoch 83


100%|██████████| 108/108 [00:06<00:00, 16.68it/s]
100%|██████████| 21/21 [00:00<00:00, 47.50it/s]
  2%|▏         | 2/108 [00:00<00:06, 17.00it/s]

T Loss: 0.1657 V Loss: 0.2370 V MCRMSE: 0.2387
Epoch 84


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.54it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.87it/s]

T Loss: 0.1641 V Loss: 0.2381 V MCRMSE: 0.2393
Epoch 85


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.77it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.92it/s]

T Loss: 0.1639 V Loss: 0.2353 V MCRMSE: 0.2369
Epoch 86


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 46.96it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.67it/s]

T Loss: 0.1630 V Loss: 0.2375 V MCRMSE: 0.2388
Epoch 87


100%|██████████| 108/108 [00:06<00:00, 16.69it/s]
100%|██████████| 21/21 [00:00<00:00, 47.44it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.85it/s]

T Loss: 0.1625 V Loss: 0.2358 V MCRMSE: 0.2374
Epoch 88


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.38it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.91it/s]

T Loss: 0.1629 V Loss: 0.2384 V MCRMSE: 0.2400
Epoch 89


100%|██████████| 108/108 [00:06<00:00, 16.71it/s]
100%|██████████| 21/21 [00:00<00:00, 47.26it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.90it/s]

T Loss: 0.1619 V Loss: 0.2350 V MCRMSE: 0.2363
Epoch 90


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.58it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.93it/s]

T Loss: 0.1615 V Loss: 0.2366 V MCRMSE: 0.2378
Epoch 91


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.27it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.82it/s]

T Loss: 0.1620 V Loss: 0.2358 V MCRMSE: 0.2374
Epoch 92


100%|██████████| 108/108 [00:06<00:00, 16.65it/s]
100%|██████████| 21/21 [00:00<00:00, 46.46it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.70it/s]

T Loss: 0.1609 V Loss: 0.2355 V MCRMSE: 0.2371
Epoch 93


100%|██████████| 108/108 [00:06<00:00, 16.65it/s]
100%|██████████| 21/21 [00:00<00:00, 47.06it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.73it/s]

T Loss: 0.1594 V Loss: 0.2353 V MCRMSE: 0.2372
Epoch 94


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.63it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.85it/s]

T Loss: 0.1598 V Loss: 0.2383 V MCRMSE: 0.2398
Epoch 95


100%|██████████| 108/108 [00:06<00:00, 16.77it/s]
100%|██████████| 21/21 [00:00<00:00, 47.37it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.70it/s]

T Loss: 0.1597 V Loss: 0.2373 V MCRMSE: 0.2390
Epoch 96


100%|██████████| 108/108 [00:06<00:00, 16.73it/s]
100%|██████████| 21/21 [00:00<00:00, 47.19it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.82it/s]

T Loss: 0.1589 V Loss: 0.2342 V MCRMSE: 0.2359
Epoch 97


100%|██████████| 108/108 [00:06<00:00, 16.76it/s]
100%|██████████| 21/21 [00:00<00:00, 47.31it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.96it/s]

T Loss: 0.1586 V Loss: 0.2355 V MCRMSE: 0.2371
Epoch 98


100%|██████████| 108/108 [00:06<00:00, 16.74it/s]
100%|██████████| 21/21 [00:00<00:00, 47.60it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.98it/s]

T Loss: 0.1590 V Loss: 0.2359 V MCRMSE: 0.2373
Epoch 99


100%|██████████| 108/108 [00:06<00:00, 16.85it/s]
100%|██████████| 21/21 [00:00<00:00, 46.99it/s]

T Loss: 0.1579 V Loss: 0.2363 V MCRMSE: 0.2381





In [None]:
train = pd.read_json("data/train.json", lines=True)
cond = train.SN_filter == 1

In [None]:
tr1 = train.loc[cond]
tr2 = train.loc[~cond]

In [None]:
e1 = np.vstack(tr1["reactivity_error"].values)
e2 = np.vstack(tr1["deg_error_Mg_50C"].values)
e3 = np.vstack(tr1["deg_error_Mg_pH10"].values)

In [None]:
pd.Series(e1.flatten()).describe()

In [None]:
pd.Series(e2.flatten()).describe()

In [None]:
pd.Series(e3.flatten()).describe()

In [None]:
(e1 > 0.5).sum()

In [None]:
pd.Series(e1.mean(1).flatten()).describe()

In [None]:
pd.Series(e2.mean(1).flatten()).describe()

In [None]:
pd.Series(e3.mean(1).flatten()).describe()

In [84]:
e = np.dstack((e1, e2, e3))
e.shape

(1589, 68, 3)

In [85]:
pd.Series(e.mean(1).mean(1)).describe()

count    1589.000000
mean        0.103780
std         0.065435
min         0.024189
25%         0.063739
50%         0.083505
75%         0.118433
max         0.598469
dtype: float64

In [86]:
snr = tr1["signal_to_noise"]
snr.describe()

count    1589.000000
mean        5.402215
std         2.524798
min         0.993000
25%         3.484000
50%         5.222000
75%         6.853000
max        17.194000
Name: signal_to_noise, dtype: float64

In [87]:
e1 = np.vstack(tr2["reactivity_error"].values)
e2 = np.vstack(tr2["deg_error_Mg_50C"].values)
e3 = np.vstack(tr2["deg_error_Mg_pH10"].values)
e = np.dstack((e1, e2, e3))
pd.Series(e.mean(1).mean(1)).describe()

count       811.000000
mean       8630.316650
std       29131.240267
min           0.034662
25%           0.088635
50%           0.195730
75%           0.641706
max      140637.240300
dtype: float64

In [89]:
e1 = np.vstack(train["reactivity_error"].values)
e2 = np.vstack(train["deg_error_Mg_50C"].values)
e3 = np.vstack(train["deg_error_Mg_pH10"].values)
e = np.dstack((e1, e2, e3))
sum(e.mean(1).mean(1) > 0.6)

211