In [1]:
from utils.dataset import DataGenerator, collate_fn, to_tensor
from utils.data import train_test_split, train_test_split_by_smiles, DataScaler
from utils.trainer import Trainer
from model.modules import CATEncoder, DNN, GraphEncoder, MoleculeEncoder, AddPoolFineTuner, MaxPoolFineTuner, StackFineTuner
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch import nn
import numpy as np
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from torch.utils.tensorboard import SummaryWriter
import gc, os, torch, tqdm, pickle

In [4]:
seed = 102
model_root = '/home/jhyang/WORKSPACES/MODELS/fpoly'

dg = DataGenerator(None, None, include_autocorr=True)
dg.generate_fpoly_from_csv('/home/jhyang/WORKSPACES/DATA/polymers/f-polymer/f-polymer-20220922.csv',
                           pfx_frac='FR', pfx_smiles='SMILES', col_target='Target', augmentation=0)

train_data_, test_data, test_smiles = train_test_split_by_smiles(np.array(dg.data), n_test=7, seed=seed)
train_data, valid_data, valid_smiles = train_test_split_by_smiles(train_data_, n_test=7, seed=seed)

train_data = to_tensor(train_data)
valid_data = to_tensor(valid_data)
test_data = to_tensor(test_data)
print(len(train_data), len(valid_data), len(test_data))
print(valid_smiles)
print(test_smiles)


generate: 100%|██████████| 57/57 [00:00<00:00, 535.87it/s]
gather: 100%|██████████| 293/293 [00:00<00:00, 1633.10it/s]


184 48 61
['FC(F)(F)C(F)(F)C(F)(F)C(F)(F)CCOC(=O)C=C', 'CCCCCCOC(=O)C(C)=C', 'C=CC(=O)OCC1CCCO1', 'C=CC(=O)OCC(C(C(F)(F)F)F)(F)F', 'CC1C2CC(C1(C)C)CC2C3CCCC(C3)OC(=O)C=C', 'CC(=C)C(=O)OCCC(F)(F)C(F)(F)C(F)(F)C(F)(F)F', 'CC(=C)C(=O)OC1CCCCC1']
['CC(=C)C(=O)OC12CC3CC(C1)CC(C3)(C2)O', 'COC(=O)C=C', 'C=CC(=O)OCCOc1ccccc1', 'CC(=C)C(=O)OC(C)(C)C', 'FC(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)CCOC(=O)C=C', 'FC(F)(F)C(F)(F)C(F)(F)COC(=O)C=C', 'C=CC(=O)OC1C=CC2C1C3CCC2C3']


##### TG data

In [5]:
model_root = '/home/jhyang/WORKSPACES/MODELS/fpoly/tg_random'
dgtg = DataGenerator(None, None, include_autocorr=True)

dgtg.generate_fpoly_from_csv('/home/jhyang/WORKSPACES/DATA/polymers/f-polymer/dsc.csv',
                             pfx_frac='FR', pfx_smiles='SMILES', col_target='tg')
train_data_, test_data, t_ = train_test_split_by_smiles(dgtg.data, seed=seed, test_smiles=test_smiles)
train_data, valid_data, v_ = train_test_split_by_smiles(train_data_, seed=seed, test_smiles=valid_smiles)
train_data = to_tensor(train_data)
valid_data = to_tensor(valid_data)
test_data = to_tensor(test_data)
print(len(train_data), len(valid_data), len(test_data))

generate: 100%|██████████| 58/58 [00:00<00:00, 531.14it/s]
gather: 100%|██████████| 257/257 [00:00<00:00, 1638.35it/s]

160 44 53





In [35]:
seed = 0
batch_size = 20
model_root = f'/home/jhyang/WORKSPACES/MODELS/fpoly/tg_r{seed}'
dg = DataGenerator(None, None, include_autocorr=True)
dg.generate_fpoly_from_csv('/home/jhyang/WORKSPACES/DATA/polymers/f-polymer/dsc.csv',
                           pfx_frac='FR', pfx_smiles='SMILES', col_target='tg')

