In [1]:
import numpy as np
import pandas as pd
import random
import torch
from torch import nn
from torch.nn import Linear, LayerNorm, ReLU, Dropout
from torch_geometric.nn import ChebConv, NNConv, DeepGCNLayer, EdgeConv
from torch_geometric.data import Data, DataLoader
from sklearn.model_selection import StratifiedKFold
from tqdm import tqdm
import os
import copy
from catalyst.dl import utils

from constants import FilePaths
# settings



In [2]:
seed = 777
train_file = '../input/stanford-covid-vaccine/train.json'
test_file = '../input/stanford-covid-vaccine/test.json'
bpps_top = '../input/stanford-covid-vaccine/bpps'
nb_fold = 5
device = 'cuda'
batch_size = 16
epochs = 100
lr = 0.0005
train_with_noisy_data = True
add_edge_for_paired_nodes = True
add_codon_nodes = True
T = 5
node_hidden_channels = 96
edge_hidden_channels = 16
hidden_channels3 = 32
num_layers = 10
dropout1 = 0.1
dropout2 = 0.1
dropout3 = 0.1
bpps_nb_mean = 0.077522 # mean of bpps_nb across all training data
bpps_nb_std = 0.08914   # std of bpps_nb across all training data
error_mean_limit = 0.5


In [3]:
def match_pair(structure):
    pair = [-1] * len(structure)
    pair_no = -1

    pair_no_stack = []
    for i, c in enumerate(structure):
        if c == '(':
            pair_no += 1
            pair[i] = pair_no
            pair_no_stack.append(pair_no)
        elif c == ')':
            pair[i] = pair_no_stack.pop()
    return pair


def match_pair2(bpps, threshold=0.0):
    max_idx = np.argmax(bpps, 1)
    max_val = np.max(bpps, 1)
    n = bpps.shape[0]
    pairs = []
    m = 0
    for j in range(n):
        idxs = np.where(bpps[j] > threshold)[0]
        for idx in idxs:
            pairs.append((j, idx))
    return pairs


In [4]:
a = np.array([[0, 0, 0, 0, 0, 0],
              [0, 0, 0.4, 0, 0, 0],
              [0, 0.4, 0, 0, 0, 0],
              [0, 0, 0, 0, 0, 0],
              [0, 0, 0, 0, 0, 0.2],
              [0, 0, 0, 0, 0.2, 0]])
a, a.T

(array([[0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0.4, 0. , 0. , 0. ],
        [0. , 0.4, 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0.2],
        [0. , 0. , 0. , 0. , 0.2, 0. ]]),
 array([[0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0.4, 0. , 0. , 0. ],
        [0. , 0.4, 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. , 0.2],
        [0. , 0. , 0. , 0. , 0.2, 0. ]]))

In [5]:
match_pair(".().()"), match_pair2(a)

([-1, 0, 0, -1, 1, 1], [(1, 2), (2, 1), (4, 5), (5, 4)])

In [6]:
class MyData(Data):
    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
                 pos=None, norm=None, face=None, weight=None, **kwargs):
        super(MyData, self).__init__(x=x, edge_index=edge_index,
                                     edge_attr=edge_attr, y=y, pos=pos,
                                     norm=norm, face=face, **kwargs)
        self.weight = weight


def calc_error_mean(row):
    reactivity_error = row['reactivity_error']
    deg_error_Mg_pH10 = row['deg_error_Mg_pH10']
    deg_error_Mg_50C = row['deg_error_Mg_50C']

    return np.mean(np.abs(reactivity_error) +
                   np.abs(deg_error_Mg_pH10) + \
                   np.abs(deg_error_Mg_50C)) / 3


def calc_sample_weight(row, threshold):
    if sample_is_clean(row):
        return 1.
    else:
        error_mean = calc_error_mean(row)
        if error_mean >= threshold:
            return 0.

        return 1. - error_mean / threshold


# add directed edge for node1 -> node2 and for node2 -> node1
def add_edges(edge_index, edge_features, node1, node2, feature1, feature2):
    edge_index.append([node1, node2])
    edge_features.append(feature1)
    edge_index.append([node2, node1])
    edge_features.append(feature2)


