# Transformer-based cmsiRpred

2-uni_v3.3_transformer_0829-Copy1

In [1]:
import time
import copy

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from torchinfo import summary
from torch.utils.data import Dataset, DataLoader, Subset
import torch.optim as optim
#from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import precision_score, recall_score, mean_absolute_error
from sklearn.model_selection import KFold
import forgi.graph.bulge_graph as fgb

BATCH_SIZE = 256

## Load data

In [2]:
df_structured_encoded = pd.read_pickle('/home/ken/MyStorage/siRNA_2503/Data/df_structured_encoded_0326.pkl')

df_structured_encoded_iid_trvl = df_structured_encoded[df_structured_encoded['dataset_usage']=='IID_trvl'].sample(frac=1)
df_structured_encoded_iid_test = df_structured_encoded[df_structured_encoded['dataset_usage']=='IID_test']
df_structured_encoded_ood_test = df_structured_encoded[df_structured_encoded['dataset_usage']=='OOD_test']
print(df_structured_encoded_iid_trvl.shape,df_structured_encoded_iid_test.shape,df_structured_encoded_ood_test.shape)

(20626, 165) (2568, 165) (2588, 165)


## Model design

In [3]:
structure_map = {
    's': 0, # stem
    'h': 1, # hairpin loop
    'i': 2, # interior loop
    'm': 3, # multiloop
    'f': 4, # fiveprime
    't': 5, # threeprime
    'P': 7 # no nt here (pad)
}
nt_map = {
    'P': 7, # no nt here (pad)
    'A': 0,
    'U': 1,
    'T': 2,
    'G': 3,
    'C': 4,
    'S': 5
}
modi_map = {
    'P': 7
    # 0～6 see vocab.csv
}


def get_nt_strtype_vec(dot_bracket,SEQ_MAX_LEN=28):
    import forgi.graph.bulge_graph as fgb
    bg = fgb.BulgeGraph.from_dotbracket(dot_bracket)
    elements_strcode = bg.to_element_string()
    elements_strcode += 'P'*(SEQ_MAX_LEN-len(elements_strcode))
    map_dict = {'P':7,'s':0,'h':1,'i':2,'m':3,'f':4,'t':5}
    elements_numcode = list(map(lambda x:map_dict[x],list(elements_strcode)))
    return torch.tensor(elements_numcode)
    
class siRNA_dataset_trsfmr(Dataset):
    def __init__(self, df_structured_encoded):
        label_tensor = torch.tensor(list(df_structured_encoded['mRNA_remaining_pct']))
        self.label_tensor = label_tensor.reshape([len(label_tensor),1]).to(torch.float32)
        self.seq_sense_index = self._pad_to_equal_length(df_structured_encoded['seq_agct_int_sense'])
        self.seq_antis_index = self._pad_to_equal_length(df_structured_encoded['seq_agct_int_anti'])
        self.modi_sense_index = self._pad_to_equal_length(df_structured_encoded['seq_modi_int_sense'])
        self.modi_antis_index = self._pad_to_equal_length(df_structured_encoded['seq_modi_int_anti'])
        self.struct_sense_index = torch.stack(list(df_structured_encoded['dp_mfe_sense'].apply(get_nt_strtype_vec)))
        self.struct_sense_index = torch.flip(self.struct_sense_index,dims=[1])
        self.struct_antis_index = torch.stack(list(df_structured_encoded['dp_mfe_antis'].apply(get_nt_strtype_vec)))
        df_tabular_encoded = df_structured_encoded.loc[:,df_structured_encoded.columns.str.contains(r'!\w+!')]
        self.features_tensor = torch.tensor(df_tabular_encoded.values).to(torch.float32)
        self.domain_label_A = np.array(df_structured_encoded['publication_id'])
    def __getitem__(self,index):
        return (self.seq_sense_index[index],self.seq_antis_index[index],
                self.modi_sense_index[index],self.modi_antis_index[index],
                self.struct_sense_index[index],self.struct_antis_index[index],
                self.features_tensor[index],self.label_tensor[index],self.domain_label_A[index])
    def __len__(self):
        return self.label_tensor.size(0)
    def _pad_to_equal_length(self,series, pad_value=7, length=28):
        num_list = series.str.split('').str[1:-1]
        padded = num_list.apply(lambda x: x + [str(pad_value)] * (length - len(x)))
        matrix = np.array(padded.tolist(), dtype=int)
        return torch.tensor(matrix)

