In [40]:
import torch
import torch.nn as nn
import sys
sys.path.append('/home/yurui/GDSC_2/code')
from code.model.drp_model.DRP_multi_conc import *
from code.loader.GDSC2_loader import *
import warnings
import argparse
import datetime
import torch.cuda.amp as amp
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from scipy.stats import pearsonr
import torch.optim as opt
from torch.utils.tensorboard import SummaryWriter
warnings.filterwarnings("ignore", category=UserWarning)
def arg_parse():
    parser = argparse.ArgumentParser(description="Model Configuration")
    parser.add_argument('--drug_embed_dim', type=int, default=256,
                        help='Embedding dimension for drug')
    parser.add_argument('--cell_embed_dim', type=int, default=128,
                        help='Embedding dimension for cell')
    parser.add_argument('--drug_layer_num', type=int, default=2,
                        help='Number of layers for drug')
    parser.add_argument('--cell_layer_num', type=int, default=2,
                        help='Number of layers for cell')
    parser.add_argument('--dropout_rate', type=float, default=0.3,
                        help='Dropout rate')
    parser.add_argument('--readout', type=str, default='mean',
                        help='Readout function')
    parser.add_argument('--JK', type=str, default='True',
                        help='JKNet option')
    parser.add_argument('--view_dim', type=int, default=256)
    parser.add_argument('--epochs', type=int, default=50)    
    parser.add_argument('--lr', type= float, default = 1e-3,
                        help='Learning rate')
    parser.add_argument('--batch_size', type=int, default= 1024,
                        help='Batch size')
    # parser.add_argument('--device', type = str, default = 0,
    #                     help='Device')
    parser.add_argument('--weight_decay', type=float, default=0.01, 
                        help='Weight decay')
    parser.add_argument('--check_step', type=int, default = 5,
                        help='Num of steps to check performance')
    parser.add_argument('--use_regulizer', type=str, default='True')
    parser.add_argument('--regular_weight', type=float, default= 0.1)
    parser.add_argument('--use_drug_path_way', type=str, default='True')
    parser.add_argument('--regular_weight_drug_path_way', type=float, default= 0.1)
    parser.add_argument('--train_type', type=str, default='Mixed')
    parser.add_argument('--scheduler_type', type=str, default='OP')
    parser.add_argument('--device',type = int, default= 0)
    parser.add_argument('--early_stop_count',type = int, default= 7)
    return parser.parse_args('')
def cross_entropy_loss(input, target):
    return F.cross_entropy(input, target)
def total_loss(out_dict, batch_sample, args):
    m = args.regular_weight
    p = args.regular_weight_drug_path_way
    mse_loss_fn = nn.MSELoss()
    pred_loss = 0
    for i in range(out_dict['pred'].shape[1]):  # Assuming the second dimension represents the columns
        pred_loss += (i + 1)*mse_loss_fn(out_dict['pred'][:, i], batch_sample['label'][:, i])
    class_l = 0.0
    drug_pathway_l = 0.0
    if args.use_regulizer == 'True':
        class_l = cross_entropy_loss(out_dict['cell_regulizer'], batch_sample['CL_type'])
    if args.use_drug_path_way == 'True':
        drug_pathway_l = cross_entropy_loss(out_dict['drug_pathway'], batch_sample['drug_atom_repr'].PATHWAY_TYPE)
    return pred_loss + m * class_l + p*drug_pathway_l
def train_step(model, train_loader, optimizer, writer, epoch, device, args):
    # enable automatic mixed precision
    scaler = amp.GradScaler()

    model.train()
    y_true, preds = [], []
    optimizer.zero_grad()
    for data in tqdm(train_loader):
        batch_sample = {k: v.to(device) for k, v in data.items()}              
        with amp.autocast():
            out_dict = model(batch_sample)
            loss = total_loss(out_dict, batch_sample, args)
        preds.append(out_dict['pred'].float())
        y_true.append(batch_sample['label'])
        # perform backward pass and optimizer step using the scaler
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()
        # scheduler.step()
    y_true = torch.cat(y_true, dim=0).cpu().detach().numpy()
    y_pred = torch.cat(preds, dim=0).cpu().detach().numpy()
    rmse_dict = {}
    pcc_dict = {}
    r2_dict = {}
    MAE_dict = {}
    for i in range(7):
        column_true = y_true[:, i]
        column_pred = y_pred[:, i]
        rmse = mean_squared_error(column_true, column_pred, squared=False)
        rmse_dict[i] = rmse
        pcc = pearsonr(column_true, column_pred)[0]
        pcc_dict[i] = pcc
        r_2 = r2_score(column_true, column_pred)
        r2_dict[i] = r_2
        MAE = mean_absolute_error(column_true, column_pred)
        MAE_dict[i] = MAE
        # print(f'Train accuracy for dose {i}: RMSE: {rmse:.4f}, PCC: {pcc:.4f}, R2: {r_2:.4f}, MAE: {MAE:.4f}')
        writer.add_scalar("Loss", rmse, epoch)
        writer.add_scalar(f"Accuracy/train/response_for_conc_{i+1}/rmse", rmse, epoch)
        writer.add_scalar(f"Accuracy/train/response_for_conc_{i+1}/mae", MAE, epoch)
        writer.add_scalar(f"Accuracy/train/response_for_conc_{i+1}/pcc", pcc, epoch)
        writer.add_scalar(f"Accuracy/train/response_for_conc_{i+1}/r_2", r_2, epoch)
    print(optimizer.param_groups[0]['lr'])
    return rmse_dict, pcc_dict