def add_edges_between_base_nodes(edge_index, edge_features, node1, node2):
    edge_feature1 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        1, # forward edge: 1, backward edge: -1
        1, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        -1, # forward edge: 1, backward edge: -1
        1, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_edges_between_paired_nodes(edge_index, edge_features, node1, node2,
                                   bpps_value):
    edge_feature1 = [
        1, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        bpps_value, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        1, # is edge for paired nodes
        0, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        bpps_value, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_edges_between_codon_nodes(edge_index, edge_features, node1, node2):
    edge_feature1 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        1, # is edge between coden nodes
        1, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        0, # is edge for paired nodes
        0, # is edge between codon node and base node
        1, # is edge between coden nodes
        -1, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_edges_between_codon_and_base_node(edge_index, edge_features,
                                          node1, node2):
    edge_feature1 = [
        0, # is edge for paired nodes
        1, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    edge_feature2 = [
        0, # is edge for paired nodes
        1, # is edge between codon node and base node
        0, # is edge between coden nodes
        0, # forward edge: 1, backward edge: -1
        0, # bpps if edge is for paired nodes
    ]
    add_edges(edge_index, edge_features, node1, node2,
              edge_feature1, edge_feature2)


def add_node(node_features, feature):
    node_features.append(feature)


def add_base_node(node_features, sequence, predicted_loop_type,
                  bpps_sum, bpps_nb):
    feature = [
        0, # is codon node
        sequence == 'A',
        sequence == 'C',
        sequence == 'G',
        sequence == 'U',
        predicted_loop_type == 'S',
        predicted_loop_type == 'M',
        predicted_loop_type == 'I',
        predicted_loop_type == 'B',
        predicted_loop_type == 'H',
        predicted_loop_type == 'E',
        predicted_loop_type == 'X',
        bpps_sum,
        bpps_nb,
    ]
    add_node(node_features, feature)

def add_codon_node(node_features):
    feature = [
        1, # is codon node
        0, # sequence == 'A',
        0, # sequence == 'C',
        0, # sequence == 'G',
        0, # sequence == 'U',
        0, # predicted_loop_type == 'S',
        0, # predicted_loop_type == 'M',
        0, # predicted_loop_type == 'I',
        0, # predicted_loop_type == 'B',
        0, # predicted_loop_type == 'H',
        0, # predicted_loop_type == 'E',
        0, # predicted_loop_type == 'X',
        0, # bpps_sum
        0, # bpps_nb
    ]
    add_node(node_features, feature)


def build_data(df, is_train):
    bpps_nb_mean = 0.077522 # mean of bpps_nb across all training data
    bpps_nb_std = 0.08914   # std of bpps_nb across all training data
    add_edge_for_paired_nodes = True
    add_codon_nodes = True

    data = []
    for i in range(len(df)):
        targets = []
        node_features = []
        edge_features = []
        edge_index = []
        train_mask = []
        test_mask = []
        weights = []

        id = df.loc[i, 'id']
        path = f"data/bpps/{id}.npy"
        bpps = np.load(path)
        bpps_sum = bpps.sum(axis=0)
        sequence = df.loc[i, 'sequence']
        structure = df.loc[i, 'structure']
        pair_info = match_pair(sequence)
        predicted_loop_type = df.loc[i, 'predicted_loop_type']
        seq_length = df.loc[i, 'seq_length']
        seq_scored = df.loc[i, 'seq_scored']
        bpps_nb = (bpps > 0).sum(axis=0) / seq_length
        bpps_nb = (bpps_nb - bpps_nb_mean) / bpps_nb_std
        if is_train:
            sample_weight = calc_sample_weight(df.loc[i], 0.8)

            reactivity = df.loc[i, 'reactivity']
            deg_Mg_pH10 = df.loc[i, 'deg_Mg_pH10']
            deg_Mg_50C = df.loc[i, 'deg_Mg_50C']

            for j in range(seq_length):
                if j < seq_scored:
                    targets.append([
                        reactivity[j],
                        deg_Mg_pH10[j],
                        deg_Mg_50C[j],
                        ])
                else:
                    targets.append([0, 0, 0])

        paired_nodes = {}
        for j in range(seq_length):
            add_base_node(node_features, sequence[j], predicted_loop_type[j],
                          bpps_sum[j], bpps_nb[j])

            if j + 1 < seq_length: # edge between current node and next node
                add_edges_between_base_nodes(edge_index, edge_features,
                                             j, j + 1)

            # if pair_info[j] != -1:
            #    if pair_info[j] not in paired_nodes:
            #        paired_nodes[pair_info[j]] = [j]
            #    else:
            #        paired_nodes[pair_info[j]].append(j)

            train_mask.append(j < seq_scored)
            test_mask.append(True)
            if is_train:
                weights.append(sample_weight)
        # paired_nodes = {i: pp for pp in match_pair2(bpps)}
        if add_edge_for_paired_nodes:
            for pair in paired_nodes.values():
                bpps_value = bpps[pair[0], pair[1]]
                add_edges_between_paired_nodes(edge_index, edge_features,
                                               pair[0], pair[1], bpps_value)

        if add_codon_nodes:
            codon_node_idx = seq_length - 1
            for j in range(seq_length):
                if j % 3 == 0:
                    # add codon node
                    add_codon_node(node_features)
                    codon_node_idx += 1
                    train_mask.append(False)
                    test_mask.append(False)
                    if is_train:
                        weights.append(0)
                        targets.append([0, 0, 0])

                    if codon_node_idx > seq_length:
                        # add edges between adjacent codon nodes
                        add_edges_between_codon_nodes(edge_index, edge_features,
                                                      codon_node_idx - 1,
                                                      codon_node_idx)

                # add edges between codon node and base node
                add_edges_between_codon_and_base_node(edge_index, edge_features,
                                                      j, codon_node_idx)

        node_features = torch.tensor(node_features, dtype=torch.float)
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_features = torch.tensor(edge_features, dtype=torch.float)

        if is_train:
            data.append(MyData(x=node_features, edge_index=edge_index,
                               edge_attr=edge_features,
                               train_mask=torch.tensor(train_mask),
                               weight=torch.tensor(weights, dtype=torch.float),
                               y=torch.tensor(targets, dtype=torch.float)))
        else:
            data.append(MyData(x=node_features, edge_index=edge_index,
                               edge_attr=edge_features,
                               test_mask=torch.tensor(test_mask)))

    return data


In [7]:
def weighted_mse_loss(prds, tgts, weight):
    return torch.mean(weight * (prds - tgts)**2)


def criterion(prds, tgts, weight=None):
    if weight is None:
        return (torch.sqrt(torch.nn.MSELoss()(prds[:,0], tgts[:,0])) +
                torch.sqrt(torch.nn.MSELoss()(prds[:,1], tgts[:,1])) +
                torch.sqrt(torch.nn.MSELoss()(prds[:,2], tgts[:,2]))) / 3
    else:
        return (torch.sqrt(weighted_mse_loss(prds[:,0], tgts[:,0], weight)) +
                torch.sqrt(weighted_mse_loss(prds[:,1], tgts[:,1], weight)) +
                torch.sqrt(weighted_mse_loss(prds[:,2], tgts[:,2], weight))) / 3

def build_id_seqpos(df):
    id_seqpos = []
    for i in range(len(df)):
        id = df.loc[i, 'id']
        seq_length = df.loc[i, 'seq_length']
        for seqpos in range(seq_length):
            id_seqpos.append(id + '_' + str(seqpos))
    return id_seqpos

def sample_is_clean(row):
    return row['SN_filter'] == 1
    #return row['signal_to_noise'] > 1 and \
    #       min((min(row['reactivity']),
    #            min(row['deg_Mg_pH10']),
    #            min(row['deg_pH10']),
    #            min(row['deg_Mg_50C']),
    #            min(row['deg_50C']))) > -0.5

# categorical value for target (used for stratified kfold)
def add_y_cat(df):
    target_mean = df['reactivity'].apply(np.mean) + \
                  df['deg_Mg_pH10'].apply(np.mean) + \
                  df['deg_Mg_50C'].apply(np.mean)
    df['y_cat'] = pd.qcut(np.array(target_mean), q=20).codes

In [8]:
#
# originally copied from
# https://github.com/rusty1s/pytorch_geometric/blob/master/examples/ogbn_proteins_deepgcn.py
# 
class MapE2NxN(torch.nn.Module):
    def __init__(self, in_channels, out_channels, hidden_channels):
        super(MapE2NxN, self).__init__()
        self.linear1 = Linear(in_channels, hidden_channels)
        self.linear2 = Linear(hidden_channels, out_channels)
        self.dropout = Dropout(dropout3)
        self.gelu = nn.GELU()
        
    def forward(self, x):
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

class MyDeeperGCN(torch.nn.Module):
    def __init__(self, num_node_features, num_edge_features,
                 node_hidden_channels,
                 edge_hidden_channels,
                 num_layers, num_classes):
        super(MyDeeperGCN, self).__init__()

        self.node_encoder = ChebConv(num_node_features, node_hidden_channels, T)
        self.edge_encoder = Linear(num_edge_features, edge_hidden_channels)

        self.layers = torch.nn.ModuleList()
        for i in range(1, num_layers + 1):
            conv = NNConv(node_hidden_channels, node_hidden_channels,
                          MapE2NxN(edge_hidden_channels,
                                   node_hidden_channels * node_hidden_channels,
                                   hidden_channels3))
            norm = LayerNorm(node_hidden_channels, elementwise_affine=True)
            act = nn.GELU()

            layer = DeepGCNLayer(conv, norm, act, block='res+',
                                 dropout=dropout1, ckpt_grad=i % 3)
            self.layers.append(layer)

        self.lin = Linear(node_hidden_channels, num_classes)
        self.dropout = Dropout(dropout2)

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        edge_attr = data.edge_attr

        # edge for paired nodes are excluded for encoding node
        seq_edge_index = edge_index[:, edge_attr[:,0] == 0]
        x = self.node_encoder(x, seq_edge_index)

        edge_attr = self.edge_encoder(edge_attr)

        x = self.layers[0].conv(x, edge_index, edge_attr)

        for layer in self.layers[1:]:
            x = layer(x, edge_index, edge_attr)

        x = self.layers[0].act(self.layers[0].norm(x))
        x = self.dropout(x)

        return self.lin(x)

In [9]:
seed = 777
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)

In [10]:
FN = FilePaths("data")
df_tr = pd.read_json(FN.train_json, lines=True)
add_y_cat(df_tr)

In [11]:
T = 5
hidden_channels3 = 32
num_layers = 10
dropout1 = 0.1
dropout2 = 0.1
dropout3 = 0.1
HPARAMS = {
    "nb_fold": 5,
    "filter_noise": True,
    "signal_to_noise_ratio": 0.5,
    "batch_size": 16,
    "lr": 1e-3,
    "wd": 0,
    "num_layers": 10,
    "node_hidden_channels": 96,
    "edge_hidden_channels": 16
}

In [12]:
device = utils.get_device()
all_ys = torch.zeros((0, 3)).to(device).detach()
all_outs = torch.zeros((0, 3)).to(device).detach()
best_model_states = []
cvlist = list(StratifiedKFold(HPARAMS["nb_fold"], shuffle=True, random_state=seed).split(df_tr, df_tr["y_cat"]))

In [13]:
from pytorch_geometric_dataset import prepare_dataset

In [14]:
def get_dataloader(df, hparams):
    # data_train = build_data(df.reset_index(drop=True), True)
    data_train = prepare_dataset(df, True, 0.5)
    return data_train, DataLoader(data_train, batch_size=hparams["batch_size"], shuffle=True)


def train_fold(model, loader_train, loader_valid, optimizer, criterion, epochs, device):
    best_mcrmse = np.inf
    for epoch in range(epochs):
        print('Epoch', epoch)
        model.train()
        train_loss = 0.0
        nb = 0
        for data in tqdm(loader_train):
            data = data.to(device)
            mask = data.train_mask
            weight = data.weight[mask]

            optimizer.zero_grad()
            out = model(data)[mask]
            y = data.y[mask]
            loss = criterion(out, y, weight)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * y.size(0)
            nb += y.size(0)

            del data
            del out
            del y
            del loss
            #gc.collect()
            #torch.cuda.empty_cache()
        train_loss /= nb

        model.eval()
        valid_loss = 0.0
        nb = 0
        ys = torch.zeros((0, 3)).to(device).detach()
        outs = torch.zeros((0, 3)).to(device).detach()
        for data in tqdm(loader_valid):
            data = data.to(device)
            mask = data.train_mask

            out = model(data)[mask].detach()
            y = data.y[mask].detach()
            loss = criterion(out, y).detach()
            valid_loss += loss.item() * y.size(0)
            nb += y.size(0)

            outs = torch.cat((outs, out), dim=0)
            ys = torch.cat((ys, y), dim=0)

            del data
            del out
            del y
            del loss
            #gc.collect()
            #torch.cuda.empty_cache()
        valid_loss /= nb

        mcrmse = criterion(outs, ys).item()

        print("T Loss: {:.4f} V Loss: {:.4f} V MCRMSE: {:.4f}".\
                format(train_loss, valid_loss, mcrmse))

        if mcrmse < best_mcrmse:
            print('Best valid MCRMSE updated to', mcrmse)
            best_mcrmse = mcrmse
            best_model_state = copy.deepcopy(model.state_dict())
    return best_model_state

In [15]:
#for i, (tr_idx, vl_idx) in enumerate(cvlist):
tr, vl = df_tr.iloc[cvlist[0][0]], df_tr.iloc[cvlist[0][1]]

if HPARAMS["filter_noise"]:
    cond = tr.apply(calc_error_mean, axis=1) < 0.5
    tr = tr.loc[cond].reset_index(drop=True)

vl = vl.loc[vl["SN_filter"] == 1].reset_index(drop=True)
print(tr.shape, vl.shape)

(1720, 20) (324, 20)


In [16]:
data_train, loader_train = get_dataloader(tr, HPARAMS)
data_valid, loader_valid = get_dataloader(vl, HPARAMS)

In [17]:
model = MyDeeperGCN(data_train[0].num_node_features,
                    data_train[0].num_edge_features,
                    node_hidden_channels=HPARAMS["node_hidden_channels"],
                    edge_hidden_channels=HPARAMS["edge_hidden_channels"],
                    num_layers=HPARAMS["num_layers"],
                    num_classes=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=HPARAMS["lr"], weight_decay=HPARAMS["wd"])

In [18]:
best_state_fold0 = train_fold(model, loader_train, loader_valid, optimizer, criterion, 100, device)

  1%|          | 1/108 [00:00<00:16,  6.48it/s]

Epoch 0


100%|██████████| 108/108 [00:07<00:00, 15.40it/s]
100%|██████████| 21/21 [00:00<00:00, 43.55it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.87it/s]

T Loss: 0.3849 V Loss: 0.3370 V MCRMSE: 0.3378
Best valid MCRMSE updated to 0.33781084418296814
Epoch 1


100%|██████████| 108/108 [00:06<00:00, 15.61it/s]
100%|██████████| 21/21 [00:00<00:00, 43.15it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.92it/s]

T Loss: 0.3189 V Loss: 0.3080 V MCRMSE: 0.3092
Best valid MCRMSE updated to 0.3092217743396759
Epoch 2


100%|██████████| 108/108 [00:06<00:00, 15.64it/s]
100%|██████████| 21/21 [00:00<00:00, 43.63it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.63it/s]

T Loss: 0.2953 V Loss: 0.2931 V MCRMSE: 0.2946
Best valid MCRMSE updated to 0.2945800721645355
Epoch 3


100%|██████████| 108/108 [00:06<00:00, 15.65it/s]
100%|██████████| 21/21 [00:00<00:00, 43.42it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.59it/s]

T Loss: 0.2829 V Loss: 0.2803 V MCRMSE: 0.2816
Best valid MCRMSE updated to 0.28162264823913574
Epoch 4


100%|██████████| 108/108 [00:06<00:00, 15.66it/s]
100%|██████████| 21/21 [00:00<00:00, 43.52it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.71it/s]

T Loss: 0.2760 V Loss: 0.2764 V MCRMSE: 0.2776
Best valid MCRMSE updated to 0.2776244878768921
Epoch 5


100%|██████████| 108/108 [00:06<00:00, 15.62it/s]
100%|██████████| 21/21 [00:00<00:00, 43.66it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.53it/s]

T Loss: 0.2660 V Loss: 0.2713 V MCRMSE: 0.2726
Best valid MCRMSE updated to 0.27258795499801636
Epoch 6


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.50it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.71it/s]