def get_nt_strtype_vec(dot_bracket,SEQ_MAX_LEN=28):
    import forgi.graph.bulge_graph as fgb
    bg = fgb.BulgeGraph.from_dotbracket(dot_bracket)
    elements_strcode = bg.to_element_string()
    elements_strcode += 'P'*(SEQ_MAX_LEN-len(elements_strcode))
    map_dict = {'P':7,'s':0,'h':1,'i':2,'m':3,'f':4,'t':5}
    elements_numcode = list(map(lambda x:map_dict[x],list(elements_strcode)))
    return torch.tensor(elements_numcode)
    
class siRNA_dataset_for_trsfmr(Dataset):
    def __init__(self, df_encoded,df_struct):
        assert False not in (df_encoded.index == df_struct.index)
        label_tensor = torch.tensor(list(df_struct['mRNA_remaining_pct']))
        self.label_tensor = label_tensor.reshape([len(label_tensor),1]).to(torch.float32)
        self.seq_sense_index = self._pad_to_equal_length(df_struct['seq_agct_int_sense'])
        self.seq_antis_index = self._pad_to_equal_length(df_struct['seq_agct_int_anti'])
        self.modi_sense_index = self._pad_to_equal_length(df_struct['seq_modi_int_sense'])
        self.modi_antis_index = self._pad_to_equal_length(df_struct['seq_modi_int_anti'])
        self.struct_sense_index = torch.stack(list(df_struct['!!dp_mfe_sense'].apply(get_nt_strtype_vec)))
        self.struct_antis_index = torch.stack(list(df_struct['!!dp_mfe_antis'].apply(get_nt_strtype_vec)))
        self.features_tensor = torch.tensor(df_encoded.values).to(torch.float32)
        self.domain_label_A = np.array(df_struct['publication_id'])
    def __getitem__(self,index):
        return (self.seq_sense_index[index],self.seq_antis_index[index],
                self.modi_sense_index[index],self.modi_antis_index[index],
                self.struct_sense_index[index],self.struct_antis_index[index],
                self.features_tensor[index],self.label_tensor[index],self.domain_label_A[index])
    def __len__(self):
        return self.label_tensor.size(0)
    def _pad_to_equal_length(self,series, pad_value=7, length=28):
        num_list = series.str.split('').str[1:-1]
        padded = num_list.apply(lambda x: x + [str(pad_value)] * (length - len(x)))
        matrix = np.array(padded.tolist(), dtype=int)
        return torch.tensor(matrix)


import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):
    def __init__(self, d_model=48, max_len=28):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [4]:
class siRNATransformer(nn.Module):
    def __init__(self, embed_dim=32, d_model=32, nhead=16, num_layers=2, max_len=28, dropout=0.1, ablation=False):
        super().__init__()
        print('ablation:', ablation)
        self.ablation = ablation
        self.pad_idx = 7
        self.seq_embed    = nn.Embedding(8, embed_dim, padding_idx=self.pad_idx)
        self.modi_embed   = nn.Embedding(8, embed_dim, padding_idx=self.pad_idx)
        self.struct_embed = nn.Embedding(8, embed_dim, padding_idx=self.pad_idx)
        self.gate = nn.Sequential(
            nn.Linear(embed_dim * 3, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, 3)
        )

        self.combine_fc = nn.Linear(embed_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model, max_len)
        self.dropout = nn.Dropout(dropout)

        ff_dim = max(4 * d_model, 128)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.head = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 16)
        )

        assert d_model % nhead == 0, f"d_model({d_model}) must be divisible by nhead({nhead})"
        nn.init.xavier_uniform_(self.combine_fc.weight)
        nn.init.zeros_(self.combine_fc.bias)
        for m in self.head:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, seq, modi, struct):
        seq_emb    = self.seq_embed(seq)
        modi_emb   = self.modi_embed(modi)
        struct_emb = self.struct_embed(struct)
        emb_cat = torch.cat([seq_emb, modi_emb, struct_emb], dim=-1)
        gate_w = self.gate(emb_cat).sigmoid()
        if self.ablation:
            fused = gate_w[..., 0:1] * seq_emb + gate_w[..., 1:2] * modi_emb
        else:
            fused = (gate_w[..., 0:1] * seq_emb +
                     gate_w[..., 1:2] * modi_emb +
                     gate_w[..., 2:3] * struct_emb)
        features = self.combine_fc(fused)
        features = self.pos_encoder(features)
        features = self.dropout(features)
        pad_mask = (seq == self.pad_idx)
        enc = self.transformer(features, src_key_padding_mask=pad_mask)
        valid = (~pad_mask).unsqueeze(-1).float()
        enc = enc * valid
        denom = valid.sum(dim=1).clamp_min(1e-6)
        pooled = enc.sum(dim=1) / denom
        logits = self.head(pooled)
        return logits