@torch.no_grad()
def test_step(model,loader,device):
    model.eval()
    y_true, preds = [], []
    for data in tqdm(loader):
        batch_sample = {k: v.to(device) for k, v in data.items()}   
        out_dict = model(batch_sample)
        y_true.append(batch_sample['label'])
        preds.append(out_dict['pred'].float().cpu())
    y_true = torch.cat(y_true, dim=0).cpu().detach()
    y_pred = torch.cat(preds, dim=0).cpu().detach()
    test_rmse = nn.MSELoss()(y_true, y_pred)
    y_true = y_true.numpy()
    y_pred = y_pred.numpy()
    rmse_dict = {}
    pcc_dict = {}
    r2_dict = {}
    MAE_dict = {}
    for i in range(7):
        column_true = y_true[:, i]
        column_pred = y_pred[:, i]
        rmse = mean_squared_error(column_true, column_pred, squared=False)
        rmse_dict[i] = rmse
        pcc = pearsonr(column_true, column_pred)[0]
        pcc_dict[i] = pcc
        r_2 = r2_score(column_true, column_pred)
        r2_dict[i] = r_2
        MAE = mean_absolute_error(column_true, column_pred)
        MAE_dict[i] = MAE
        print(f'Test accuracy for dose {i}: RMSE: {rmse:.4f}, PCC: {pcc:.4f}, R2: {r_2:.4f}, MAE: {MAE:.4f}')
    return rmse_dict, pcc_dict, r2_dict, MAE_dict, test_rmse