T Loss: 0.2610 V Loss: 0.2676 V MCRMSE: 0.2691
Best valid MCRMSE updated to 0.26907917857170105
Epoch 7


100%|██████████| 108/108 [00:06<00:00, 15.66it/s]
100%|██████████| 21/21 [00:00<00:00, 43.47it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.79it/s]

T Loss: 0.2553 V Loss: 0.2643 V MCRMSE: 0.2662
Best valid MCRMSE updated to 0.26624464988708496
Epoch 8


100%|██████████| 108/108 [00:06<00:00, 15.66it/s]
100%|██████████| 21/21 [00:00<00:00, 43.36it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.70it/s]

T Loss: 0.2527 V Loss: 0.2602 V MCRMSE: 0.2615
Best valid MCRMSE updated to 0.2615025043487549
Epoch 9


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.67it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.74it/s]

T Loss: 0.2490 V Loss: 0.2605 V MCRMSE: 0.2625
Epoch 10


100%|██████████| 108/108 [00:06<00:00, 15.65it/s]
100%|██████████| 21/21 [00:00<00:00, 43.68it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.76it/s]

T Loss: 0.2455 V Loss: 0.2593 V MCRMSE: 0.2608
Best valid MCRMSE updated to 0.2607875168323517
Epoch 11


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.64it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.77it/s]

