In [1]:
import sys, torch, os
sys.path.append('..')
from src.networks import TransformerDecoderBlock
from src.data import ReactionDataset
from src.feature import PrecursorDataset
from src.trainer import BaseTrainer, SequenceTrainer
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import matplotlib.pyplot as plt

In [2]:
DS = ReactionDataset(feat_type='cgcnn')
DS.from_file('../data/scrxn.pkl.gz', 
             heat_temp_key=('heat_temp','median'))

years = np.array([d.year for d in DS])
train_mask = years < 2016
valid_mask = (years >= 2016) & years < 2018
test_mask = years >= 2018

train_dl = DataLoader(DS, batch_size=256, sampler=SubsetRandomSampler(np.where(train_mask)[0]), collate_fn=DS.cfn)
valid_dl = DataLoader(DS, batch_size=32, sampler=np.where(train_mask)[0], collate_fn=DS.cfn)
test_dl = DataLoader(DS, batch_size=32, sampler=np.where(train_mask)[0], collate_fn=DS.cfn)

In [3]:
feat, info = next(iter(valid_dl))
for k,v in feat.items():
    print(k, v.shape, type(v))

target torch.Size([32, 7, 91]) <class 'torch.Tensor'>
precursor_feat torch.Size([32, 8]) <class 'torch.Tensor'>
label torch.Size([20384]) <class 'torch.Tensor'>
context torch.Size([32, 92]) <class 'torch.Tensor'>
weight torch.Size([224]) <class 'torch.Tensor'>
mask torch.Size([224]) <class 'torch.Tensor'>


#### tested classes

In [17]:
from src.trainer import BaseTrainer
from src.networks import BaseNetwork
from typing import Iterable, List, Dict

class PositionalEncoding:
    def __init__(self, num_dim, max_len=100):
        pe = torch.zeros((max_len, num_dim)).float()
        pos = torch.arange(0, max_len).reshape(-1,1).float()
        div = torch.exp(- torch.arange(0, num_dim, 2) * torch.log(torch.tensor([10000])) / num_dim).float()
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.pe = torch.nn.Parameter(pe, requires_grad=False)

    def __call__(self, seq_len):
        return self.pe[:seq_len].unsqueeze(0)