def train_multi_view_model(args, train_set, val_set, test_set):
    save_dir = 'best_model_' + args.train_type
    lr = args.lr
    device = torch.device(f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')
    batch_size = args.batch_size
    drug_embed_dim = args.drug_embed_dim
    dropout_rate = args.dropout_rate
    drug_layer_num = args.drug_layer_num   
    readout = args.readout ## mean, max
    JK = args.JK ## 'True', 'False', string value
    ## Config for cells
    cell_embed_dim = args.cell_embed_dim
    cell_layer_num = args.cell_layer_num
    ## Config for genes
    view_dim = args.view_dim   
    n_epochs = args.epochs 
    use_regulizer = args.use_regulizer
    use_drug_path_way = args.use_drug_path_way
    model_config = {'drug_embed_dim': drug_embed_dim,
                    'cell_embed_dim': cell_embed_dim, 
                    'hidden_dim': cell_embed_dim, 
                    'drug_layer_num': drug_layer_num, ## This is for drug
                    'cell_layer_num': cell_layer_num, ## This is for cell
                    'dropout_rate' : dropout_rate,
                    'readout': readout,
                    'JK': JK,
                    'view_dim': view_dim,
                    'use_regulizer': use_regulizer,
                    'use_drug_path_way': use_drug_path_way
                    }  
    path = f'./TB_5_fold/{save_dir}'+'.pth' 
    model = DRP_multi_view(model_config).to(device)
    # model = torch.compile(model)
    optimizer = opt.AdamW(model.parameters(), lr=lr, weight_decay= 0.01)
        # scheduler = get_polynomial_decay_schedule_with_warmup(optimizer,num_warmup_steps=50, num_training_steps=n_epochs, lr_end = 1e-4, power=1)
    # elif optimizer_name == 'SGD': 
        # optimizer = opt.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=1e-2)
    # cos_lr = lambda x : ((1+math.cos(math.pi* x /100) )/2)*(1-args.lrf) + args.lrf
    # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=cos_lr)
    current_time = datetime.datetime.now().time()
    print('Begin Training')
    print(f'Embed_dim_drug : {drug_embed_dim}'+ '\n' +f'Hidden_dim_cell : {cell_embed_dim} \n' +  f'drug_layer_num : {drug_layer_num} \n'+ 
            f'read_out_function : {readout}\n'  +f'batch_size : {batch_size}\n' + f'view_dim : {view_dim}\n' + 
            f'lr : {lr}\n' + f'use_regulizer : {use_regulizer}\n' + f'use_drug_path_way : {use_drug_path_way}')
    tb = SummaryWriter(comment=current_time, log_dir=f'./TB_5_fold/{save_dir}')
    train_loader = DataLoader(train_set, batch_size= batch_size, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_set, batch_size= batch_size, shuffle=True, collate_fn=collate_fn)
    test_loader = DataLoader(test_set, batch_size= batch_size, shuffle=True, collate_fn=collate_fn)
    epoch_len = len(str(n_epochs))
    results = 'model_results'
    if os.path.exists(f"./{results}/{args.train_type}") is False:
        os.makedirs(f"./{results}/{args.train_type}")
    early_stop_count = 0 
    best_epoch = 0 
    best_val_rmse = 100
    if args.scheduler_type == 'OP':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience= 7 , verbose=True, min_lr= 0.05 * args.lr, factor= 0.1)
    elif args.scheduler_type == 'ML':
        scheduler = opt.lr_scheduler.MultiStepLR(optimizer, milestones=[80], gamma=0.1)
    for epoch in range(n_epochs):
        if early_stop_count < args.early_stop_count :
            train_rmse, train_pcc = train_step(model, train_loader, optimizer, tb, epoch, device, args)
            if args.scheduler_type == 'ML':
                scheduler.step()
            current_lr = optimizer.param_groups[0]['lr']
            for i in range(7):
                print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] '  + 
                            f'train_rmse for conc {i}: {train_rmse[i]:.5f} ' +
                            f'train_pcc for conc {i}: {train_pcc[i]:.5f} ' +  f'lr : {current_lr}')
                print(print_msg)
            if epoch % args.check_step == 0:
                val_rmse,val_pcc, val_r_2, val_mae, val_threshold_rmse = test_step(model, val_loader, device)
                if args.scheduler_type == 'OP':
                    scheduler.step(val_threshold_rmse)
                for i in range(7):
                    tb.add_scalar(f'Accuracy/val/pcc_conc_{i}', val_pcc[i], epoch)
                    tb.add_scalar(f"Accuracy/val/rmse_conc_{i}", val_rmse[i], epoch)
                    tb.add_scalar(f"Accuracy/val/mae_conc_{i}", val_mae[i], epoch)
                    tb.add_scalar(f"Accuracy/val/r_2_conc_{i}", val_r_2[i], epoch)
                    tb.add_scalar("LR", optimizer.param_groups[0]['lr'], epoch)
                    print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] '  + 
                                f'val_rmse for conc {i}: {val_rmse[i]:.5f} ' +
                                f'val_r_2 for conc {i}: {val_r_2[i]:.5f} ' +
                                f'val_mae for conc {i}: {val_mae[i]:.5f} ' +
                                f'val_pcc for conc {i}: {val_pcc[i]:.5f} ' +  f'lr : {current_lr}')
                    print(print_msg)
                if val_threshold_rmse < best_val_rmse:
                    early_stop_count = 0
                    best_val_rmse = val_threshold_rmse
                    best_epoch = epoch
                    test_rmse, test_pcc, test_r_2, test_mae, test_threshold_rmse = test_step(model,test_loader, device)
                    torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict' : optimizer.state_dict(),
                            }, path)
                else: 
                    early_stop_count += 1 
                    print(f'Early stopping encounter : {early_stop_count}  times')
                if early_stop_count >= args.early_stop_count:
                    print('Early stopping!')
                    break
                print(f'Best epoch: {best_epoch:03d}')
                for i in range(7):
                    print(f'Best PCC for conc {i}: {test_pcc[i]:.4f},'
                        f'Best RMSE for conc {i}: {test_rmse[i]:.4f}, '
                        f'Best R_2 for conc {i}: {test_r_2[i]:.4f}, '
                        f'Best MAE for conc {i}: {test_mae[i]:.4f}')

    print("__________________________________________________________")
    hparams = {
        'use_regulizer': use_regulizer,
        'use_drug_path_way': use_drug_path_way,
        'best_epoch': best_epoch,
        'regular_weight': args.regular_weight,
        'regular_weight_drug_path_way': args.regular_weight_drug_path_way
    }

    metrics = {
        'val_threshold_rmse': best_val_rmse,
        'test_threshold_rmse': test_threshold_rmse
    }

    # Add metrics for each concentration
    for i in range(7):
        metrics[f'test_pcc_conc_{i}'] = test_pcc[i]
        metrics[f'test_rmse_conc_{i}'] = test_rmse[i]
        metrics[f'test_r2_conc_{i}'] = test_r_2[i]
        metrics[f'test_mae_conc_{i}'] = test_mae[i]
    tb.add_hparams(hparams, metrics)
    tb.close()


In [41]:

args = arg_parse()
device = torch.device("cuda:"+str(args.device) if torch.cuda.is_available() else "cpu") 
mut_tensor, chr_tensor, cna_tensor, GE_tensor, CL_type_tensor, label_tensor, result_df, drug2index, CL2index = preprocess_cell_data()
drug_atom_feat, drug_bond_feat = process_drug_feat(drug2index)
train_idx, val_idx, test_idx = train_val_test_split(result_df)
drug_train_idx, CL_train_idx = process_CL_drug(train_idx)
drug_val_idx, CL_val_idx = process_CL_drug(val_idx)
drug_test_idx, CL_test_idx = process_CL_drug(test_idx)
DRP_trainset = DRP_dataset(mut_tensor, chr_tensor, cna_tensor, GE_tensor, CL_type_tensor, drug_atom_feat, drug_bond_feat, result_tensor=label_tensor, drug_idx = drug_train_idx, CL_idx = CL_train_idx)
DRP_valset = DRP_dataset(mut_tensor, chr_tensor, cna_tensor, GE_tensor, CL_type_tensor, drug_atom_feat, drug_bond_feat, result_tensor=label_tensor, drug_idx = drug_val_idx, CL_idx = CL_val_idx)
DRP_testset = DRP_dataset(mut_tensor, chr_tensor, cna_tensor, GE_tensor, CL_type_tensor, drug_atom_feat, drug_bond_feat, result_tensor=label_tensor, drug_idx = drug_test_idx, CL_idx = CL_test_idx)


finish loading drug data!


In [42]:
train_multi_view_model(args, train_set = DRP_trainset, val_set = DRP_valset, test_set = DRP_testset)
    

Begin Training
Embed_dim_drug : 256
Hidden_dim_cell : 128 
drug_layer_num : 2 
read_out_function : mean
batch_size : 1024
view_dim : 256
lr : 0.001
use_regulizer : True
use_drug_path_way : True


100%|██████████| 23/23 [00:06<00:00,  3.45it/s]


0.001
[ 0/50] train_rmse for conc 0: 0.81480 train_pcc for conc 0: -0.00819 lr : 0.001
[ 0/50] train_rmse for conc 1: 0.91794 train_pcc for conc 1: -0.00252 lr : 0.001
[ 0/50] train_rmse for conc 2: 0.90960 train_pcc for conc 2: -0.00981 lr : 0.001
[ 0/50] train_rmse for conc 3: 0.93007 train_pcc for conc 3: -0.00386 lr : 0.001
[ 0/50] train_rmse for conc 4: 0.96326 train_pcc for conc 4: -0.00243 lr : 0.001
[ 0/50] train_rmse for conc 5: 0.94786 train_pcc for conc 5: -0.00012 lr : 0.001
[ 0/50] train_rmse for conc 6: 0.91533 train_pcc for conc 6: 0.00134 lr : 0.001


100%|██████████| 89/89 [00:22<00:00,  3.99it/s]


Test accuracy for dose 0: RMSE: 0.3972, PCC: -0.0000, R2: -25.9746, MAE: 0.3496
Test accuracy for dose 1: RMSE: 0.5125, PCC: -0.0058, R2: -35.0932, MAE: 0.4903
Test accuracy for dose 2: RMSE: 0.3950, PCC: -0.0029, R2: -10.7478, MAE: 0.3605
Test accuracy for dose 3: RMSE: 0.2745, PCC: -0.0021, R2: -1.7820, MAE: 0.2360
Test accuracy for dose 4: RMSE: 0.4836, PCC: 0.0005, R2: -3.8249, MAE: 0.4179
Test accuracy for dose 5: RMSE: 0.5544, PCC: 0.0008, R2: -3.2872, MAE: 0.4891
Test accuracy for dose 6: RMSE: 0.4586, PCC: 0.0041, R2: -1.3688, MAE: 0.4003
[ 0/50] val_rmse for conc 0: 0.39717 val_r_2 for conc 0: -25.97464 val_mae for conc 0: 0.34961 val_pcc for conc 0: -0.00004 lr : 0.001
[ 0/50] val_rmse for conc 1: 0.51251 val_r_2 for conc 1: -35.09322 val_mae for conc 1: 0.49033 val_pcc for conc 1: -0.00578 lr : 0.001
[ 0/50] val_rmse for conc 2: 0.39501 val_r_2 for conc 2: -10.74778 val_mae for conc 2: 0.36053 val_pcc for conc 2: -0.00292 lr : 0.001
[ 0/50] val_rmse for conc 3: 0.27446 val_r

100%|██████████| 28/28 [00:07<00:00,  3.99it/s]


Test accuracy for dose 0: RMSE: 0.3982, PCC: 0.0015, R2: -29.1503, MAE: 0.3515
Test accuracy for dose 1: RMSE: 0.5184, PCC: 0.0089, R2: -43.7664, MAE: 0.4971
Test accuracy for dose 2: RMSE: 0.4005, PCC: -0.0021, R2: -12.4423, MAE: 0.3667
Test accuracy for dose 3: RMSE: 0.2775, PCC: 0.0072, R2: -1.9091, MAE: 0.2411
Test accuracy for dose 4: RMSE: 0.4860, PCC: -0.0046, R2: -3.7022, MAE: 0.4205
Test accuracy for dose 5: RMSE: 0.5504, PCC: 0.0016, R2: -2.8758, MAE: 0.4829
Test accuracy for dose 6: RMSE: 0.4538, PCC: -0.0078, R2: -1.0902, MAE: 0.3933
Best epoch: 000
Best PCC for conc 0: 0.0015,Best RMSE for conc 0: 0.3982, Best R_2 for conc 0: -29.1503, Best MAE for conc 0: 0.3515
Best PCC for conc 1: 0.0089,Best RMSE for conc 1: 0.5184, Best R_2 for conc 1: -43.7664, Best MAE for conc 1: 0.4971
Best PCC for conc 2: -0.0021,Best RMSE for conc 2: 0.4005, Best R_2 for conc 2: -12.4423, Best MAE for conc 2: 0.3667
Best PCC for conc 3: 0.0072,Best RMSE for conc 3: 0.2775, Best R_2 for conc 3: -