T Loss: 0.2429 V Loss: 0.2555 V MCRMSE: 0.2570
Best valid MCRMSE updated to 0.2570015788078308
Epoch 12


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.52it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.64it/s]

T Loss: 0.2386 V Loss: 0.2513 V MCRMSE: 0.2530
Best valid MCRMSE updated to 0.25303491950035095
Epoch 13


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.64it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.67it/s]

T Loss: 0.2357 V Loss: 0.2539 V MCRMSE: 0.2552
Epoch 14


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.59it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.85it/s]

T Loss: 0.2352 V Loss: 0.2550 V MCRMSE: 0.2570
Epoch 15


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.62it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.71it/s]

T Loss: 0.2328 V Loss: 0.2527 V MCRMSE: 0.2539
Epoch 16


100%|██████████| 108/108 [00:06<00:00, 15.61it/s]
100%|██████████| 21/21 [00:00<00:00, 43.36it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.74it/s]

T Loss: 0.2306 V Loss: 0.2517 V MCRMSE: 0.2529
Best valid MCRMSE updated to 0.25287923216819763
Epoch 17


100%|██████████| 108/108 [00:06<00:00, 15.59it/s]
100%|██████████| 21/21 [00:00<00:00, 43.56it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.85it/s]

T Loss: 0.2293 V Loss: 0.2486 V MCRMSE: 0.2502
Best valid MCRMSE updated to 0.25015509128570557
Epoch 18


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.58it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.67it/s]

T Loss: 0.2260 V Loss: 0.2507 V MCRMSE: 0.2516
Epoch 19


100%|██████████| 108/108 [00:06<00:00, 15.66it/s]
100%|██████████| 21/21 [00:00<00:00, 43.58it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.81it/s]

T Loss: 0.2236 V Loss: 0.2491 V MCRMSE: 0.2508
Epoch 20


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.39it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.58it/s]

T Loss: 0.2225 V Loss: 0.2471 V MCRMSE: 0.2485
Best valid MCRMSE updated to 0.2485448122024536
Epoch 21


100%|██████████| 108/108 [00:06<00:00, 15.57it/s]
100%|██████████| 21/21 [00:00<00:00, 43.19it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.82it/s]

T Loss: 0.2216 V Loss: 0.2475 V MCRMSE: 0.2487
Epoch 22


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.63it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.79it/s]

T Loss: 0.2189 V Loss: 0.2446 V MCRMSE: 0.2466
Best valid MCRMSE updated to 0.24659499526023865
Epoch 23


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.63it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.63it/s]

T Loss: 0.2179 V Loss: 0.2494 V MCRMSE: 0.2512
Epoch 24


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.61it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.82it/s]

T Loss: 0.2156 V Loss: 0.2461 V MCRMSE: 0.2472
Epoch 25


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.64it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.80it/s]

T Loss: 0.2133 V Loss: 0.2433 V MCRMSE: 0.2455
Best valid MCRMSE updated to 0.24549441039562225
Epoch 26


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.71it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.83it/s]

T Loss: 0.2122 V Loss: 0.2459 V MCRMSE: 0.2477
Epoch 27


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.51it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.83it/s]

T Loss: 0.2119 V Loss: 0.2479 V MCRMSE: 0.2492
Epoch 28


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.64it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.63it/s]

T Loss: 0.2104 V Loss: 0.2439 V MCRMSE: 0.2454
Best valid MCRMSE updated to 0.2453605979681015
Epoch 29


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.73it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.58it/s]

T Loss: 0.2077 V Loss: 0.2451 V MCRMSE: 0.2464
Epoch 30


100%|██████████| 108/108 [00:06<00:00, 15.70it/s]
100%|██████████| 21/21 [00:00<00:00, 43.66it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.73it/s]