In [5]:
class TfxMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = nn.Linear(109,256)
        self.actv1 = nn.ReLU()
        self.dense2 = nn.Linear(256,128)
        self.actv2 = nn.ReLU()
        self.dense3 = nn.Linear(128,16)
        self.actv3 = nn.ReLU()
    
    def forward(self,x):    
        x = self.dense1(x)
        x = self.actv1(x)
        x = self.dense2(x)
        x = self.actv2(x)
        x = self.dense3(x)
        x = self.actv3(x)
        return x.view(x.size(0),-1)

#summary(TfxMLP(), input_size=(BATCH_SIZE, 109))

class CombineMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.dense1 = nn.Linear(48,128)
        self.actv1 = nn.ReLU()
        self.dense2 = nn.Linear(128,128)
        self.actv2 = nn.ReLU()
        self.dense3 = nn.Linear(128,64)
        self.actv3 = nn.ReLU()
        self.actv_linear = nn.Linear(64,1)
    def forward(self,x):
        x = self.dense1(x)
        x = self.actv1(x)
        x = self.dense2(x)
        x = self.actv2(x)
        x = self.dense3(x)
        x = self.actv3(x)
        x = F.dropout(x,p=0.5,training=self.training)
        x = self.actv_linear(x)
        return x.view(x.size(0),-1)

#summary(CombineMLP(), input_size=(BATCH_SIZE, 48))

class modisiR_transformer(nn.Module):
    def __init__(self,ablation):
        super().__init__()
        self.transformer_sense = siRNATransformer(ablation=ablation)
        self.transformer_antis = siRNATransformer(ablation=ablation)
        self.tfx_mlp = TfxMLP()
        self.combine_mlp = CombineMLP()
    def forward(self,seq_sense,seq_antis,modi_sense,modi_antis,struct_sense,struct_antis,x_tfx):
        modiseqstr_sense = self.transformer_sense(seq_sense,modi_sense,struct_sense)
        modiseqstr_antis = self.transformer_antis(seq_antis,modi_antis,struct_antis)
        tfx_embed = self.tfx_mlp(x_tfx)
        x_combine = torch.cat([modiseqstr_sense,modiseqstr_antis,tfx_embed],axis=1)
        y_pred = self.combine_mlp(x_combine)
        return y_pred.reshape([len(y_pred),1])

#summary(modisiR_3dConv(), input_size=((256, 2,28,6,7),(256,1,28,7),(256,109)))

def train_ERM(dataload_TRAIN,model,optimizer,criterion):
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(DEVICE)
    model.train()
    loss_train = 0
    
    for seq_sense_index, seq_antis_index, modi_sense_index, modi_antis_index, struct_sense_index, struct_antis_index, features_tensor, label_tensor, domain_label_A in dataload_TRAIN:
        seq_sense_index = seq_sense_index.to(DEVICE)
        seq_antis_index = seq_antis_index.to(DEVICE)        
        modi_sense_index = modi_sense_index.to(DEVICE)
        modi_antis_index = modi_antis_index.to(DEVICE)
        struct_sense_index = struct_sense_index.to(DEVICE)
        struct_antis_index = struct_antis_index.to(DEVICE)
        features_tensor = features_tensor.to(DEVICE)
        y_batch_lbl = label_tensor.to(DEVICE)
        
        y_batch_pred = model(seq_sense_index, seq_antis_index, modi_sense_index, modi_antis_index, 
                             struct_sense_index, struct_antis_index, features_tensor)
        loss_batch = criterion(y_batch_pred,y_batch_lbl)
        optimizer.zero_grad()
        loss_batch.backward()
        optimizer.step()
        
        loss_train += loss_batch.item()
    return loss_train/len(dataload_TRAIN)