100%|██████████| 23/23 [00:07<00:00,  3.26it/s]


0.001
[ 1/50] train_rmse for conc 0: 0.34972 train_pcc for conc 0: 0.00548 lr : 0.001
[ 1/50] train_rmse for conc 1: 0.25281 train_pcc for conc 1: 0.00002 lr : 0.001
[ 1/50] train_rmse for conc 2: 0.27992 train_pcc for conc 2: 0.00410 lr : 0.001
[ 1/50] train_rmse for conc 3: 0.30419 train_pcc for conc 3: -0.00150 lr : 0.001
[ 1/50] train_rmse for conc 4: 0.35235 train_pcc for conc 4: 0.00043 lr : 0.001
[ 1/50] train_rmse for conc 5: 0.40041 train_pcc for conc 5: -0.00825 lr : 0.001
[ 1/50] train_rmse for conc 6: 0.38963 train_pcc for conc 6: -0.00127 lr : 0.001


100%|██████████| 23/23 [00:06<00:00,  3.62it/s]


0.001
[ 2/50] train_rmse for conc 0: 0.19652 train_pcc for conc 0: 0.00444 lr : 0.001
[ 2/50] train_rmse for conc 1: 0.18537 train_pcc for conc 1: 0.00471 lr : 0.001
[ 2/50] train_rmse for conc 2: 0.19010 train_pcc for conc 2: -0.00597 lr : 0.001
[ 2/50] train_rmse for conc 3: 0.22032 train_pcc for conc 3: 0.01318 lr : 0.001
[ 2/50] train_rmse for conc 4: 0.27179 train_pcc for conc 4: 0.00817 lr : 0.001
[ 2/50] train_rmse for conc 5: 0.32389 train_pcc for conc 5: 0.01137 lr : 0.001
[ 2/50] train_rmse for conc 6: 0.35332 train_pcc for conc 6: 0.01330 lr : 0.001


100%|██████████| 23/23 [00:06<00:00,  3.30it/s]


0.001
[ 3/50] train_rmse for conc 0: 0.15507 train_pcc for conc 0: -0.00037 lr : 0.001
[ 3/50] train_rmse for conc 1: 0.14814 train_pcc for conc 1: 0.01191 lr : 0.001
[ 3/50] train_rmse for conc 2: 0.15983 train_pcc for conc 2: 0.01306 lr : 0.001
[ 3/50] train_rmse for conc 3: 0.20266 train_pcc for conc 3: 0.00858 lr : 0.001
[ 3/50] train_rmse for conc 4: 0.25858 train_pcc for conc 4: 0.00851 lr : 0.001
[ 3/50] train_rmse for conc 5: 0.31132 train_pcc for conc 5: 0.00419 lr : 0.001
[ 3/50] train_rmse for conc 6: 0.34059 train_pcc for conc 6: 0.01216 lr : 0.001


100%|██████████| 23/23 [00:07<00:00,  3.15it/s]


0.001
[ 4/50] train_rmse for conc 0: 0.14135 train_pcc for conc 0: -0.00464 lr : 0.001
[ 4/50] train_rmse for conc 1: 0.14194 train_pcc for conc 1: 0.00340 lr : 0.001
[ 4/50] train_rmse for conc 2: 0.15346 train_pcc for conc 2: 0.00725 lr : 0.001
[ 4/50] train_rmse for conc 3: 0.19852 train_pcc for conc 3: -0.00155 lr : 0.001
[ 4/50] train_rmse for conc 4: 0.25457 train_pcc for conc 4: 0.01049 lr : 0.001
[ 4/50] train_rmse for conc 5: 0.30871 train_pcc for conc 5: 0.01082 lr : 0.001
[ 4/50] train_rmse for conc 6: 0.33752 train_pcc for conc 6: 0.01397 lr : 0.001


100%|██████████| 23/23 [00:07<00:00,  3.12it/s]