T Loss: 0.2070 V Loss: 0.2482 V MCRMSE: 0.2496
Epoch 31


100%|██████████| 108/108 [00:06<00:00, 15.71it/s]
100%|██████████| 21/21 [00:00<00:00, 43.57it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.74it/s]

T Loss: 0.2049 V Loss: 0.2417 V MCRMSE: 0.2431
Best valid MCRMSE updated to 0.24308966100215912
Epoch 32


100%|██████████| 108/108 [00:06<00:00, 15.70it/s]
100%|██████████| 21/21 [00:00<00:00, 43.71it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.79it/s]

T Loss: 0.2036 V Loss: 0.2421 V MCRMSE: 0.2436
Epoch 33


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.45it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.63it/s]

T Loss: 0.2018 V Loss: 0.2401 V MCRMSE: 0.2416
Best valid MCRMSE updated to 0.24156604707241058
Epoch 34


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.66it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.52it/s]

T Loss: 0.2008 V Loss: 0.2436 V MCRMSE: 0.2447
Epoch 35


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.66it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.94it/s]

T Loss: 0.1996 V Loss: 0.2400 V MCRMSE: 0.2412
Best valid MCRMSE updated to 0.2411687821149826
Epoch 36


100%|██████████| 108/108 [00:06<00:00, 15.62it/s]
100%|██████████| 21/21 [00:00<00:00, 43.62it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.55it/s]

T Loss: 0.1977 V Loss: 0.2402 V MCRMSE: 0.2424
Epoch 37


100%|██████████| 108/108 [00:06<00:00, 15.66it/s]
100%|██████████| 21/21 [00:00<00:00, 43.59it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.66it/s]

T Loss: 0.1969 V Loss: 0.2441 V MCRMSE: 0.2453
Epoch 38


100%|██████████| 108/108 [00:06<00:00, 15.65it/s]
100%|██████████| 21/21 [00:00<00:00, 43.39it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.74it/s]

T Loss: 0.1956 V Loss: 0.2408 V MCRMSE: 0.2424
Epoch 39


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.61it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.65it/s]

T Loss: 0.1938 V Loss: 0.2394 V MCRMSE: 0.2415
Epoch 40


100%|██████████| 108/108 [00:06<00:00, 15.64it/s]
100%|██████████| 21/21 [00:00<00:00, 43.41it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.47it/s]

T Loss: 0.1940 V Loss: 0.2391 V MCRMSE: 0.2405
Best valid MCRMSE updated to 0.24049684405326843
Epoch 41


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.58it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.00it/s]

T Loss: 0.1931 V Loss: 0.2418 V MCRMSE: 0.2434
Epoch 42


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.65it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.58it/s]

T Loss: 0.1916 V Loss: 0.2396 V MCRMSE: 0.2413
Epoch 43


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.56it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.81it/s]

T Loss: 0.1905 V Loss: 0.2414 V MCRMSE: 0.2430
Epoch 44


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.60it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.49it/s]

T Loss: 0.1886 V Loss: 0.2415 V MCRMSE: 0.2430
Epoch 45


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.68it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.63it/s]

T Loss: 0.1889 V Loss: 0.2387 V MCRMSE: 0.2401
Best valid MCRMSE updated to 0.24014076590538025
Epoch 46


100%|██████████| 108/108 [00:06<00:00, 15.70it/s]
100%|██████████| 21/21 [00:00<00:00, 43.78it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.93it/s]

T Loss: 0.1874 V Loss: 0.2370 V MCRMSE: 0.2387
Best valid MCRMSE updated to 0.23873357474803925
Epoch 47


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.82it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.81it/s]

T Loss: 0.1864 V Loss: 0.2395 V MCRMSE: 0.2417
Epoch 48


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.69it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.75it/s]

T Loss: 0.1850 V Loss: 0.2391 V MCRMSE: 0.2407
Epoch 49


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.73it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.87it/s]

T Loss: 0.1837 V Loss: 0.2401 V MCRMSE: 0.2415
Epoch 50


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.29it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.66it/s]

T Loss: 0.1834 V Loss: 0.2381 V MCRMSE: 0.2398
Epoch 51


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.50it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.81it/s]

T Loss: 0.1813 V Loss: 0.2371 V MCRMSE: 0.2384
Best valid MCRMSE updated to 0.23835858702659607
Epoch 52


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.65it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.38it/s]

T Loss: 0.1814 V Loss: 0.2409 V MCRMSE: 0.2424
Epoch 53


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.43it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.87it/s]

T Loss: 0.1809 V Loss: 0.2387 V MCRMSE: 0.2404
Epoch 54


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.43it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.66it/s]

T Loss: 0.1793 V Loss: 0.2415 V MCRMSE: 0.2430
Epoch 55


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.53it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.50it/s]

T Loss: 0.1781 V Loss: 0.2387 V MCRMSE: 0.2407
Epoch 56


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.62it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.81it/s]

T Loss: 0.1777 V Loss: 0.2378 V MCRMSE: 0.2393
Epoch 57


100%|██████████| 108/108 [00:06<00:00, 15.70it/s]
100%|██████████| 21/21 [00:00<00:00, 43.82it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.81it/s]

T Loss: 0.1771 V Loss: 0.2379 V MCRMSE: 0.2392
Epoch 58


100%|██████████| 108/108 [00:06<00:00, 15.64it/s]
100%|██████████| 21/21 [00:00<00:00, 43.31it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.53it/s]

T Loss: 0.1772 V Loss: 0.2418 V MCRMSE: 0.2430
Epoch 59


100%|██████████| 108/108 [00:06<00:00, 15.63it/s]
100%|██████████| 21/21 [00:00<00:00, 43.59it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.93it/s]

T Loss: 0.1757 V Loss: 0.2380 V MCRMSE: 0.2397
Epoch 60


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.81it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.82it/s]

T Loss: 0.1739 V Loss: 0.2372 V MCRMSE: 0.2396
Epoch 61


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.74it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.65it/s]

T Loss: 0.1748 V Loss: 0.2382 V MCRMSE: 0.2397
Epoch 62


100%|██████████| 108/108 [00:06<00:00, 15.70it/s]
100%|██████████| 21/21 [00:00<00:00, 43.71it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.64it/s]

T Loss: 0.1727 V Loss: 0.2377 V MCRMSE: 0.2392
Epoch 63


100%|██████████| 108/108 [00:06<00:00, 15.71it/s]
100%|██████████| 21/21 [00:00<00:00, 43.70it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.77it/s]