def train_VREX(dataload_TRAIN,env_list,beta,model,optimizer,criterion):
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(DEVICE)
    model.train()
    loss_train_total = 0
    
    for seq_sense_index, seq_antis_index, modi_sense_index, modi_antis_index, struct_sense_index, struct_antis_index, features_tensor, label_tensor, domain_label_A in dataload_TRAIN:
        risks = []
        x_domain_label_A = np.array(domain_label_A)
        for env in env_list:
            env_mask = (x_domain_label_A == env)
            if True not in env_mask: continue
            x_seq_sense_index = seq_sense_index[env_mask].to(DEVICE)
            x_seq_antis_index = seq_antis_index[env_mask].to(DEVICE)        
            x_modi_sense_index = modi_sense_index[env_mask].to(DEVICE)
            x_modi_antis_index = modi_antis_index[env_mask].to(DEVICE)
            x_struct_sense_index = struct_sense_index[env_mask].to(DEVICE)
            x_struct_antis_index = struct_antis_index[env_mask].to(DEVICE)
            x_features_tensor = features_tensor[env_mask].to(DEVICE)
            y_batch_lbl = label_tensor[env_mask].to(DEVICE)
            y_batch_pred = model(x_seq_sense_index, x_seq_antis_index, x_modi_sense_index, x_modi_antis_index, 
                                 x_struct_sense_index, x_struct_antis_index, x_features_tensor)
            risks.append(criterion(y_batch_pred,y_batch_lbl))
        
        risks = torch.stack(risks)
        risks_mean = torch.mean(risks)
        risks_var = torch.var(risks)
        
        loss_batch = risks_mean + beta * risks_var
        
        optimizer.zero_grad()
        loss_batch.backward()
        optimizer.step()
        
        loss_train_total += risks_mean.item()
    return loss_train_total/len(dataload_TRAIN)

def calculate_metrics(y_pred, y_true, threshold=30):
    import warnings
    warnings.simplefilter("ignore")
    
    y_true = y_true.clip(0,100)
    y_pred = y_pred.clip(0,100)
    
    mae = np.mean(np.abs(y_true - y_pred))

    y_true_binary = (y_true < threshold).astype(int)
    y_pred_binary = (y_pred < threshold).astype(int)
    
    mask = (y_pred >= 0) & (y_pred <= threshold)
    range_mae = mean_absolute_error(y_true[mask], y_pred[mask]) if mask.sum() > 0 else 100

    precision = precision_score(y_true_binary, y_pred_binary, average='binary')
    recall = recall_score(y_true_binary, y_pred_binary, average='binary')
    
    f1 = 2 * precision * recall / (precision + recall)

    score = (1 - mae / 100) * 0.5 + (1 - range_mae / 100) * f1 * 0.5
    
    warnings.filterwarnings("default")
    return score

def validate(dataload_VAL,model,criterion,threshold=30):
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(DEVICE)
    model.eval()
    loss_val = 0
    y_val_lbl = []
    y_val_pred = []
    with torch.no_grad():
        for seq_sense_index, seq_antis_index, modi_sense_index, modi_antis_index, struct_sense_index, struct_antis_index, features_tensor, label_tensor, domain_label_A in dataload_VAL:
            seq_sense_index = seq_sense_index.to(DEVICE)
            seq_antis_index = seq_antis_index.to(DEVICE)        
            modi_sense_index = modi_sense_index.to(DEVICE)
            modi_antis_index = modi_antis_index.to(DEVICE)
            struct_sense_index = struct_sense_index.to(DEVICE)
            struct_antis_index = struct_antis_index.to(DEVICE)
            features_tensor = features_tensor.to(DEVICE)
            y_batch_lbl = label_tensor.to(DEVICE)

            y_batch_pred = model(seq_sense_index, seq_antis_index, modi_sense_index, modi_antis_index, 
                                 struct_sense_index, struct_antis_index, features_tensor)
            loss_batch = criterion(y_batch_pred,y_batch_lbl)

            loss_val += loss_batch.item()
            y_val_lbl.extend(y_batch_lbl.cpu().numpy())
            y_val_pred.extend(y_batch_pred.cpu().numpy())
        
    y_val_pred = np.array(y_val_pred)
    y_val_lbl = np.array(y_val_lbl)
    model_score = calculate_metrics(y_val_pred, y_val_lbl,threshold)
    return loss_val/len(dataload_VAL),model_score

## Training

In [6]:
dataset_TEST_IID = siRNA_dataset_trsfmr(df_structured_encoded_iid_test)
dataset_TEST_OOD = siRNA_dataset_trsfmr(df_structured_encoded_ood_test)
dataset_TRVL = siRNA_dataset_trsfmr(df_structured_encoded_iid_trvl)