train_data_, test_data = train_test_split(dg.data, train_ratio=0.8, seed=seed)
train_data, valid_data = train_test_split(train_data_, train_ratio=0.75, seed=seed)
print(len(train_data), len(valid_data), len(test_data))
train_data = to_tensor(train_data)
valid_data = to_tensor(valid_data)
test_data  = to_tensor(test_data)

scaler = DataScaler()
scaler.train(train_data)
train_scaled = scaler.scale_data(train_data)
valid_scaled = scaler.scale_data(valid_data)
test_scaled = scaler.scale_data(test_data)
print(sorted([d['id'] for d in valid_data]))
print(sorted([d['id'] for d in test_data]))

generate: 100%|██████████| 58/58 [00:00<00:00, 524.97it/s]
gather: 100%|██████████| 257/257 [00:00<00:00, 1959.60it/s]

153 52 52
['FA-00109', 'FA-00112', 'FA-00119', 'FA-00124', 'FA-00126', 'FA-00139', 'FA-00140', 'FA-00147', 'FA-00154', 'FA-00157', 'FA-00163', 'FA-00164', 'FA-00172', 'FA-00174', 'FA-00176', 'FA-00187', 'FA-00193', 'FA-00195', 'FA-00210', 'FA-00213', 'FA-00215', 'FA-00216', 'FA-00218', 'FA-00223', 'FA-00224', 'FA-00238', 'FA-00240', 'FA-00248', 'FA-00249', 'FA-00285', 'FA-00294', 'FA-00299', 'FA-00300', 'FA-00304', 'FA-00347', 'FA-00348', 'FA-00353', 'FA-00361', 'FA-00375', 'FA-00384', 'FA-00388', 'FA-00411', 'FA-00422', 'FA-00430', 'FA-00433', 'FA-00439', 'FA-00443', 'FA-00448', 'FA-00452', 'FA-00458', 'FA-00464', 'FA-00466']
['FA-00120', 'FA-00132', 'FA-00141', 'FA-00143', 'FA-00151', 'FA-00165', 'FA-00173', 'FA-00186', 'FA-00188', 'FA-00202', 'FA-00207', 'FA-00229', 'FA-00230', 'FA-00232', 'FA-00234', 'FA-00237', 'FA-00241', 'FA-00244', 'FA-00247', 'FA-00281', 'FA-00295', 'FA-00297', 'FA-00311', 'FA-00335', 'FA-00336', 'FA-00344', 'FA-00346', 'FA-00350', 'FA-00355', 'FA-00356', 'FA-




In [19]:
def do_epoch(train_dl, valid_dl, test_dl, trainer, path, epochs=10000, early_stop=500, 
             relax_after=0, logging_interval=50):
    os.makedirs(path, exist_ok=True)
    writer = SummaryWriter(path)
    best_loss = 1e6
    count = 0
    for epoch in tqdm.tqdm(range(1,epochs+1), desc=path.replace(model_root,'')):
        if epoch == relax_after + 1:
            for param in trainer.model.parameters():
                param.requires_grad = True
            for pg in trainer.opt.param_groups:
                pg['lr'] = pg['lr'] * 0.2
            
        train_loss = trainer.train(train_dl)
        valid_loss, vi, vs, vt, vp = trainer.test(valid_dl)
        
        writer.add_scalar('train/loss', train_loss, epoch)
        writer.add_scalar('valid/loss', valid_loss, epoch)
        writer.add_scalar('valid/R2', r2_score(vt, vp), epoch)
        writer.add_scalar('valid/MAE', mean_absolute_error(vt, vp), epoch)

        if epoch % 5 == 0:
            _, _, _, t, p = trainer.test(train_dl)
            writer.add_scalar('train/MAE', mean_absolute_error(t, p), epoch)
            writer.add_scalar('train/R2', r2_score(t, p), epoch)
            _, _, _, t, p = trainer.test(test_dl)
            writer.add_scalar('test/MAE', mean_absolute_error(t, p), epoch)
            writer.add_scalar('test/R2', r2_score(t, p), epoch)

        if epoch % logging_interval == 0:
            trainer.model.save(path, desc=f'{epoch:05d}')
            _, ti, ts, tt, tp = trainer.test(train_dl)
            with open(os.path.join(path, f'{epoch:05d}.train.pkl'),'wb') as f:
                pickle.dump([ti, tt, tp], f)
            with open(os.path.join(path, f'{epoch:05d}.valid.pkl'),'wb') as f:
                pickle.dump([vi, vt, vp], f)
            _, tti, tts, ttt, ttp = trainer.test(test_dl)
            with open(os.path.join(path, f'{epoch:05d}.test.pkl'),'wb') as f:
                pickle.dump([tti, ttt, ttp], f)

        if valid_loss > best_loss:
            count += 1
        else:
            count = 0
            best_loss = valid_loss
            trainer.model.save(path, desc='best')
            _, ti, ts, tt, tp = trainer.test(train_dl)
            with open(os.path.join(path, f'best.train.pkl'),'wb') as f:
                pickle.dump([ti, tt, tp], f)
            with open(os.path.join(path, f'best.valid.pkl'),'wb') as f:
                pickle.dump([vi, vt, vp], f)
            _, tti, tts, ttt, ttp = trainer.test(test_dl)
            with open(os.path.join(path, f'best.test.pkl'),'wb') as f:
                pickle.dump([tti, ttt, ttp], f)
            with open(os.path.join(path, f'best.epoch.txt'),'w') as f:
                f.write(f'{epoch}\n')
        if isinstance(early_stop, int) and count >= early_stop and epoch > relax_after:
            break
    writer.close()

# non-pretrained models 
- model dependancy

- xgboost

In [20]:
def process_data_fpoly(data):
    feat, target, ids, smiles = collate_fn(data)
    feat = np.hstack([
        feat['mol_feat'].cpu().numpy(),
        feat['weight'].cpu().numpy()])
    n_sample, n_feature = feat.shape
    feat = feat.reshape(n_sample//5, n_feature*5)
    target = target.cpu().numpy()
    return feat, target, ids, smiles

In [21]:
feat_train, target_train, ids_train, _ = process_data_fpoly(train_data)
feat_valid, target_valid, ids_valid, _ = process_data_fpoly(valid_data)
feat_test, target_test, ids_test , _ = process_data_fpoly(test_data)

In [34]:
import xgboost

dmat_train = xgboost.DMatrix(feat_train, target_train)
dmat_valid = xgboost.DMatrix(feat_valid, target_valid)
dmat_test = xgboost.DMatrix(feat_test, target_test)

progress = {}
booster = xgboost.train(dtrain=dmat_train, params={'max_depth':8, 'subsample':1}, 
                        num_boost_round=5000, early_stopping_rounds=30, 
                        evals=[(dmat_train,'Train'),(dmat_valid,'Valid')],
                        evals_result=progress, verbose_eval=False)

dest = os.path.join(model_root, 'scratch/xgboost', )
os.makedirs(dest, exist_ok=True)
writer = SummaryWriter(dest)
booster.save_model(os.path.join(dest, 'model.json'))
with open(os.path.join(dest, 'best.epoch.txt'),'w') as f:
    f.write(f'{booster.best_iteration}\n')
for d in ['train','valid','test']:
    pred = booster.predict(eval(f'dmat_{d}'), iteration_range=(0,booster.best_iteration))
    writer.add_scalar(f'{d}/MAE', mean_absolute_error(eval(f'target_{d}'), pred), booster.best_iteration)
    writer.add_scalar(f'{d}/R2', r2_score(eval(f'target_{d}'), pred), booster.best_iteration)

    with open(os.path.join(dest, f'best.{d}.pkl'),'wb') as f:
        pickle.dump([eval(f'ids_{d}'), eval(f'target_{d}'), pred], f)

- graphnet

In [29]:
epochs = 10000

atom_net_params  = {
    'n_atom_feat':None,
    'n_bond_feat':None,
    'graph':None,
    'hidden_dim':64,
    'output_dim':64,
    'n_layer':4
}

decoder_params = {
    'input_dim':None,
    'hidden_dims':[256, 256],
    'output_dim':1,
}

early_stop = 500

atom_net_params['n_atom_feat'] = dg.n_atom_feat
atom_net_params['n_bond_feat'] = dg.n_bond_feat

trainers = {
    'max': (MaxPoolFineTuner, 64 * 2), 
    'add': (AddPoolFineTuner, 64 * 2), 
#    'stack': (StackFineTuner, 64 * 3 * 10), 
}

for graph in ['cg','transformer']:
    atom_net_params['graph'] = graph
    for trtyp, (FineTuner, decoder_input_dim) in trainers.items():
        decoder_params['input_dim'] = decoder_input_dim
        for n in range(2):
            scr_path = os.path.join(model_root, 'scratch/graph', f'{trtyp}_{graph}_{n:02d}')
            if os.path.isdir(scr_path): continue
            gc.collect()
            torch.cuda.empty_cache()
            torch.random.manual_seed(seed + n)
            torch.cuda.manual_seed(seed + n)
            train_dl = DataLoader(dataset=train_scaled, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
            valid_dl = DataLoader(dataset=valid_scaled, batch_size=256, collate_fn=collate_fn)
            test_dl  = DataLoader(dataset=test_scaled, batch_size=256, collate_fn=collate_fn)

            encoder = GraphEncoder(atom_net_params)
            decoder = DNN(**decoder_params)
            model = FineTuner(encoder, decoder).cuda()
            opt   = AdamW(model.parameters(), lr=1e-4)
            trainer = Trainer(model, opt, scaler=scaler)

            do_epoch(train_dl=train_dl, valid_dl=valid_dl, test_dl=test_dl, trainer=trainer, path=scr_path)


/scratch/graph/max_cg_00:  11%|█         | 1106/10000 [00:35<04:43, 31.33it/s]
/scratch/graph/max_cg_01:  15%|█▌        | 1534/10000 [00:48<04:28, 31.57it/s]
/scratch/graph/add_cg_00:   9%|▉         | 879/10000 [00:27<04:47, 31.78it/s]
/scratch/graph/add_cg_01:  11%|█         | 1084/10000 [00:33<04:36, 32.26it/s]
/scratch/graph/max_transformer_00:  20%|██        | 2009/10000 [01:40<06:41, 19.90it/s]
/scratch/graph/max_transformer_01:  20%|██        | 2035/10000 [01:42<06:40, 19.90it/s]
/scratch/graph/add_transformer_00:  21%|██        | 2097/10000 [01:44<06:33, 20.08it/s]
/scratch/graph/add_transformer_01:  12%|█▏        | 1170/10000 [00:58<07:19, 20.08it/s]


- molnet

In [30]:
epochs = 10000

mol_net_params  = {
    'input_dim':None,
    'hidden_dims':[256, 256, 256],
    'output_dim':64,
}
early_stop = 500

trainers = {
    'add': (AddPoolFineTuner, 64 * 2), 
    'max': (MaxPoolFineTuner, 64 * 2), 
#    'stack': (StackFineTuner, 64 * 3 * 10), 
}
mol_net_params['input_dim'] = dg.n_mol_feat 

for aug in [0]:
    for trtyp, (FineTuner, decoder_input_dim) in trainers.items():
        for n in range(2):
            scr_path = os.path.join(model_root, 'scratch/mol', f'{trtyp}_a{aug:03d}_dnn_{n:02d}')
            gc.collect()
            torch.cuda.empty_cache()
            torch.random.manual_seed(seed + n)
            torch.cuda.manual_seed(seed + n)
            train_dl = DataLoader(dataset=train_scaled, batch_size=batch_size*(aug+1), shuffle=True, collate_fn=collate_fn)
            valid_dl = DataLoader(dataset=valid_scaled, batch_size=2048, collate_fn=collate_fn)
            test_dl  = DataLoader(dataset=test_scaled, batch_size=2048, collate_fn=collate_fn)
            
            encoder = MoleculeEncoder(mol_net_params)
            decoder = DNN(input_dim=decoder_input_dim, hidden_dims=[256, 256], output_dim=1)

            model = FineTuner(encoder, decoder).cuda()
            opt   = AdamW(model.parameters(), lr=1e-5)
            trainer = Trainer(model, opt, scaler)
            
            do_epoch(train_dl=train_dl, valid_dl=valid_dl, test_dl=test_dl, trainer=trainer, path=scr_path)

/scratch/mol/add_a000_dnn_00:  15%|█▌        | 1531/10000 [00:37<03:24, 41.38it/s]
/scratch/mol/add_a000_dnn_01:  24%|██▎       | 2368/10000 [00:57<03:06, 41.01it/s]
/scratch/mol/max_a000_dnn_00:  12%|█▏        | 1249/10000 [00:30<03:37, 40.30it/s]
/scratch/mol/max_a000_dnn_01:   9%|▉         | 917/10000 [00:22<03:47, 39.87it/s]


- concat

In [31]:
epochs = 10000
early_stop = 500

atom_net_params  = {
    'n_atom_feat':None,     'n_bond_feat':None,     'graph':None,
    'hidden_dim':64,        'output_dim':64,        'n_layer':4
}

mol_net_params  = {
    'input_dim':None,       'hidden_dims':[256, 256, 256],      'output_dim':64,
}

trainers = {
    'add': (AddPoolFineTuner, 64 * 3 * 2), 
    'max': (MaxPoolFineTuner, 64 * 3 * 2), 
    'stack': (StackFineTuner, 64 * 3 * 10), 
}

atom_net_params['n_atom_feat'] = dg.n_atom_feat
atom_net_params['n_bond_feat'] = dg.n_bond_feat
mol_net_params['input_dim'] = dg.n_mol_feat 

for aug in [0]:
#    dg.generate_fpoly_from_csv('/home/jhyang/WORKSPACES/DATA/polymers/f-polymer/f-polymer-20220922.csv',
#                               pfx_frac='FR', pfx_smiles='SMILES', col_target='Target', augmentation=aug)
#    dg.generate_fpoly_from_csv('/home/jhyang/WORKSPACES/DATA/polymers/f-polymer/dsc.csv',
#                               pfx_frac='FR', pfx_smiles='SMILES', col_target='tg', augmentation=aug)
#    train_data_, test_data, _ = train_test_split_by_smiles(np.array(dg.data), seed=seed, test_smiles=test_smiles)
#    train_data, valid_data, _ = train_test_split_by_smiles(train_data_, seed=seed, test_smiles=valid_smiles)

#    scaler = DataScaler()
#    train_data = to_torch(train_data)
#    valid_data = to_torch(valid_data)
#    test_data = to_torch(test_data)
#    scaler.train(train_data)
#    train_scale = scaler.scale_data(train_data)
#    valid_scale = scaler.scale_data(valid_data)
#    test_scale = scaler.scale_data(test_data)

    for graph in ['cg','transformer']:
        atom_net_params['graph'] = graph
        
        for trtyp, (FineTuner, decoder_input_dim) in trainers.items():
            if aug > 1 and trtyp in ['add','max']: continue
            for n in range(2):
                scr_path = os.path.join(model_root, 'scratch/cat', f'{trtyp}_a{aug:03d}_{graph}_{n:02d}')
                gc.collect()
                torch.cuda.empty_cache()
                torch.random.manual_seed(seed + n)
                torch.cuda.manual_seed(seed + n)
                train_dl = DataLoader(dataset=train_scaled, batch_size=batch_size*(aug+1), shuffle=True, collate_fn=collate_fn)
                valid_dl = DataLoader(dataset=valid_scaled, batch_size=2048, collate_fn=collate_fn)
                test_dl  = DataLoader(dataset=test_scaled, batch_size=2048, collate_fn=collate_fn)
                
                encoder = CATEncoder(atom_net_params, mol_net_params)
                decoder = DNN(input_dim=decoder_input_dim, hidden_dims=[256, 256], output_dim=1)

                model = FineTuner(encoder, decoder).cuda()
                opt   = AdamW(model.parameters(), lr=3e-5)
                trainer = Trainer(model, opt, scaler)
                
                do_epoch(train_dl=train_dl, valid_dl=valid_dl, test_dl=test_dl, trainer=trainer, path=scr_path)

/scratch/cat/add_a000_cg_00:  11%|█         | 1118/10000 [00:41<05:31, 26.79it/s]
/scratch/cat/add_a000_cg_01:  19%|█▊        | 1868/10000 [01:08<05:00, 27.08it/s]
/scratch/cat/max_a000_cg_00:  14%|█▎        | 1354/10000 [00:50<05:24, 26.61it/s]
/scratch/cat/max_a000_cg_01:  21%|██        | 2099/10000 [01:18<04:57, 26.57it/s]
/scratch/cat/stack_a000_cg_00:   6%|▋         | 646/10000 [00:23<05:44, 27.14it/s]
/scratch/cat/stack_a000_cg_01:   8%|▊         | 834/10000 [00:30<05:35, 27.32it/s]
/scratch/cat/add_a000_transformer_00:  15%|█▌        | 1516/10000 [01:24<07:53, 17.90it/s]
/scratch/cat/add_a000_transformer_01:   8%|▊         | 780/10000 [00:44<08:40, 17.70it/s]
/scratch/cat/max_a000_transformer_00:   8%|▊         | 789/10000 [00:44<08:40, 17.70it/s]
/scratch/cat/max_a000_transformer_01:  11%|█         | 1104/10000 [01:02<08:24, 17.64it/s]
/scratch/cat/stack_a000_transformer_00:  19%|█▉        | 1883/10000 [01:44<07:29, 18.07it/s]
/scratch/cat/stack_a000_transformer_01:  12%|█▏    

# Pretrained models

- SSIB

In [32]:
pt_root = '/home/jhyang/WORKSPACES/MODELS/fpoly/'
epochs = 10000
early_stop = 500

atom_net_params  = {
    'n_atom_feat':None,     'n_bond_feat':None,     'graph':None,
    'hidden_dim':64,        'output_dim':64,        'n_layer':4
}

mol_net_params  = {
    'input_dim':None,       'hidden_dims':[256, 256, 256],      'output_dim':64,
}

trainers = {
    'add': (AddPoolFineTuner, 64 * 3 * 2), 
    'max': (MaxPoolFineTuner, 64 * 3 * 2), 
    'stack': (StackFineTuner, 64 * 3 * 10), 
}

for aug in [0]:#, 10, 30,]:

    atom_net_params['n_atom_feat'] = dg.n_atom_feat
    atom_net_params['n_bond_feat'] = dg.n_bond_feat
    mol_net_params['input_dim'] = dg.n_mol_feat 

    for graph in ['cg','transformer']:
        atom_net_params['graph'] = graph
        
        for trtyp, (FineTuner, decoder_input_dim) in trainers.items():
            if aug > 1 and trtyp in ['add','max']: continue
            for amnt in ['99k','ALL']:
                pt_path = os.path.join(pt_root, f'encoders/ssib/U_wF_{amnt}/{graph}_pt')
                if not os.path.isdir(pt_path):
                    continue
                for relax_after in [10, 300, 600]:
                    for n in range(1):
                        ft_path = os.path.join(model_root, 'finetune/ssib', f'{amnt}', f'{trtyp}_{graph}_r{relax_after:03d}_a{aug:03d}_{n:02d}')
                        gc.collect()
                        torch.cuda.empty_cache()
                        torch.random.manual_seed(seed + n)
                        torch.cuda.manual_seed(seed + n)

                        train_dl = DataLoader(dataset=train_scaled, batch_size=batch_size*(aug+1), shuffle=True, collate_fn=collate_fn)
                        valid_dl = DataLoader(dataset=valid_scaled, batch_size=2048, collate_fn=collate_fn)
                        test_dl  = DataLoader(dataset=test_scaled, batch_size=2048, collate_fn=collate_fn)

                        encoder = CATEncoder(atom_net_params, mol_net_params)
                        encoder.load(pt_path, desc='best', freeze=True)
                        decoder = DNN(input_dim=decoder_input_dim, hidden_dims=[256, 256], output_dim=1)

                        model = FineTuner(encoder, decoder).cuda()
                        opt   = AdamW(model.parameters(), lr=3e-5)
                        trainer = Trainer(model, opt, scaler)

                        do_epoch(train_dl=train_dl, valid_dl=valid_dl, test_dl=test_dl, trainer=trainer, 
                                 path=ft_path, relax_after=relax_after)

/finetune/ssib/99k/add_cg_r010_a000_00:   9%|▉         | 933/10000 [00:34<05:33, 27.22it/s]
/finetune/ssib/99k/add_cg_r300_a000_00:   8%|▊         | 759/10000 [00:24<04:56, 31.20it/s]
/finetune/ssib/99k/add_cg_r600_a000_00:   8%|▊         | 759/10000 [00:20<04:11, 36.78it/s]
/finetune/ssib/ALL/add_cg_r010_a000_00:   6%|▌         | 570/10000 [00:20<05:47, 27.16it/s]
/finetune/ssib/ALL/add_cg_r300_a000_00:   7%|▋         | 654/10000 [00:20<04:52, 31.93it/s]
/finetune/ssib/ALL/add_cg_r600_a000_00:   7%|▋         | 654/10000 [00:16<03:58, 39.18it/s]
/finetune/ssib/99k/max_cg_r010_a000_00:   7%|▋         | 734/10000 [00:27<05:44, 26.87it/s]
/finetune/ssib/99k/max_cg_r300_a000_00:   7%|▋         | 698/10000 [00:22<04:58, 31.17it/s]
/finetune/ssib/99k/max_cg_r600_a000_00:   7%|▋         | 698/10000 [00:18<04:07, 37.59it/s]
/finetune/ssib/ALL/max_cg_r010_a000_00:  17%|█▋        | 1666/10000 [01:01<05:09, 26.93it/s]
/finetune/ssib/ALL/max_cg_r300_a000_00:   6%|▋         | 632/10000 [00:19<04:55

In [33]:
pt_root = '/home/jhyang/WORKSPACES/MODELS/fpoly/'

epochs = 10000
early_stop = 500

atom_net_params  = {
    'n_atom_feat':None,     'n_bond_feat':None,     'graph':None,
    'hidden_dim':64,        'output_dim':64,        'n_layer':4
}

mol_net_params  = {
    'input_dim':None,       'hidden_dims':[256, 256, 256],      'output_dim':64,
}

trainers = {
    'add': (AddPoolFineTuner, 64 * 3 * 2), 
    'max': (MaxPoolFineTuner, 64 * 3 * 2), 
    'stack': (StackFineTuner, 64 * 3 * 10), 
}

for aug in [0]:#, 10, 30]:
#    dg.generate_fpoly_from_csv('/home/jhyang/WORKSPACES/DATA/polymers/f-polymer/f-polymer-20220922.csv',
#                               pfx_frac='FR', pfx_smiles='SMILES', col_target='Target', augmentation=aug)
#    dg.generate_fpoly_from_csv('/home/jhyang/WORKSPACES/DATA/polymers/f-polymer/dsc.csv',
#                               pfx_frac='FR', pfx_smiles='SMILES', col_target='tg', augmentation=aug)
#
#    train_data_, test_data, test_smiles = train_test_split_by_smiles(np.array(dg.data), test_smiles=test_smiles)
#    train_data, valid_data, valid_smiles = train_test_split_by_smiles(train_data_, test_smiles=valid_smiles)
#
#    scaler = DataScaler()
#    train_data = to_torch(train_data)
#    valid_data = to_torch(valid_data)
#    test_data = to_torch(test_data)
#    scaler.train(train_data)
#    train_scale = scaler.scale_data(train_data)
#    valid_scale = scaler.scale_data(valid_data)
#    test_scale = scaler.scale_data(test_data)

    atom_net_params['n_atom_feat'] = dg.n_atom_feat
    atom_net_params['n_bond_feat'] = dg.n_bond_feat
    mol_net_params['input_dim'] = dg.n_mol_feat 

    for graph in ['cg','transformer']:
        atom_net_params['graph'] = graph
        
        for trtyp, (FineTuner, decoder_input_dim) in trainers.items():
            if aug > 1 and trtyp in ['add','max']: continue
            for amnt in ['99k','ALL']:
                for n in range(2):
                    pt_path = os.path.join(pt_root, f'encoders/cat/U_wF_{amnt}/{graph}_{n:02d}')
                    if not os.path.isdir(pt_path):
                        continue
                    for relax_after in [0, 300, 600]:
                        ft_path = os.path.join(model_root, 'finetune/cat', f'{amnt}', f'{trtyp}_{graph}_r{relax_after:03d}_a{aug:03d}_{n:02d}')
                        gc.collect()
                        torch.cuda.empty_cache()
                        torch.random.manual_seed(seed + n)
                        torch.cuda.manual_seed(seed + n)

                        train_dl = DataLoader(dataset=train_scaled, batch_size=batch_size*(aug+1), shuffle=True, collate_fn=collate_fn)
                        valid_dl = DataLoader(dataset=valid_scaled, batch_size=2048, collate_fn=collate_fn)
                        test_dl  = DataLoader(dataset=test_scaled, batch_size=2048, collate_fn=collate_fn)

                        encoder = CATEncoder(atom_net_params, mol_net_params)
                        encoder.load(pt_path, desc='best', freeze=True)
                        decoder = DNN(input_dim=decoder_input_dim, hidden_dims=[256, 256], output_dim=1)

                        model = FineTuner(encoder, decoder).cuda()
                        opt   = AdamW(model.parameters(), lr=3e-5)
                        trainer = Trainer(model, opt, scaler)

                        do_epoch(train_dl=train_dl, valid_dl=valid_dl, test_dl=test_dl, trainer=trainer, 
                                 path=ft_path, relax_after=relax_after)

/finetune/cat/99k/add_cg_r000_a000_00:  15%|█▍        | 1452/10000 [00:54<05:18, 26.88it/s]
/finetune/cat/99k/add_cg_r300_a000_00:  13%|█▎        | 1286/10000 [00:43<04:55, 29.47it/s]
/finetune/cat/99k/add_cg_r600_a000_00:  23%|██▎       | 2306/10000 [01:17<04:18, 29.77it/s]
/finetune/cat/99k/add_cg_r000_a000_01:   6%|▌         | 565/10000 [00:21<05:52, 26.75it/s]
/finetune/cat/99k/add_cg_r300_a000_01:   5%|▌         | 532/10000 [00:15<04:41, 33.67it/s]
/finetune/cat/99k/add_cg_r600_a000_01:   6%|▌         | 600/10000 [00:14<03:52, 40.50it/s]
/finetune/cat/ALL/add_cg_r000_a000_00:  22%|██▏       | 2240/10000 [01:22<04:47, 27.04it/s]
/finetune/cat/ALL/add_cg_r300_a000_00:  20%|██        | 2040/10000 [01:11<04:39, 28.53it/s]
/finetune/cat/ALL/add_cg_r600_a000_00:   8%|▊         | 805/10000 [00:22<04:14, 36.19it/s]
/finetune/cat/ALL/add_cg_r000_a000_01:  12%|█▏        | 1235/10000 [00:45<05:24, 26.98it/s]
/finetune/cat/ALL/add_cg_r300_a000_01:  23%|██▎       | 2287/10000 [01:20<04:30, 28.