T Loss: 0.1727 V Loss: 0.2370 V MCRMSE: 0.2389
Epoch 64


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.59it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.86it/s]

T Loss: 0.1716 V Loss: 0.2366 V MCRMSE: 0.2384
Epoch 65


100%|██████████| 108/108 [00:06<00:00, 15.70it/s]
100%|██████████| 21/21 [00:00<00:00, 43.58it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.67it/s]

T Loss: 0.1713 V Loss: 0.2391 V MCRMSE: 0.2405
Epoch 66


100%|██████████| 108/108 [00:06<00:00, 15.71it/s]
100%|██████████| 21/21 [00:00<00:00, 43.80it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.64it/s]

T Loss: 0.1708 V Loss: 0.2385 V MCRMSE: 0.2398
Epoch 67


100%|██████████| 108/108 [00:06<00:00, 15.72it/s]
100%|██████████| 21/21 [00:00<00:00, 43.58it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.87it/s]

T Loss: 0.1693 V Loss: 0.2393 V MCRMSE: 0.2408
Epoch 68


100%|██████████| 108/108 [00:06<00:00, 15.70it/s]
100%|██████████| 21/21 [00:00<00:00, 43.89it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.71it/s]

T Loss: 0.1690 V Loss: 0.2398 V MCRMSE: 0.2413
Epoch 69


100%|██████████| 108/108 [00:06<00:00, 15.71it/s]
100%|██████████| 21/21 [00:00<00:00, 43.70it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.74it/s]

T Loss: 0.1686 V Loss: 0.2378 V MCRMSE: 0.2389
Epoch 70


100%|██████████| 108/108 [00:06<00:00, 15.69it/s]
100%|██████████| 21/21 [00:00<00:00, 43.68it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.73it/s]

T Loss: 0.1678 V Loss: 0.2373 V MCRMSE: 0.2389
Epoch 71


100%|██████████| 108/108 [00:06<00:00, 15.72it/s]
100%|██████████| 21/21 [00:00<00:00, 43.84it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.71it/s]

T Loss: 0.1675 V Loss: 0.2378 V MCRMSE: 0.2393
Epoch 72


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.67it/s]
  2%|▏         | 2/108 [00:00<00:06, 16.15it/s]

T Loss: 0.1659 V Loss: 0.2380 V MCRMSE: 0.2395
Epoch 73


100%|██████████| 108/108 [00:06<00:00, 15.66it/s]
100%|██████████| 21/21 [00:00<00:00, 43.46it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.86it/s]

T Loss: 0.1668 V Loss: 0.2374 V MCRMSE: 0.2389
Epoch 74


100%|██████████| 108/108 [00:06<00:00, 15.59it/s]
100%|██████████| 21/21 [00:00<00:00, 43.24it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.62it/s]

T Loss: 0.1649 V Loss: 0.2392 V MCRMSE: 0.2406
Epoch 75


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.47it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.65it/s]

T Loss: 0.1653 V Loss: 0.2373 V MCRMSE: 0.2392
Epoch 76


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.50it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.78it/s]

T Loss: 0.1643 V Loss: 0.2404 V MCRMSE: 0.2416
Epoch 77


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.61it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.94it/s]

T Loss: 0.1632 V Loss: 0.2374 V MCRMSE: 0.2389
Epoch 78


100%|██████████| 108/108 [00:06<00:00, 15.67it/s]
100%|██████████| 21/21 [00:00<00:00, 43.41it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.76it/s]

T Loss: 0.1633 V Loss: 0.2365 V MCRMSE: 0.2389
Epoch 79


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.70it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.64it/s]

T Loss: 0.1614 V Loss: 0.2374 V MCRMSE: 0.2393
Epoch 80


100%|██████████| 108/108 [00:06<00:00, 15.65it/s]
100%|██████████| 21/21 [00:00<00:00, 43.33it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.51it/s]

T Loss: 0.1621 V Loss: 0.2374 V MCRMSE: 0.2390
Epoch 81


100%|██████████| 108/108 [00:06<00:00, 15.61it/s]
100%|██████████| 21/21 [00:00<00:00, 43.53it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.70it/s]

T Loss: 0.1615 V Loss: 0.2401 V MCRMSE: 0.2420
Epoch 82


100%|██████████| 108/108 [00:06<00:00, 15.62it/s]
100%|██████████| 21/21 [00:00<00:00, 42.82it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.62it/s]

T Loss: 0.1615 V Loss: 0.2357 V MCRMSE: 0.2383
Best valid MCRMSE updated to 0.2382880449295044
Epoch 83


100%|██████████| 108/108 [00:07<00:00, 15.42it/s]
100%|██████████| 21/21 [00:00<00:00, 42.99it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.54it/s]

T Loss: 0.1601 V Loss: 0.2379 V MCRMSE: 0.2393
Epoch 84


100%|██████████| 108/108 [00:06<00:00, 15.59it/s]
100%|██████████| 21/21 [00:00<00:00, 43.43it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.62it/s]

T Loss: 0.1596 V Loss: 0.2394 V MCRMSE: 0.2411
Epoch 85


100%|██████████| 108/108 [00:06<00:00, 15.61it/s]
100%|██████████| 21/21 [00:00<00:00, 43.49it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.69it/s]

T Loss: 0.1597 V Loss: 0.2381 V MCRMSE: 0.2398
Epoch 86


100%|██████████| 108/108 [00:06<00:00, 15.61it/s]
100%|██████████| 21/21 [00:00<00:00, 42.97it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.43it/s]

T Loss: 0.1585 V Loss: 0.2353 V MCRMSE: 0.2374
Best valid MCRMSE updated to 0.23735183477401733
Epoch 87


100%|██████████| 108/108 [00:06<00:00, 15.61it/s]
100%|██████████| 21/21 [00:00<00:00, 43.22it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.61it/s]

T Loss: 0.1584 V Loss: 0.2375 V MCRMSE: 0.2392
Epoch 88


100%|██████████| 108/108 [00:06<00:00, 15.62it/s]
100%|██████████| 21/21 [00:00<00:00, 42.72it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.46it/s]

T Loss: 0.1574 V Loss: 0.2391 V MCRMSE: 0.2408
Epoch 89


100%|██████████| 108/108 [00:06<00:00, 15.64it/s]
100%|██████████| 21/21 [00:00<00:00, 43.40it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.73it/s]

T Loss: 0.1587 V Loss: 0.2356 V MCRMSE: 0.2370
Best valid MCRMSE updated to 0.23699951171875
Epoch 90