In [7]:
import pickle
import os

def cvmodel_transformer_test(model_list,dataset_TEST,ablation):
    model4test = modisiR_transformer(ablation)
    test_score_list = []
    for i in range(len(model_list)):
        model4test.load_state_dict(model_list[i])
        model4test.eval()
        y_pred_TEST = model4test(dataset_TEST.seq_sense_index, dataset_TEST.seq_antis_index, 
                                 dataset_TEST.modi_sense_index, dataset_TEST.modi_antis_index, 
                                 dataset_TEST.struct_sense_index, dataset_TEST.struct_antis_index, 
                                 dataset_TEST.features_tensor)
        test_score = calculate_metrics(y_pred_TEST.detach().numpy(),dataset_TEST.label_tensor.detach().numpy())
        print(test_score)
        test_score_list.append(test_score)
    return test_score_list

def cv_train(model_type:str,dataset_TRVL,dataset_TEST_IID,dataset_TEST_OOD,ablation,
             model_list,cv_log):
    lr = 0.002
    EPOCHS = 150
    BETA = 1
    print('beta:',BETA)
    BATCH_SIZE = 256

    early_stop_score = 1
    warm_up_epoch_num = 10
    loss_tolerance_epoch_num = 10
    env_list = df_structured_encoded_iid_trvl['publication_id'].unique() #####

    dataload_TEST_IID = DataLoader(dataset=dataset_TEST_IID,batch_size=BATCH_SIZE)
    dataload_TEST_OOD = DataLoader(dataset=dataset_TEST_OOD,batch_size=BATCH_SIZE)

    kfold = KFold(n_splits=5,shuffle=True)
    splits = kfold.split(dataset_TRVL)

    for train_index, val_index in splits:

        log_train = []

        start_time = time.time()

        dataset_TRAIN = Subset(dataset_TRVL,train_index)
        dataset_VAL = Subset(dataset_TRVL,val_index)

        dataload_TRAIN = DataLoader(dataset=dataset_TRAIN,batch_size=BATCH_SIZE)
        dataload_VAL = DataLoader(dataset=dataset_VAL,batch_size=BATCH_SIZE)

        lowest_loss_epoch = {'loss_val':float("inf"),'epoch':0}
        best_score = -float('inf')
        best_OOD = -float('inf')
        model = modisiR_transformer(ablation)
        optimizer = optim.AdamW([{'params':model.parameters(),'lr':lr}])
        criterion = nn.MSELoss(reduction='mean')

        for epoch in range(EPOCHS):
            start_time_epoch = time.time()
            if model_type == 'vrex':
                loss_train = train_VREX(dataload_TRAIN,env_list,BETA,model,optimizer,criterion)
            elif model_type == 'erm':
                loss_train = train_ERM(dataload_TRAIN,model,optimizer,criterion)
            elif model_type == 'stdrex':
                loss_train = train_stdREX(dataload_TRAIN,env_list,BETA,model,optimizer,criterion)
            else: print('No such model type. Type should be erm or vrex.')
            loss_val,model_score = validate(dataload_VAL,model,criterion)
            
            _,test_score_iid = validate(dataload_TEST_IID,model,criterion)
            _,test_score_ood = validate(dataload_TEST_OOD,model,criterion)

            if epoch > warm_up_epoch_num:
                if loss_val < lowest_loss_epoch['loss_val']:
                    lowest_loss_epoch['epoch'] = epoch
                    lowest_loss_epoch['loss_val'] = loss_val
                elif (epoch-lowest_loss_epoch['epoch']) >= loss_tolerance_epoch_num:
                    lowest_loss_epoch['epoch'] = epoch
                    for param_group in optimizer.param_groups:
                        param_group['lr'] *= 0.5
                        
            log_train.append((epoch,loss_train,loss_val,model_score,test_score_iid,test_score_ood,lr))
            
            if model_score > best_score:
                best_score = model_score
                best_OOD = test_score_ood
                best_model = copy.deepcopy(model.state_dict())
            
            print(f'{epoch}\t{model_score:.4f}\t{test_score_iid:.4f}\t{test_score_ood:.4f}\t({best_score:.4f},{best_OOD:.4f})',sep='',end='\n')
            
            if best_score >= early_stop_score:
                break
        model_list.append(best_model)
        cv_log.append(log_train)
        print('')
        
    df_cv_log = pd.DataFrame()
    for i in range(len(cv_log)):
        df_log = pd.DataFrame(cv_log[i])
        mindex = pd.MultiIndex.from_product([['Model_'+str(i)],['epoch','loss_train','loss_val','val_score','iid_score','ood_score','lr']])
        df_log.columns = mindex
        df_cv_log = pd.concat([df_cv_log,df_log],axis=1)
        
    return model_list,df_cv_log