0.001
[ 5/50] train_rmse for conc 0: 0.13177 train_pcc for conc 0: 0.00702 lr : 0.001
[ 5/50] train_rmse for conc 1: 0.13410 train_pcc for conc 1: 0.00647 lr : 0.001
[ 5/50] train_rmse for conc 2: 0.14876 train_pcc for conc 2: -0.00040 lr : 0.001
[ 5/50] train_rmse for conc 3: 0.19491 train_pcc for conc 3: 0.00714 lr : 0.001
[ 5/50] train_rmse for conc 4: 0.25341 train_pcc for conc 4: 0.00739 lr : 0.001
[ 5/50] train_rmse for conc 5: 0.30723 train_pcc for conc 5: 0.00909 lr : 0.001
[ 5/50] train_rmse for conc 6: 0.33640 train_pcc for conc 6: 0.00988 lr : 0.001


100%|██████████| 89/89 [00:22<00:00,  3.89it/s]


Test accuracy for dose 0: RMSE: 0.0853, PCC: -0.0044, R2: -0.2436, MAE: 0.0626
Test accuracy for dose 1: RMSE: 0.0995, PCC: -0.0033, R2: -0.3607, MAE: 0.0707
Test accuracy for dose 2: RMSE: 0.1311, PCC: -0.0003, R2: -0.2949, MAE: 0.0867
Test accuracy for dose 3: RMSE: 0.1752, PCC: -0.0006, R2: -0.1331, MAE: 0.1149
Test accuracy for dose 4: RMSE: 0.2255, PCC: -0.0010, R2: -0.0494, MAE: 0.1699
Test accuracy for dose 5: RMSE: 0.2754, PCC: 0.0007, R2: -0.0578, MAE: 0.2283
Test accuracy for dose 6: RMSE: 0.3037, PCC: -0.0020, R2: -0.0385, MAE: 0.2590
[ 5/50] val_rmse for conc 0: 0.08528 val_r_2 for conc 0: -0.24359 val_mae for conc 0: 0.06257 val_pcc for conc 0: -0.00444 lr : 0.001
[ 5/50] val_rmse for conc 1: 0.09951 val_r_2 for conc 1: -0.36070 val_mae for conc 1: 0.07073 val_pcc for conc 1: -0.00329 lr : 0.001
[ 5/50] val_rmse for conc 2: 0.13115 val_r_2 for conc 2: -0.29494 val_mae for conc 2: 0.08667 val_pcc for conc 2: -0.00032 lr : 0.001
[ 5/50] val_rmse for conc 3: 0.17517 val_r_2 f

100%|██████████| 28/28 [00:06<00:00,  4.08it/s]


Test accuracy for dose 0: RMSE: 0.0814, PCC: 0.0003, R2: -0.2591, MAE: 0.0609
Test accuracy for dose 1: RMSE: 0.0922, PCC: 0.0060, R2: -0.4160, MAE: 0.0671
Test accuracy for dose 2: RMSE: 0.1246, PCC: 0.0018, R2: -0.3012, MAE: 0.0829
Test accuracy for dose 3: RMSE: 0.1724, PCC: 0.0049, R2: -0.1220, MAE: 0.1136
Test accuracy for dose 4: RMSE: 0.2293, PCC: 0.0012, R2: -0.0467, MAE: 0.1729
Test accuracy for dose 5: RMSE: 0.2863, PCC: -0.0004, R2: -0.0486, MAE: 0.2375
Test accuracy for dose 6: RMSE: 0.3190, PCC: 0.0017, R2: -0.0325, MAE: 0.2734
Best epoch: 005
Best PCC for conc 0: 0.0003,Best RMSE for conc 0: 0.0814, Best R_2 for conc 0: -0.2591, Best MAE for conc 0: 0.0609
Best PCC for conc 1: 0.0060,Best RMSE for conc 1: 0.0922, Best R_2 for conc 1: -0.4160, Best MAE for conc 1: 0.0671
Best PCC for conc 2: 0.0018,Best RMSE for conc 2: 0.1246, Best R_2 for conc 2: -0.3012, Best MAE for conc 2: 0.0829
Best PCC for conc 3: 0.0049,Best RMSE for conc 3: 0.1724, Best R_2 for conc 3: -0.1220, B

100%|██████████| 23/23 [00:06<00:00,  3.33it/s]


0.001
[ 6/50] train_rmse for conc 0: 0.12796 train_pcc for conc 0: 0.01513 lr : 0.001
[ 6/50] train_rmse for conc 1: 0.13062 train_pcc for conc 1: 0.00086 lr : 0.001
[ 6/50] train_rmse for conc 2: 0.14764 train_pcc for conc 2: -0.00355 lr : 0.001
[ 6/50] train_rmse for conc 3: 0.19340 train_pcc for conc 3: -0.00398 lr : 0.001
[ 6/50] train_rmse for conc 4: 0.25210 train_pcc for conc 4: 0.00777 lr : 0.001
[ 6/50] train_rmse for conc 5: 0.30661 train_pcc for conc 5: 0.00313 lr : 0.001
[ 6/50] train_rmse for conc 6: 0.33411 train_pcc for conc 6: 0.01330 lr : 0.001


100%|██████████| 23/23 [00:06<00:00,  3.55it/s]