100%|██████████| 108/108 [00:06<00:00, 15.68it/s]
100%|██████████| 21/21 [00:00<00:00, 43.55it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.58it/s]

T Loss: 0.1570 V Loss: 0.2390 V MCRMSE: 0.2403
Epoch 91


100%|██████████| 108/108 [00:06<00:00, 15.63it/s]
100%|██████████| 21/21 [00:00<00:00, 43.35it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.92it/s]

T Loss: 0.1564 V Loss: 0.2375 V MCRMSE: 0.2393
Epoch 92


100%|██████████| 108/108 [00:06<00:00, 15.62it/s]
100%|██████████| 21/21 [00:00<00:00, 43.16it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.50it/s]

T Loss: 0.1561 V Loss: 0.2370 V MCRMSE: 0.2385
Epoch 93


100%|██████████| 108/108 [00:06<00:00, 15.53it/s]
100%|██████████| 21/21 [00:00<00:00, 42.95it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.45it/s]

T Loss: 0.1556 V Loss: 0.2354 V MCRMSE: 0.2373
Epoch 94


100%|██████████| 108/108 [00:06<00:00, 15.57it/s]
100%|██████████| 21/21 [00:00<00:00, 43.24it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.53it/s]

T Loss: 0.1543 V Loss: 0.2370 V MCRMSE: 0.2383
Epoch 95


100%|██████████| 108/108 [00:06<00:00, 15.58it/s]
100%|██████████| 21/21 [00:00<00:00, 43.37it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.72it/s]

T Loss: 0.1548 V Loss: 0.2373 V MCRMSE: 0.2386
Epoch 96


100%|██████████| 108/108 [00:06<00:00, 15.60it/s]
100%|██████████| 21/21 [00:00<00:00, 43.08it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.66it/s]

T Loss: 0.1543 V Loss: 0.2386 V MCRMSE: 0.2403
Epoch 97


100%|██████████| 108/108 [00:06<00:00, 15.62it/s]
100%|██████████| 21/21 [00:00<00:00, 43.26it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.96it/s]

T Loss: 0.1538 V Loss: 0.2374 V MCRMSE: 0.2395
Epoch 98


100%|██████████| 108/108 [00:06<00:00, 15.54it/s]
100%|██████████| 21/21 [00:00<00:00, 41.87it/s]
  2%|▏         | 2/108 [00:00<00:06, 15.45it/s]

T Loss: 0.1533 V Loss: 0.2370 V MCRMSE: 0.2384
Epoch 99


100%|██████████| 108/108 [00:07<00:00, 15.36it/s]
100%|██████████| 21/21 [00:00<00:00, 42.15it/s]

T Loss: 0.1528 V Loss: 0.2366 V MCRMSE: 0.2381





In [None]:
train = pd.read_json("data/train.json", lines=True)
cond = train.SN_filter == 1

In [None]:
tr1 = train.loc[cond]
tr2 = train.loc[~cond]

In [None]:
e1 = np.vstack(tr1["reactivity_error"].values)
e2 = np.vstack(tr1["deg_error_Mg_50C"].values)
e3 = np.vstack(tr1["deg_error_Mg_pH10"].values)

In [None]:
pd.Series(e1.flatten()).describe()

In [None]:
pd.Series(e2.flatten()).describe()

In [None]:
pd.Series(e3.flatten()).describe()

In [None]:
(e1 > 0.5).sum()

In [None]:
pd.Series(e1.mean(1).flatten()).describe()

In [None]:
pd.Series(e2.mean(1).flatten()).describe()

In [None]:
pd.Series(e3.mean(1).flatten()).describe()

In [84]:
e = np.dstack((e1, e2, e3))
e.shape

(1589, 68, 3)

In [85]:
pd.Series(e.mean(1).mean(1)).describe()

count    1589.000000
mean        0.103780
std         0.065435
min         0.024189
25%         0.063739
50%         0.083505
75%         0.118433
max         0.598469
dtype: float64

In [86]:
snr = tr1["signal_to_noise"]
snr.describe()

count    1589.000000
mean        5.402215
std         2.524798
min         0.993000
25%         3.484000
50%         5.222000
75%         6.853000
max        17.194000
Name: signal_to_noise, dtype: float64

In [87]:
e1 = np.vstack(tr2["reactivity_error"].values)
e2 = np.vstack(tr2["deg_error_Mg_50C"].values)
e3 = np.vstack(tr2["deg_error_Mg_pH10"].values)
e = np.dstack((e1, e2, e3))
pd.Series(e.mean(1).mean(1)).describe()

count       811.000000
mean       8630.316650
std       29131.240267
min           0.034662
25%           0.088635
50%           0.195730
75%           0.641706
max      140637.240300
dtype: float64

In [89]:
e1 = np.vstack(train["reactivity_error"].values)
e2 = np.vstack(train["deg_error_Mg_50C"].values)
e3 = np.vstack(train["deg_error_Mg_pH10"].values)
e = np.dstack((e1, e2, e3))
sum(e.mean(1).mean(1) > 0.6)

211

In [20]:
tr.iloc[0]