def save_state_dict_2cpu(PATH_SAVE,model_list,model):
    os.mkdir(PATH_SAVE+'models')
    for i in range(len(model_list)):
        print(i,list(model_list[i].values())[0].device,end='\t')
        model.load_state_dict(model_list[i])
        model.to('cpu')
        print('to',list(model.state_dict().values())[0].device)
        torch.save(model.state_dict(), PATH_SAVE+'models/state_dict_cpu_'+str(i)+'.pth')

### train_ERM

In [8]:
ERM_model_list = []
ERM_cv_log = []
ERM_model_list, df_ERM_cv_log = cv_train('erm',dataset_TRVL,dataset_TEST_IID,dataset_TEST_OOD,ablation=False,
         model_list=ERM_model_list,cv_log=ERM_cv_log)

beta: 1
ablation: False
ablation: False




0	nan	nan	nan	(-inf,-inf)
1	0.5366	0.5209	0.6715	(0.5366,0.6715)
2	0.6066	0.6139	0.6528	(0.6066,0.6528)
3	0.6370	0.6464	0.6340	(0.6370,0.6340)
4	0.6302	0.6324	0.6234	(0.6370,0.6340)
5	0.6718	0.6772	0.6210	(0.6718,0.6210)
6	0.6970	0.7020	0.5943	(0.6970,0.5943)
7	0.6946	0.7048	0.5941	(0.6970,0.5943)
8	0.7165	0.7294	0.5725	(0.7165,0.5725)
9	0.7018	0.7119	0.5809	(0.7165,0.5725)
10	0.7119	0.7203	0.5733	(0.7165,0.5725)
11	0.7330	0.7429	0.5717	(0.7330,0.5717)
12	0.7178	0.7283	0.5668	(0.7330,0.5717)
13	0.7211	0.7289	0.5621	(0.7330,0.5717)
14	0.7016	0.7042	0.5935	(0.7330,0.5717)
15	0.7202	0.7330	0.5722	(0.7330,0.5717)
16	0.7126	0.7285	0.5771	(0.7330,0.5717)
17	0.7351	0.7476	0.5569	(0.7351,0.5569)
18	0.7203	0.7274	0.5813	(0.7351,0.5569)
19	0.7441	0.7506	0.5490	(0.7441,0.5490)
20	0.7389	0.7476	0.5636	(0.7441,0.5490)
21	0.7303	0.7402	0.5707	(0.7441,0.5490)
22	0.7342	0.7421	0.5660	(0.7441,0.5490)
23	0.7262	0.7361	0.5818	(0.7441,0.5490)
24	0.7341	0.7386	0.5835	(0.7441,0.5490)
25	0.7259	0.7431	0.5758



0	nan	nan	nan	(-inf,-inf)
1	0.5280	0.5284	0.6535	(0.5280,0.6535)
2	0.5694	0.5710	0.6263	(0.5694,0.6263)
3	0.6025	0.6091	0.6177	(0.6025,0.6177)
4	0.5707	0.5688	0.5808	(0.6025,0.6177)
5	0.6330	0.6412	0.6025	(0.6330,0.6025)
6	0.6902	0.6953	0.5656	(0.6902,0.5656)
7	0.6944	0.7002	0.5671	(0.6944,0.5671)
8	0.6991	0.7024	0.5578	(0.6991,0.5578)
9	0.7172	0.7171	0.5206	(0.7172,0.5206)
10	0.7165	0.7194	0.5195	(0.7172,0.5206)
11	0.7294	0.7308	0.5043	(0.7294,0.5043)
12	0.7106	0.7187	0.5229	(0.7294,0.5043)
13	0.7092	0.7162	0.5080	(0.7294,0.5043)
14	0.7396	0.7487	0.5124	(0.7396,0.5124)
15	0.7141	0.7204	0.5113	(0.7396,0.5124)
16	0.7390	0.7437	0.5049	(0.7396,0.5124)
17	0.7336	0.7425	0.5132	(0.7396,0.5124)
18	0.7261	0.7380	0.5110	(0.7396,0.5124)
19	0.7211	0.7342	0.5204	(0.7396,0.5124)
20	0.7392	0.7415	0.5155	(0.7396,0.5124)
21	0.7208	0.7322	0.5245	(0.7396,0.5124)
22	0.7349	0.7400	0.5155	(0.7396,0.5124)
23	0.7470	0.7547	0.5159	(0.7470,0.5159)
24	0.7249	0.7322	0.5228	(0.7470,0.5159)
25	0.7481	0.7524	0.5179



