In [1]:
import sys, torch, os
sys.path.append('..')
from src.networks import TransformerDecoderBlock
from src.data import ReactionDataset
from src.feature import NUM_LABEL, EOS_LABEL, SOS_LABEL
from src.trainer import BaseTrainer
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import matplotlib.pyplot as plt

In [2]:
DS = ReactionDataset(feat_type='cgcnn')
DS.from_file('../data/screened_conditional_reaction.pkl.gz', 
             heat_temp_key=('heat_temp','median'))
DS.to('cuda')

In [3]:
years = np.array([d.year for d in DS])
train_mask = years < 2016
valid_mask = (years >= 2016) & years < 2018
test_mask = years >= 2018

train_dl = DataLoader(DS, batch_size=64, sampler=SubsetRandomSampler(np.where(train_mask)[0]), collate_fn=DS.cfn)
valid_dl = DataLoader(DS, batch_size=256, sampler=np.where(train_mask)[0], collate_fn=DS.cfn)
test_dl = DataLoader(DS, batch_size=256, sampler=np.where(train_mask)[0], collate_fn=DS.cfn)

In [10]:
feat, info = next(iter(valid_dl))
for k,v in feat.items():
    print(k, v.shape, v.device)

target torch.Size([256, 7]) cuda:0
label torch.Size([256, 7]) cuda:0
context torch.Size([256, 92]) cuda:0
weight torch.Size([1792]) cuda:0
mask torch.Size([1792]) cuda:0


In [38]:
class SequenceTrainer(BaseTrainer):
    def __init__(self, model, lr, device='cuda', crit=torch.nn.CrossEntropyLoss(reduction='none')):
        super().__init__(model, lr, device, crit)
    
    def _eval_batch(self, batch, compute_loss=True):
        feat, _ = batch
        n_batch, l_seq = feat['label'].shape
        pred = self.model(**feat)
        if compute_loss:
            _loss = self.crit(pred.reshape(n_batch * l_seq, -1), feat['label'].reshape(-1))
#            loss = (_loss * feat['weight'])[feat['mask']].mean()
            loss = (_loss * feat['weight']).mean()
            return loss, pred.detach().cpu().numpy()
        else:
            return pred.detach().cpu().numpy()
    
    def _parse_output(self, batch, output):
        feat, info = batch
        if self._output is None:
            self._output = {
                'info' : info,
                'pred' : output
            }
            if feat['weight'] is not None:
                self._output.update({
                    'label': feat['label'].cpu().numpy(),
                    'weight': feat['weight'].cpu().numpy(),
                    'mask': feat['mask'].cpu().numpy(),
                })
        else:
            self._output['info'].extend(info)
            self._output['pred'] = np.vstack([self._output['pred'], output])
            if feat['weight'] is not None:
                self._output['label'] = np.vstack([self._output['label'], feat['label'].cpu().numpy()])
                self._output['weight'] = np.hstack([self._output['weight'], feat['weight'].cpu().numpy()])
                self._output['mask'] = np.hstack([self._output['mask'], feat['mask'].cpu().numpy()])

#    def predict(self, dataloader, n_samples=1000):


In [29]:
model = TransformerDecoderBlock(DS.num_condition_feat, num_heads=4, hidden_dim=64, hidden_layers=4)
tr = SequenceTrainer(model, lr=1e-4)
for i in range(100):
    train_loss = tr.train(train_dl)
    valid_loss, out = tr.test(valid_dl)
    test_loss, out = tr.test(test_dl)
    print('{:4d} {:9.5f} {:9.5f} {:9.5f}'.format(i, train_loss, valid_loss, test_loss))

   0   3.13832   2.52148   2.52187
   1   2.47638   2.18071   2.18079
   2   2.21579   1.98539   1.98508
   3   2.05256   1.85142   1.85024
   4   1.94766   1.75977   1.75768
   5   1.86269   1.68644   1.68594
   6   1.79632   1.62716   1.62804
   7   1.74248   1.57574   1.57592
   8   1.69888   1.53373   1.53657
   9   1.66051   1.49665   1.49627
  10   1.62338   1.46147   1.45949
  11   1.59392   1.43348   1.43367
  12   1.56763   1.40294   1.40394
  13   1.54363   1.38331   1.38248
  14   1.52056   1.36490   1.36676
  15   1.50192   1.33983   1.34271
  16   1.48459   1.32154   1.32077
  17   1.46405   1.30720   1.31126
  18   1.45042   1.29123   1.29500
  19   1.43399   1.27776   1.27806
  20   1.42522   1.26168   1.26235
  21   1.40874   1.25202   1.25043
  22   1.39563   1.24982   1.25063
  23   1.38589   1.22967   1.23000
  24   1.37682   1.21290   1.21432
  25   1.36640   1.20665   1.20634
  26   1.35416   1.19996   1.20039
  27   1.34338   1.18648   1.18534
  28   1.33754   1.1

In [39]:
model = TransformerDecoderBlock(DS.num_condition_feat, num_heads=4, hidden_dim=64, hidden_layers=4)
tr = SequenceTrainer(model, lr=1e-4)
for i in range(100):
    train_loss = tr.train(train_dl)
    valid_loss, out = tr.test(valid_dl)
    test_loss, out = tr.test(test_dl)
    print('{:4d} {:9.5f} {:9.5f} {:9.5f}'.format(i, train_loss, valid_loss, test_loss))

   0   1.99951   1.44498   1.44493
   1   1.32844   1.18909   1.18908
   2   1.15493   1.06664   1.06657
   3   1.05857   0.98836   0.98805
   4   0.99526   0.93284   0.93303
   5   0.94666   0.88761   0.88701
   6   0.90799   0.85305   0.85327
   7   0.87553   0.82254   0.82096
   8   0.85084   0.79534   0.79469
   9   0.82855   0.77896   0.77786
  10   0.80872   0.75424   0.75411
  11   0.79169   0.73576   0.73554
  12   0.77698   0.72410   0.72531
  13   0.76392   0.70815   0.70855
  14   0.75060   0.69569   0.69644
  15   0.73886   0.68417   0.68380
  16   0.73034   0.67255   0.67398
  17   0.72001   0.66320   0.66334
  18   0.71236   0.65547   0.65561
  19   0.70242   0.64979   0.64865
  20   0.69550   0.64075   0.64030
  21   0.69033   0.63113   0.63170
  22   0.68170   0.62679   0.62562
  23   0.67514   0.61784   0.61814
  24   0.66914   0.61233   0.61180
  25   0.66434   0.60926   0.60969
  26   0.65751   0.60311   0.60178
  27   0.65465   0.59412   0.59472
  28   0.64847   0.5

In [47]:
out['pred'].argmax(-1)[:30, :5] - out['label'][:30, :5]

array([[3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [0, 0, 0, 0, 0],
       [3, 0, 0, 0, 0],
       [3, 0, 0, 0, 0]])