index                                                                  2
id                                                          id_006f36f57
sequence               GGAAAGUGCUCAGAUAAGCUAAGCUCGAAUAGCAAUCGAAUAGAAU...
structure              .....((((.((.....((((.(((.....)))..((((......)...
predicted_loop_type    EEEEESSSSISSIIIIISSSSMSSSHHHHHSSSMMSSSSHHHHHHS...
signal_to_noise                                                      8.8
SN_filter                                                              1
seq_length                                                           107
seq_scored                                                            68
reactivity_error       [0.0931, 0.13290000000000002, 0.11280000000000...
deg_error_Mg_pH10      [0.1365, 0.2237, 0.1812, 0.1333, 0.1148, 0.160...
deg_error_pH10         [0.17020000000000002, 0.178, 0.111, 0.091, 0.0...
deg_error_Mg_50C       [0.1033, 0.1464, 0.1126, 0.09620000000000001, ...
deg_error_50C          [0.14980000000000002, 0.1761

In [87]:
import subprocess
seq_id = tr.id.iloc[191]
seq = tr.sequence.iloc[191]
struct = tr.structure.iloc[191]

subprocess.run(f"echo {seq} > {seq_id}.dbn", shell=True)
subprocess.run(f"echo '{struct}' >> {seq_id}.dbn", shell=True)
proc = subprocess.Popen(['perl', 'bpRNA/bpRNA.pl', f'{seq_id}.dbn'])
proc.wait()
with open(f"{seq_id}.st") as stf:
    result = [l.strip('\n') for l in stf]
result

['#Name: id_1bb2f1786',
 '#Length:  107 ',
 '#PageNumber: 1',
 'GGAAAUUACAAGACCCGGGCCGAGGUGAAGUUCGAGGGCGACACCUUGGUGAACCGGAUCGAGUUAAAUCAGGUCUUCGGAUUUGAAAAAGAAACAACAACAACAAC',
 '.......((..((.(((((((((((((..(((....)))..)))))))))...)))).))..))....(((((((....))))))).....................',
 'EEEEEEESSIISSISSSSSSSSSSSSSIISSSHHHHSSSIISSSSSSSSSBBBSSSSISSIISSXXXXSSSSSSSHHHHSSSSSSSEEEEEEEEEEEEEEEEEEEEE',
 'NNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNNN',
 'S1 8..9 "AC" 63..64 "GU"',
 'S2 12..13 "GA" 59..60 "UC"',
 'S3 15..18 "CCGG" 54..57 "CCGG"',
 'S4 19..27 "GCCGAGGUG" 42..50 "CACCUUGGU"',
 'S5 30..32 "GUU" 37..39 "GGC"',
 'S6 69..75 "UCAGGUC" 80..86 "GAUUUGA"',
 'H1 33..36 "CGAG" (32,37) U:G ',
 'H2 76..79 "UUCG" (75,80) C:G ',
 'B1 51..53 "GAA" (50,19) U:G (54,18) C:G ',
 'I1.1 10..11 "AA" (9,63) C:G ',
 'I1.2 61..62 "GA" (60,12) C:G ',
 'I2.1 14..14 "C" (13,59) A:U ',
 'I2.2 58..58 "A" (57,15) G:C ',
 'I3.1 28..29 "AA" (27,42) G:C

In [93]:
pl_segs[50:53]

[1, 1, 1]

In [89]:
def map_segs(result, seq_len=107):
    seg_num = [-1] * seq_len
    seg_pairs = [0] * seq_len
    pl_segs = [0] * seq_len
    for linenum, line in enumerate(result):
        if linenum <= 6:
            continue
        if line.startswith("segment"):
            tokens = line.split(" ")
            num = int(tokens[0].strip("segment"))
            num_pairs = int(tokens[1].strip("bp"))
            i_1, _, i_2 = tokens[2].split(".")
            j_1, _, j_2 = tokens[4].split(".")
            i_1, i_2 = int(i_1), int(i_2)
            j_1, j_2 = int(j_1), int(j_2)
            for i in range(i_1, i_2+1):
                seg_num[i-1] = num
                seg_pairs[i-1] = num_pairs

            for j in range(j_1, j_2):
                seg_num[j-1] = num
                seg_pairs[j-1] = num_pairs

        elif line.startswith("S"):
            tokens = line.split(" ")
            num = int(tokens[0].strip("S"))
            i_1, _, i_2 = tokens[1].split(".")
            j_1, _, j_2 = tokens[3].split(".")
            i_1, i_2 = int(i_1), int(i_2)
            j_1, j_2 = int(j_1), int(j_2)
            for i in range(i_1, i_2+1):
                pl_segs[i-1] = num
            for j in range(i_1, i_2+1):
                pl_segs[j-1] = num
        else:
            tokens = line.split(" ")
            num = tokens[0].strip("BSHIMXE").split(".")
            if len(num) > 1:
                num = int(num[0]) * int(num[1]) + int(num[0]) - 1
            else:
                num = int(num[0])
            i_1, _, i_2 = tokens[1].split(".")
            i_1, i_2 = int(i_1), int(i_2)
            for i in range(i_1, i_2+1):
                pl_segs[i-1] = num
        
    return seg_num, seg_pairs, pl_segs
            

In [90]:
seg_num, seg_pairs, pl_segs = map_segs(result)


In [35]:
def get_segments(struc, ploop):
    n = len(struc)
    stack1 = []
    stack2 = []
    pairs = []
    loops = []
    for j in range(n):
        i = j+1
        s = struc[j]
        if s == "(":
            stack1.append(i)
        elif s == ")":
            pairs.append((stack1.pop(), i))
            ll = []
            while stack2:
                ll.append(stack2.pop())
            if ll:
                loops.append(ll)
        else:
            stack2.append(i)
        ll = []
        while stack2:
            ll.append(stack2.pop())
        if ll:
            loops.append(ll)
    return pairs, loops

In [38]:
pp, ll = get_segments(tr.structure.iloc[0])
pp

[(25, 31),
 (24, 32),
 (23, 33),
 (39, 46),
 (38, 47),
 (37, 48),
 (36, 49),
 (21, 52),
 (20, 53),
 (19, 54),
 (18, 55),
 (12, 61),
 (11, 62),
 (9, 64),
 (8, 65),
 (7, 66),
 (6, 67),
 (75, 80),
 (74, 81),
 (73, 82),
 (72, 83),
 (71, 84),
 (70, 85),
 (69, 86)]

In [49]:
def get_bpseg(pp):
    bpcnt = 1
    bpcnt_map = {}
    bpcnt_map[pp[0][0]] = bpcnt
    bpcnt_map[pp[0][1]] = bpcnt

    for k, (i, j) in enumerate(pp[1:]):
        if ((pp[k][0] - pp[k+1][0]) > 2) or ((pp[k+1][1] - pp[k][1]) > 2):
            bpcnt += 1
        bpcnt_map[i] = bpcnt
        bpcnt_map[j] = bpcnt
    return bpcnt_map

In [50]:
get_bpseg(pp)

{25: 1,
 31: 1,
 24: 1,
 32: 1,
 23: 1,
 33: 1,
 39: 2,
 46: 2,
 38: 2,
 47: 2,
 37: 2,
 48: 2,
 36: 2,
 49: 2,
 21: 3,
 52: 3,
 20: 3,
 53: 3,
 19: 3,
 54: 3,
 18: 3,
 55: 3,
 12: 4,
 61: 4,
 11: 4,
 62: 4,
 9: 4,
 64: 4,
 8: 4,
 65: 4,
 7: 4,
 66: 4,
 6: 4,
 67: 4,
 75: 5,
 80: 5,
 74: 5,
 81: 5,
 73: 5,
 82: 5,
 72: 5,
 83: 5,
 71: 5,
 84: 5,
 70: 5,
 85: 5,
 69: 5,
 86: 5}

In [22]:
torch.clamp(np.log(1e-4)

-9.210340371976182