0	nan	nan	nan	(-inf,-inf)
1	0.5259	0.5128	0.6739	(0.5259,0.6739)
2	0.5334	0.5222	0.6271	(0.5334,0.6271)
3	0.6047	0.6072	0.6113	(0.6047,0.6113)
4	0.6619	0.6612	0.6092	(0.6619,0.6092)
5	0.6391	0.6433	0.5949	(0.6619,0.6092)
6	0.6799	0.6883	0.5625	(0.6799,0.5625)
7	0.6991	0.7012	0.5437	(0.6991,0.5437)
8	0.6931	0.6975	0.5405	(0.6991,0.5437)
9	0.7041	0.7109	0.5345	(0.7041,0.5345)
10	0.7053	0.7092	0.5341	(0.7053,0.5341)
11	0.7394	0.7383	0.5268	(0.7394,0.5268)
12	0.7260	0.7329	0.5303	(0.7394,0.5268)
13	0.7270	0.7321	0.5230	(0.7394,0.5268)
14	0.7356	0.7397	0.5261	(0.7394,0.5268)
15	0.7531	0.7556	0.5283	(0.7531,0.5283)
16	0.7365	0.7393	0.5334	(0.7531,0.5283)
17	0.7474	0.7482	0.5232	(0.7531,0.5283)
18	0.7541	0.7582	0.5348	(0.7541,0.5348)
19	0.7497	0.7482	0.5114	(0.7541,0.5348)
20	0.7547	0.7542	0.5219	(0.7547,0.5219)
21	0.7539	0.7562	0.5283	(0.7547,0.5219)
22	0.7629	0.7623	0.5412	(0.7629,0.5412)
23	0.7439	0.7411	0.5370	(0.7629,0.5412)
24	0.7594	0.7589	0.5271	(0.7629,0.5412)
25	0.7504	0.7524	0.5308



0	nan	nan	nan	(-inf,-inf)
1	0.5568	0.5379	0.6692	(0.5568,0.6692)
2	0.6198	0.6160	0.6348	(0.6198,0.6348)
3	0.5700	0.5543	0.6028	(0.6198,0.6348)
4	0.6771	0.6693	0.5902	(0.6771,0.5902)
5	0.7032	0.7012	0.5530	(0.7032,0.5530)
6	0.7040	0.6943	0.5479	(0.7040,0.5479)
7	0.7322	0.7266	0.5185	(0.7322,0.5185)
8	0.7240	0.7196	0.5158	(0.7322,0.5185)
9	0.7327	0.7291	0.5131	(0.7327,0.5131)
10	0.7228	0.7206	0.5122	(0.7327,0.5131)
11	0.7476	0.7382	0.5037	(0.7476,0.5037)
12	0.7489	0.7434	0.4898	(0.7489,0.4898)
13	0.7379	0.7369	0.5133	(0.7489,0.4898)
14	0.7226	0.7150	0.5004	(0.7489,0.4898)
15	0.7565	0.7488	0.4929	(0.7565,0.4929)
16	0.7407	0.7366	0.5089	(0.7565,0.4929)
17	0.7618	0.7563	0.4947	(0.7618,0.4947)
18	0.7610	0.7607	0.5075	(0.7618,0.4947)
19	0.7488	0.7433	0.5165	(0.7618,0.4947)
20	0.7612	0.7579	0.5083	(0.7618,0.4947)
21	0.7638	0.7641	0.5087	(0.7638,0.5087)
22	0.7589	0.7559	0.5134	(0.7638,0.5087)
23	0.7568	0.7574	0.5101	(0.7638,0.5087)
24	0.7649	0.7694	0.5119	(0.7649,0.5119)
25	0.7539	0.7598	0.5195