0.001
[ 7/50] train_rmse for conc 0: 0.12485 train_pcc for conc 0: -0.00631 lr : 0.001
[ 7/50] train_rmse for conc 1: 0.12767 train_pcc for conc 1: -0.00838 lr : 0.001
[ 7/50] train_rmse for conc 2: 0.14304 train_pcc for conc 2: -0.00537 lr : 0.001
[ 7/50] train_rmse for conc 3: 0.19058 train_pcc for conc 3: -0.00056 lr : 0.001
[ 7/50] train_rmse for conc 4: 0.24967 train_pcc for conc 4: 0.00765 lr : 0.001
[ 7/50] train_rmse for conc 5: 0.30495 train_pcc for conc 5: 0.01076 lr : 0.001
[ 7/50] train_rmse for conc 6: 0.33437 train_pcc for conc 6: 0.01005 lr : 0.001


100%|██████████| 23/23 [00:06<00:00,  3.53it/s]


0.001
[ 8/50] train_rmse for conc 0: 0.12049 train_pcc for conc 0: 0.00049 lr : 0.001
[ 8/50] train_rmse for conc 1: 0.13076 train_pcc for conc 1: -0.00055 lr : 0.001
[ 8/50] train_rmse for conc 2: 0.14939 train_pcc for conc 2: 0.01019 lr : 0.001
[ 8/50] train_rmse for conc 3: 0.19136 train_pcc for conc 3: 0.00889 lr : 0.001
[ 8/50] train_rmse for conc 4: 0.24986 train_pcc for conc 4: 0.01549 lr : 0.001
[ 8/50] train_rmse for conc 5: 0.30556 train_pcc for conc 5: 0.01283 lr : 0.001
[ 8/50] train_rmse for conc 6: 0.33312 train_pcc for conc 6: 0.02409 lr : 0.001


100%|██████████| 23/23 [00:07<00:00,  3.12it/s]


0.001
[ 9/50] train_rmse for conc 0: 0.11777 train_pcc for conc 0: 0.00815 lr : 0.001
[ 9/50] train_rmse for conc 1: 0.12553 train_pcc for conc 1: 0.00916 lr : 0.001
[ 9/50] train_rmse for conc 2: 0.14081 train_pcc for conc 2: -0.00638 lr : 0.001
[ 9/50] train_rmse for conc 3: 0.18850 train_pcc for conc 3: 0.00249 lr : 0.001
[ 9/50] train_rmse for conc 4: 0.24802 train_pcc for conc 4: 0.01352 lr : 0.001
[ 9/50] train_rmse for conc 5: 0.30375 train_pcc for conc 5: 0.00995 lr : 0.001
[ 9/50] train_rmse for conc 6: 0.33311 train_pcc for conc 6: 0.00478 lr : 0.001


100%|██████████| 23/23 [00:07<00:00,  3.25it/s]


0.001
[10/50] train_rmse for conc 0: 0.11526 train_pcc for conc 0: -0.00355 lr : 0.001
[10/50] train_rmse for conc 1: 0.12359 train_pcc for conc 1: -0.00060 lr : 0.001
[10/50] train_rmse for conc 2: 0.14078 train_pcc for conc 2: 0.00086 lr : 0.001
[10/50] train_rmse for conc 3: 0.18913 train_pcc for conc 3: 0.01241 lr : 0.001
[10/50] train_rmse for conc 4: 0.24788 train_pcc for conc 4: 0.01133 lr : 0.001
[10/50] train_rmse for conc 5: 0.30340 train_pcc for conc 5: 0.01273 lr : 0.001
[10/50] train_rmse for conc 6: 0.33197 train_pcc for conc 6: 0.01850 lr : 0.001


100%|██████████| 89/89 [00:22<00:00,  4.02it/s]


Test accuracy for dose 0: RMSE: 0.0857, PCC: 0.0011, R2: -0.2567, MAE: 0.0676
Test accuracy for dose 1: RMSE: 0.0923, PCC: -0.0058, R2: -0.1718, MAE: 0.0653
Test accuracy for dose 2: RMSE: 0.1231, PCC: -0.0039, R2: -0.1400, MAE: 0.0796
Test accuracy for dose 3: RMSE: 0.1690, PCC: 0.0036, R2: -0.0550, MAE: 0.1105
Test accuracy for dose 4: RMSE: 0.2245, PCC: -0.0004, R2: -0.0396, MAE: 0.1698
Test accuracy for dose 5: RMSE: 0.2710, PCC: -0.0012, R2: -0.0245, MAE: 0.2214
Test accuracy for dose 6: RMSE: 0.3029, PCC: -0.0020, R2: -0.0332, MAE: 0.2539
[10/50] val_rmse for conc 0: 0.08573 val_r_2 for conc 0: -0.25665 val_mae for conc 0: 0.06755 val_pcc for conc 0: 0.00112 lr : 0.001
[10/50] val_rmse for conc 1: 0.09235 val_r_2 for conc 1: -0.17180 val_mae for conc 1: 0.06533 val_pcc for conc 1: -0.00584 lr : 0.001
[10/50] val_rmse for conc 2: 0.12305 val_r_2 for conc 2: -0.14005 val_mae for conc 2: 0.07959 val_pcc for conc 2: -0.00388 lr : 0.001
[10/50] val_rmse for conc 3: 0.16902 val_r_2 for