class TestTransformerDecoderBlock(BaseNetwork):
    def __init__(self,
                 feature_dim:int,
                 context_dim:int,
                 vocab_dim:int, 
                 num_heads:int = 4,
                 hidden_dim:int = 32,
                 ff_dim_mul:int = 2,
                 hidden_layers:int = 2,
                 positional_encoding:bool = True,
                 batch_norm:bool = False,
                 dropout = 0.2,
                 activation:str = 'LeakyReLU',
                 negative_slope:float = 0.1,
                 ):
        
        super().__init__(context_dim = context_dim,
                         vocab_dim = vocab_dim,
                         num_heads = num_heads,
                         hidden_dim = hidden_dim,
                         hidden_layers = hidden_layers,
                         positional_encoding = positional_encoding,
                         batch_norm = batch_norm,
                         dropout = dropout,
                         activation = activation,
                         negative_slope = negative_slope,
        )

        self.vocab_dim = vocab_dim
        try:
            activation = eval(f'torch.nn.{activation}({negative_slope})')
        except:
            activation = eval(f'torch.nn.{activation}()')

        self.positional_encoding = PositionalEncoding(hidden_dim) if positional_encoding else False

        self.feature_embed_layer = torch.nn.Sequential(
            torch.nn.Linear(feature_dim, hidden_dim),
            torch.nn.BatchNorm1d(hidden_dim),
#            torch.nn.Dropout(dropout),
            activation,            
        )

        self.context_embed_layer = torch.nn.Sequential(
            torch.nn.Linear(context_dim, hidden_dim),
            torch.nn.BatchNorm1d(hidden_dim),
#            torch.nn.Dropout(dropout),
            activation,
        )

        self.transformer_decoder = torch.nn.TransformerDecoder(
            torch.nn.TransformerDecoderLayer(
                d_model = hidden_dim,
                nhead = num_heads,
                dim_feedforward = hidden_dim * 2,
                dropout = dropout,
                activation = activation,
                batch_first = True,
            ), 
            num_layers = hidden_layers, 
            norm = torch.nn.BatchNorm1d(hidden_dim) if batch_norm else None,
        )
        self.output_layer = torch.nn.Linear(hidden_dim, vocab_dim)

    def forward(self, target, context, *args, **kwargs):
        n = target.shape[1]
        target_mask = torch.nn.Transformer.generate_square_subsequent_mask(n)

        if self.positional_encoding: 
            target += self.positional_encoding(n)
        context_embed = self.context_embed_layer(context).unsqueeze(1)#.repeat(1, n, 1)

        h = self.transformer_decoder(target, context_embed, tgt_mask=target_mask)
        out = self.output_layer(h)
        return out
    
    def generate(self, context, embed_fn, max_len=20, *args, **kwargs):
        output_seq = torch.ones(context.shape[0], 1).long().to(self.device) + self.vocab_dim 
        for _ in range(max_len):
            output = self.forward(output_seq, context)
            output_seq = torch.hstack([output_seq, output.argmax(-1)[:, -1:]])
        return output_seq.cpu().numpy()[:, 1:]

    def to(self, *args, **kwargs):
        super().to(*args, **kwargs)
        if self.positional_encoding:
            self.positional_encoding.pe = self.positional_encoding.pe.to(*args, **kwargs)

    def save(self, path, prefix, overwrite=True):
        self._save(path, f'{prefix}.model', overwrite)

class TestSequenceTrainer(BaseTrainer):
    def __init__(self, model, lr, device='cuda', crit=torch.nn.CrossEntropyLoss(reduction='none')):
        super().__init__(model, lr, device, crit)
    
    def _eval_batch(self, batch, compute_loss=True):
        _feat, _ = batch
        feat = {k:v.to(self.device) for k,v in _feat.items()}
        pred = self.model(**feat)
        if compute_loss:
            _loss = self.crit(pred.view(feat['label'].shape[0], -1), feat['label'])[feat['mask']]
#            print(_loss.shape, feat['weight'].shape, feat['mask'].shape)
            loss = (_loss * feat['weight'][feat['mask']]).mean()
            return loss, pred.detach().cpu().numpy()
        else:
            return pred.detach().cpu().numpy()
    
    def _parse_output(self, batch, output):
        feat, info = batch
        if self._output is None:
            self._output = {
                'info' : info,
                'pred' : output
            }
            if feat['weight'] is not None:
                n = feat['context'].shape[0]
                self._output.update({
                    'label': feat['label'].cpu().numpy().reshape(n, -1),
                    'weight': feat['weight'].cpu().numpy().reshape(n, -1)[:, 0],
                })
        else:
            self._output['info'].extend(info)
            self._output['pred'] = np.vstack([self._output['pred'], output])
            if feat['weight'] is not None:
                n = feat['context'].shape[0]
                self._output['label'] = np.vstack([self._output['label'], feat['label'].cpu().numpy().reshape(n, -1)])
                self._output['weight'] = np.hstack([self._output['weight'], feat['weight'].cpu().numpy().reshape(n, -1)[:,0]])

# Test

In [18]:
model = TestTransformerDecoderBlock(DS.num_condition_feat, DS.precursor_dataset.NUM_LABEL, num_heads=4, hidden_dim=64, hidden_layers=4)
model.forward(target  = feat['target'], 
              context = feat['context'],)
#tr = TestSequenceTrainer(model, lr=1e-4)
#for i in range(5):
#    train_loss = tr.train(train_dl)
#    valid_loss, out = tr.test(valid_dl)
#    test_loss, out = tr.test(test_dl)
#    print('{:4d} {:9.5f} {:9.5f} {:9.5f}'.format(i, train_loss, valid_loss, test_loss))