0	nan	nan	nan	(-inf,-inf)
1	0.5099	0.4910	0.6062	(0.5099,0.6062)
2	0.6140	0.6071	0.6524	(0.6140,0.6524)
3	0.6541	0.6534	0.6353	(0.6541,0.6353)
4	0.6565	0.6618	0.6065	(0.6565,0.6065)
5	0.6743	0.6807	0.5975	(0.6743,0.5975)
6	0.6916	0.7003	0.5805	(0.6916,0.5805)
7	0.6839	0.6897	0.5915	(0.6916,0.5805)
8	0.7150	0.7240	0.5521	(0.7150,0.5521)
9	0.7120	0.7258	0.5496	(0.7150,0.5521)
10	0.7002	0.7126	0.5563	(0.7150,0.5521)
11	0.7280	0.7414	0.5490	(0.7280,0.5490)
12	0.7239	0.7258	0.5612	(0.7280,0.5490)
13	0.7407	0.7356	0.5434	(0.7407,0.5434)
14	0.7414	0.7486	0.5470	(0.7414,0.5470)
15	0.7211	0.7224	0.5571	(0.7414,0.5470)
16	0.7411	0.7457	0.5560	(0.7414,0.5470)
17	0.7143	0.7191	0.5713	(0.7414,0.5470)
18	0.7393	0.7499	0.5715	(0.7414,0.5470)
19	0.7443	0.7493	0.5660	(0.7443,0.5660)
20	0.7444	0.7527	0.5610	(0.7444,0.5610)
21	0.7303	0.7326	0.5774	(0.7444,0.5610)
22	0.7412	0.7457	0.5803	(0.7444,0.5610)
23	0.7556	0.7468	0.5809	(0.7556,0.5809)
24	0.7614	0.7648	0.5671	(0.7614,0.5671)
25	0.7553	0.7653	0.5876

In [9]:
df_ERM_cv_log

Unnamed: 0_level_0,Model_0,Model_0,Model_0,Model_0,Model_0,Model_0,Model_0,Model_1,Model_1,Model_1,...,Model_3,Model_3,Model_3,Model_4,Model_4,Model_4,Model_4,Model_4,Model_4,Model_4
Unnamed: 0_level_1,epoch,loss_train,loss_val,val_score,iid_score,ood_score,lr,epoch,loss_train,loss_val,...,iid_score,ood_score,lr,epoch,loss_train,loss_val,val_score,iid_score,ood_score,lr
0,0,1982.007381,1216.681386,,,,0.002,0,1941.988552,1164.747638,...,,,0.002,0,2012.438048,1168.784162,,,,0.002
1,1,1172.369864,978.815002,0.536635,0.520937,0.671475,0.002,1,1161.032538,908.661000,...,0.537896,0.669210,0.002,1,1169.601975,958.636173,0.509854,0.491030,0.606196,0.002
2,2,1006.674239,817.379786,0.606575,0.613854,0.652836,0.002,2,997.469137,768.077525,...,0.615971,0.634848,0.002,2,1015.151989,825.510986,0.614026,0.607108,0.652443,0.002
3,3,884.811293,728.081320,0.636993,0.646404,0.633986,0.002,3,891.204818,711.054957,...,0.554342,0.602843,0.002,3,907.347135,738.595226,0.654108,0.653402,0.635329,0.002
4,4,837.888278,683.205311,0.630172,0.632385,0.623412,0.002,4,840.905144,697.576520,...,0.669277,0.590174,0.002,4,836.350348,703.811875,0.656503,0.661762,0.606507,0.002
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
145,145,359.264946,418.129362,0.803344,0.811849,0.616196,0.002,145,334.304927,371.659484,...,0.807911,0.523234,0.002,145,327.344837,368.234632,0.818323,0.817276,0.595380,0.002
146,146,357.075971,423.604783,0.806297,0.815588,0.608114,0.002,146,334.146345,373.104980,...,0.813104,0.520418,0.002,146,319.695579,373.385175,0.817833,0.816918,0.596803,0.002
147,147,350.030584,403.444564,0.810628,0.812476,0.619307,0.002,147,339.731209,377.728725,...,0.810764,0.513701,0.002,147,328.591776,372.564629,0.816053,0.815164,0.595449,0.002
148,148,357.384038,399.562432,0.807447,0.815142,0.615818,0.002,148,336.431475,373.615243,...,0.810681,0.538339,0.002,148,327.174681,371.118712,0.817568,0.820156,0.597170,0.002


### train_V-REX

In [None]:
VREX_model_list,VREX_cv_log = cv_train('vrex',dataset_TRVL,dataset_TEST_IID,dataset_TEST_OOD,ablation=False)