100%|██████████| 28/28 [00:06<00:00,  4.03it/s]


Test accuracy for dose 0: RMSE: 0.0830, PCC: -0.0032, R2: -0.3091, MAE: 0.0669
Test accuracy for dose 1: RMSE: 0.0853, PCC: -0.0044, R2: -0.2129, MAE: 0.0624
Test accuracy for dose 2: RMSE: 0.1169, PCC: -0.0036, R2: -0.1447, MAE: 0.0764
Test accuracy for dose 3: RMSE: 0.1669, PCC: -0.0018, R2: -0.0522, MAE: 0.1107
Test accuracy for dose 4: RMSE: 0.2286, PCC: -0.0060, R2: -0.0404, MAE: 0.1728
Test accuracy for dose 5: RMSE: 0.2830, PCC: -0.0055, R2: -0.0244, MAE: 0.2315
Test accuracy for dose 6: RMSE: 0.3203, PCC: 0.0020, R2: -0.0414, MAE: 0.2697
Best epoch: 010
Best PCC for conc 0: -0.0032,Best RMSE for conc 0: 0.0830, Best R_2 for conc 0: -0.3091, Best MAE for conc 0: 0.0669
Best PCC for conc 1: -0.0044,Best RMSE for conc 1: 0.0853, Best R_2 for conc 1: -0.2129, Best MAE for conc 1: 0.0624
Best PCC for conc 2: -0.0036,Best RMSE for conc 2: 0.1169, Best R_2 for conc 2: -0.1447, Best MAE for conc 2: 0.0764
Best PCC for conc 3: -0.0018,Best RMSE for conc 3: 0.1669, Best R_2 for conc 3: -

100%|██████████| 23/23 [00:06<00:00,  3.33it/s]


0.001
[11/50] train_rmse for conc 0: 0.11259 train_pcc for conc 0: -0.00928 lr : 0.001
[11/50] train_rmse for conc 1: 0.12023 train_pcc for conc 1: -0.00182 lr : 0.001
[11/50] train_rmse for conc 2: 0.13539 train_pcc for conc 2: 0.00035 lr : 0.001
[11/50] train_rmse for conc 3: 0.18482 train_pcc for conc 3: 0.01763 lr : 0.001
[11/50] train_rmse for conc 4: 0.24601 train_pcc for conc 4: 0.02151 lr : 0.001
[11/50] train_rmse for conc 5: 0.30201 train_pcc for conc 5: 0.01451 lr : 0.001
[11/50] train_rmse for conc 6: 0.33103 train_pcc for conc 6: 0.01809 lr : 0.001


100%|██████████| 23/23 [00:07<00:00,  3.28it/s]


0.001
[12/50] train_rmse for conc 0: 0.10776 train_pcc for conc 0: 0.01281 lr : 0.001
[12/50] train_rmse for conc 1: 0.11571 train_pcc for conc 1: 0.01269 lr : 0.001
[12/50] train_rmse for conc 2: 0.13453 train_pcc for conc 2: 0.00213 lr : 0.001
[12/50] train_rmse for conc 3: 0.18404 train_pcc for conc 3: 0.01717 lr : 0.001
[12/50] train_rmse for conc 4: 0.24592 train_pcc for conc 4: 0.01569 lr : 0.001
[12/50] train_rmse for conc 5: 0.30115 train_pcc for conc 5: 0.02418 lr : 0.001
[12/50] train_rmse for conc 6: 0.33054 train_pcc for conc 6: 0.01807 lr : 0.001


100%|██████████| 23/23 [00:07<00:00,  3.16it/s]


0.001
[13/50] train_rmse for conc 0: 0.10937 train_pcc for conc 0: 0.00191 lr : 0.001
[13/50] train_rmse for conc 1: 0.11842 train_pcc for conc 1: 0.00778 lr : 0.001
[13/50] train_rmse for conc 2: 0.13565 train_pcc for conc 2: 0.00105 lr : 0.001
[13/50] train_rmse for conc 3: 0.18461 train_pcc for conc 3: 0.01207 lr : 0.001
[13/50] train_rmse for conc 4: 0.24604 train_pcc for conc 4: 0.01363 lr : 0.001
[13/50] train_rmse for conc 5: 0.30224 train_pcc for conc 5: 0.01084 lr : 0.001
[13/50] train_rmse for conc 6: 0.33086 train_pcc for conc 6: 0.01133 lr : 0.001


 26%|██▌       | 6/23 [00:02<00:07,  2.40it/s]


KeyboardInterrupt: 