tensor([[[-4.5063e-01,  5.7888e-01,  5.0418e-01,  ..., -3.0448e-01,
          -3.8087e-01, -7.0146e-02],
         [ 2.0862e-01,  8.1354e-01,  5.6663e-01,  ...,  2.0481e-01,
           2.8978e-01,  9.8414e-02],
         [-2.5659e-01,  8.2889e-01,  1.0078e+00,  ..., -4.5890e-01,
           9.8797e-01, -2.6617e-01],
         ...,
         [-1.0256e-01,  1.0854e+00,  7.8896e-01,  ..., -5.8150e-01,
           7.1035e-01,  2.4221e-01],
         [ 8.1696e-02,  8.7025e-01,  9.6989e-01,  ..., -8.9164e-01,
           3.5854e-01,  5.2547e-01],
         [-4.7880e-02,  1.0109e+00,  8.1174e-01,  ..., -3.8134e-01,
           9.7241e-01,  4.4252e-01]],

        [[-6.0453e-01,  3.7663e-01,  4.2093e-01,  ..., -3.4342e-01,
          -4.8789e-01, -3.8142e-01],
         [-5.7945e-01,  4.0511e-01,  5.2471e-01,  ...,  6.0483e-01,
           8.2484e-01,  2.2529e-01],
         [-6.9531e-01,  1.0187e-01,  1.3416e+00,  ..., -1.5803e-01,
           1.0385e+00, -3.0737e-01],
         ...,
         [-4.5443e-02,  8

In [28]:
pred = out['pred'].argmax(-1)
label = out['label']
weight = out['weight']
n_data, l_seq = pred.shape
mask = np.hstack([np.ones((n_data, 1), dtype=bool), (label != EOS_LABEL)[..., :-1]])
acc_rxn = []
for p, l, m, w in zip(pred, label, mask, weight):
    hit = (p[m] != l[m]).sum() == 0
    acc_rxn.append(hit)
len(acc_rxn), weight.shape

(27967, (27967,))

In [39]:
model = TransformerDecoderBlock(DS.num_condition_feat, num_heads=4, hidden_dim=64, hidden_layers=4)
tr = SequenceTrainer(model, lr=1e-4)
for i in range(100):
    train_loss = tr.train(train_dl)
    valid_loss, out = tr.test(valid_dl)
    test_loss, out = tr.test(test_dl)
    print('{:4d} {:9.5f} {:9.5f} {:9.5f}'.format(i, train_loss, valid_loss, test_loss))

   0   1.99951   1.44498   1.44493
   1   1.32844   1.18909   1.18908
   2   1.15493   1.06664   1.06657
   3   1.05857   0.98836   0.98805
   4   0.99526   0.93284   0.93303
   5   0.94666   0.88761   0.88701
   6   0.90799   0.85305   0.85327
   7   0.87553   0.82254   0.82096
   8   0.85084   0.79534   0.79469
   9   0.82855   0.77896   0.77786
  10   0.80872   0.75424   0.75411
  11   0.79169   0.73576   0.73554
  12   0.77698   0.72410   0.72531
  13   0.76392   0.70815   0.70855
  14   0.75060   0.69569   0.69644
  15   0.73886   0.68417   0.68380
  16   0.73034   0.67255   0.67398
  17   0.72001   0.66320   0.66334
  18   0.71236   0.65547   0.65561
  19   0.70242   0.64979   0.64865
  20   0.69550   0.64075   0.64030
  21   0.69033   0.63113   0.63170
  22   0.68170   0.62679   0.62562
  23   0.67514   0.61784   0.61814
  24   0.66914   0.61233   0.61180
  25   0.66434   0.60926   0.60969
  26   0.65751   0.60311   0.60178
  27   0.65465   0.59412   0.59472
  28   0.64847   0.5

In [65]:
def generate(self, context, max_len=20):
        output_seq = torch.ones(context.shape[0], 1).long().to(self.device) * 443
        for _ in range(max_len - 1):
            output = self.forward(output_seq, context)
            output_seq = torch.hstack([output_seq, output.argmax(-1)[:, -1:]])
        seq = output_seq.cpu().numpy()[:, 1:]
        j = (seq != EOS_LABEL).sum(1).max()
        return seq[:, :j]

generate(model, feat['context'][:3].to('cuda'))

array([[5, 5, 5, 5, 8, 5, 5, 8, 5, 8, 5, 8, 5, 8, 5, 8, 5, 8, 5],
       [5, 5, 5, 5, 8, 5, 5, 8, 5, 8, 5, 8, 5, 8, 5, 8, 5, 8, 5],
       [5, 5, 5, 8, 5, 5, 5, 8, 5, 8, 5, 8, 5, 8, 5, 8, 5, 8, 5]])

In [64]:
for _prd, _lbl in zip(out_gen, feat['label'].cpu().numpy()):
    prd = np.array(sorted(_prd[_prd != EOS_LABEL]))
    lbl = np.array(sorted(_lbl[_lbl != EOS_LABEL]))
    if len(prd) != len(lbl):
        print(prd, lbl)
    elif np.sum(prd != lbl) != 0:
        print(prd, lbl)

[2 5] [  5 148]
[2 5] [  5 148]
[2 5] [  2 107]
[2 5] [  5 108]
[5 8] [  8 124]
[  5 106] [  5 189]
[ 0 16 29] [  0  16 113]
[ 0 16 29] [  0  16 113]
[ 0 16 29] [  0  16 113]
[ 0 16 29] [  0  16 182]
[ 0 16 29] [  0  16 182]
[  8  14 126 161] [ 14  66 108 250]
[  8  14 126 161] [ 14  66 108 250]
[  8  14 126 161] [ 14  66 108 250]
[ 8 51] [ 8 24 51]
[ 8 51] [ 8 24 51]
[ 8 51] [ 8 24 51]
[ 7 18 29] [  7  29 177]
[ 7 18 29] [  7  29 177]
[ 7 18 29] [  7  29 177]
[ 7 18 29] [  7  29 177]
[ 7 18 29] [  7  29 177]
[ 7 18 29] [  7  88 303]
[ 7 18 29] [  7  88 113]
[ 7 18 29] [  7  88 113]
[ 7 18 29] [  7  88 113]


In [70]:
out_gen

array([[  5,   2, 444, 444],
       [  5,   2, 444, 444],
       [  5,   2, 444, 444],
       ...,
       [ 12,   9, 444, 444],
       [ 12,   9, 444, 444],
       [ 12,   9, 444, 444]])

In [48]:
for l, m in zip(feat['label'], feat['target'] != 444):
    print(l[m].cpu(), l)

tensor([  0,  13,  14,   4, 444]) tensor([  0,  13,  14,   4, 444, 444, 444], device='cuda:0')
tensor([  5,  65, 444]) tensor([  5,  65, 444, 444, 444, 444, 444], device='cuda:0')
tensor([  4,  76, 444]) tensor([  4,  76, 444, 444, 444, 444, 444], device='cuda:0')
tensor([ 27,  49, 444]) tensor([ 27,  49, 444, 444, 444, 444, 444], device='cuda:0')
tensor([  6, 197,  18, 444]) tensor([  6, 197,  18, 444, 444, 444, 444], device='cuda:0')
tensor([ 50,   2,   5,   0, 444]) tensor([ 50,   2,   5,   0, 444, 444, 444], device='cuda:0')
tensor([  1,  32,  25, 444]) tensor([  1,  32,  25, 444, 444, 444, 444], device='cuda:0')
tensor([  3,   0,   1, 444]) tensor([  3,   0,   1, 444, 444, 444, 444], device='cuda:0')
tensor([  2,   7,  13, 444]) tensor([  2,   7,  13, 444, 444, 444, 444], device='cuda:0')
tensor([ 79,  27,  49, 444]) tensor([ 79,  27,  49, 444, 444, 444, 444], device='cuda:0')
tensor([  6,   1,  30, 444]) tensor([  6,   1,  30, 444, 444, 444, 444], device='cuda:0')
tensor([ 10,   

# profiling

In [2]:
import sys
sys.path.append('..')
from src.data import ReactionDataset
from src.networks import TransformerDecoderBlock
import torch, gc
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch.profiler import profile, record_function, ProfilerActivity
import numpy as np

In [3]:
DS = ReactionDataset(feat_type='cgcnn')
DS.from_file('../data/screened_conditional_reaction.pkl.gz', 
             heat_temp_key=('heat_temp','median'))
#DS.to('cuda')

years = np.array([d.year for d in DS])
train_mask = years < 2016
valid_mask = (years >= 2016) & years < 2018
test_mask = years >= 2018

In [40]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.cuda.init()

train_dl = DataLoader(DS, batch_size=64, sampler=SubsetRandomSampler(np.where(train_mask)[0]), 
                      collate_fn=DS.cfn)#num_workers=1, prefetch_factor=4, collate_fn=DS.cfn)

valid_dl = DataLoader(DS, batch_size=2048, sampler=np.where(train_mask)[0], collate_fn=DS.cfn)
test_dl = DataLoader(DS, batch_size=2048, sampler=np.where(train_mask)[0], collate_fn=DS.cfn)

model = TransformerDecoderBlock(DS.num_condition_feat)
model.to('cuda')
opt = torch.optim.AdamW(model.parameters(), lr=1e-4)
crit = torch.nn.CrossEntropyLoss(reduction='none')

with profile(
    activities=[
        ProfilerActivity.CPU,
        ProfilerActivity.CUDA,
    ],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler('../dump/log_dir'),
    record_shapes=True,
    profile_memory=True,
    with_stack=True
    ) as prof:


    for _feat, _ in train_dl:
        feat = {k:v.to('cuda') for k,v in _feat.items()}
        n_batch, l_seq = feat['label'].shape
        with record_function('model_inference'):
            pred = model(**feat)
#        with record_function('0_compute_loss'):
            _loss = crit(pred.reshape(n_batch * l_seq, -1), feat['label'].view(-1))
#            loss = (_loss * weight)[mask].mean()
            loss = (_loss * feat['weight']).mean()
        with record_function('backward_pass'):
            opt.zero_grad()
            loss.backward()
            opt.step()
        prof.step()
gc.collect()
torch.cuda.empty_cache()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=100))
#print(prof.key_averages().table(row_limit=100))

STAGE:2024-06-14 14:23:23 787949:787949 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-06-14 14:23:23 787949:787949 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-06-14 14:23:23 787949:787949 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2024-06-14 14:23:24 787949:787949 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2024-06-14 14:23:24 787949:787949 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2024-06-14 14:23:24 787949:787949 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         0.90%     322.000us        79.97%      28.589ms       9.530ms       0.000us         0.00%       1.259ms     419.667us      29.19 Kb     -61.88 Kb           0 b     -97.50 K

In [11]:
print(prof.key_averages().table(sort_by="cpu_time", row_limit=100))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*         0.58%     434.000us        90.70%      68.182ms      22.727ms       0.000us         0.00%       3.002ms       1.001ms     116.75 Kb    -233.50 Kb    -725.00 Kb           0 

In [13]:
with open('../dump/profiler_B64_NW1_PF0_indexed.txt','w') as f: 
    f.write(prof.key_averages().table(sort_by="cpu_time_total", row_limit=100))