In [1]:
import os, sys

parent_dir = os.path.abspath('..')
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

from src import utils

log = utils.get_logger(__name__)

In [2]:
%load_ext autoreload
%autoreload 2

## Data Modules

In [188]:
import torch
import numpy as np
from glob import glob


from ptls.data_load import IterableChain
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from ptls.data_load.datasets.parquet_dataset import ParquetDataset, ParquetFiles

dataset_conf = {
    'min_seq_len': 10,
    'event_col': 'small_group',
    'time_col' : 'event_time',
    'event_cnt_col': 'trx_count',
    'num_types': 250,
    'target_col': 'target',
    }

train_data = glob('/home/morlov/ptls-experiments/scenario_age_pred/data/train_trx_file.parquet')
valid_data = glob('/home/morlov/ptls-experiments/scenario_age_pred/data/test_trx_file.parquet')

            
class TimeProc(IterableProcessingDataset):
    
    def __init__(self, time_col, tmin, tmax):
        super().__init__()
        self._time_col = time_col
        self.tmin, self.tmax = tmin, tmax
        
    def __iter__(self):
        for rec in self._src:
            features = rec[0] if type(rec) is tuple else rec
            rec[self._time_col] = np.array((features[self._time_col]-self.tmin)/(self.tmax-self.tmin))
            yield rec
            

def pp_collate_fn(time_col, event_col, event_cnt_col, return_len=False):

    def fn(batch):
        es, ts, ls = [], [], []
        for rec in batch:
            ts.append(rec[time_col])
            es.append(rec[event_col])
            ls.append(rec[event_cnt_col] - 1)
            
        ret = [torch.nn.utils.rnn.pad_sequence(ts, batch_first=True), 
               torch.nn.utils.rnn.pad_sequence(es, batch_first=True).long()]
        
        if return_len:
            return ret, torch.tensor(ls)
        else:
            return ret
            
    return fn

process = IterableChain(
            SeqLenFilter(min_seq_len=dataset_conf['min_seq_len']),
            TimeProc(dataset_conf['time_col'], 0, 1000),
            ToTorch()
)
   
train_ds = ParquetDataset(train_data, post_processing=process)
valid_ds = ParquetDataset(valid_data, post_processing=process)

collate_fn = pp_collate_fn(dataset_conf['time_col'], dataset_conf['event_col'], dataset_conf['event_cnt_col'])

train_dl = torch.utils.data.DataLoader(
                        dataset=train_ds,
                        collate_fn=collate_fn,
                        num_workers=4,
                        batch_size=32)

valid_dl = torch.utils.data.DataLoader(
                        dataset=valid_ds,
                        collate_fn=collate_fn,
                        num_workers=4,
                        batch_size=32)

## COTIC

In [4]:
from src.models.components.cont_cnn import CCNN
from src.models.components.cont_cnn import Kernel
from src.models.components.cont_cnn import PredictionHead


nb_filters = 16
num_types = dataset_conf['num_types'] 

kernel = Kernel(hidden1=8, hidden2=4, hidden3=8, in_channels=nb_filters, out_channels=nb_filters)

head = PredictionHead(in_channels=nb_filters, num_types=num_types)

net = CCNN(in_channels=32, kernel_size=5, nb_filters=nb_filters, nb_layers=9,
           num_types=num_types, kernel=kernel, head=head)

## Event Module

In [5]:
from src.models.base_model import BaseEventModule
from src.metrics.cont_cnn import CCNNMetrics
from src.utils.metrics import MetricsCore
from src.utils.metrics import MAE, Accuracy
import torch

from omegaconf import OmegaConf,open_dict


train_conf = OmegaConf.create({'optimizer': {"name": "adam", "params": {"lr": 0.01, "weight_decay": 1e-8}}, 
                               'scheduler': {"milestones": [40, 75], "gamma": 0.1, "step": None}})


metrics = CCNNMetrics(return_time_metric = MAE(),
                      event_type_metric = Accuracy(),
                      type_loss_coeff = 1,
                      time_loss_coeff = 10,
                      sim_size = 40,
                      reductions = {'log_likelihood': 'mean','type': 'sum', 'time': 'mean'}
                     )

# Init lightning model
model = BaseEventModule(net = net,
                        metrics = metrics,
                        optimizer = train_conf.optimizer,
                        scheduler = train_conf.scheduler,
                        head_start = 1)

  rank_zero_warn(


## Wavenet

In [6]:
# from src.models.base_model import BaseEventModule
# from src.metrics.baselines.wavenet import WNMetrics
# from src.models.components.baselines.wavenet import WaveNet
# from src.utils.metrics import MetricsCore
# from src.utils.metrics import MAE, Accuracy
# import torch

# from omegaconf import OmegaConf,open_dict



# metrics = WNMetrics(return_time_metric = MAE(), event_type_metric = Accuracy())
        
# wn_conf = OmegaConf.create({'optimizer': {"name": "adam", "params": {"lr": 0.01, "weight_decay": 1e-5}}})

# net = WaveNet(in_channels=32, num_types=250, hyperparams={'nb_layers': 9, 'kernel_size': 5, 'nb_filters': 16})

# wn_model = BaseEventModule(net = net,
#                            metrics = metrics,
#                            optimizer = wn_conf.optimizer,
#                            head_start = 1)

## Callbacks and Loggers

In [7]:
# Init lightning callbacks
from pytorch_lightning.callbacks import RichModelSummary, RichProgressBar, EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger

model_checkpoint = ModelCheckpoint(monitor="val/log_likelihood", mode="max",  save_top_k=1,  save_last=True,
                                   verbose=False, dirpath="checkpoints/", filename="epoch_{epoch:03d}",
                                   auto_insert_metric_name=False)  

early_stopping = EarlyStopping(monitor="val/log_likelihood", mode="max", patience=100, min_delta=0)
model_summary = RichModelSummary(max_depth=-1)
rich_progress_bar = RichProgressBar()

callbacks = [model_checkpoint, early_stopping, model_summary, rich_progress_bar]


tensorboard = TensorBoardLogger(save_dir="tensorboard", prefix="", default_hp_metric=True, log_graph=False)
logger = [tensorboard]

## Trainer

In [None]:
from pytorch_lightning import Trainer

trainer = Trainer(gpus=[0], 
                  max_steps=10000,
                  limit_val_batches=100,
                  val_check_interval=1000,
                  accumulate_grad_batches=10, 
                  gradient_clip_val=1,
                  callbacks=callbacks, 
                  logger=logger)

# Train the model
log.info("Starting training!")
print(f'logger.version = {trainer.logger.version}')
trainer.fit(model, train_dl, valid_dl)

Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


logger.version = 6


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Output()

In [9]:
torch.save(model.state_dict(), "cotic-age.pt")

# Downstream

In [10]:
model.load_state_dict(torch.load("cotic-age.pt"))

<All keys matched successfully>

In [189]:
valid_ds = ParquetDataset(valid_data, post_processing=process)

valid_dl = torch.utils.data.DataLoader(
                        dataset=valid_ds,
                        collate_fn=pp_collate_fn(dataset_conf['time_col'], 
                                                 dataset_conf['event_col'], 
                                                 dataset_conf['event_cnt_col'],
                                                 return_len=True), 
                        num_workers=8,
                        batch_size=128)

In [190]:
def generate_next(batch, model):
    
    batch, lens = batch
    times, types = batch
    n = times.size()[0]
    idx = torch.arange(n)
    
    with torch.no_grad():
        event_times_pred, event_types_pred = model(batch)[1]
    
    times = F.pad(input=times, pad=(0, 1), mode='constant', value=0)
    types = F.pad(input=types, pad=(0, 1), mode='constant', value=0)
    
    next_times = -event_times_pred.squeeze()[idx, lens+1] * torch.log(1 - torch.rand(n)) + times[idx, lens]
    event_types_pred_softmax = torch.softmax(event_types_pred[idx, lens+1, :], dim=1)
    next_types = torch.multinomial(event_types_pred_softmax, num_samples=1).T
    
    times[idx, lens + 1] = next_times
    types[idx, lens + 1] = next_types
    lens = lens + 1
    
    return (times, types), lens

In [191]:
batch = next(iter(valid_dl))
batch_, lens = batch
times, types = batch_

In [192]:
times.shape, types.shape

(torch.Size([128, 1147]), torch.Size([128, 1147]))

In [224]:
new_batch  = generate_next(batch, model)
new_batch  = generate_next(new_batch, model)
new_batch  = generate_next(new_batch, model)
new_batch_, new_lens = new_batch
new_times, new_types = new_batch_

In [225]:
new_times.shape, new_types.shape

(torch.Size([128, 1150]), torch.Size([128, 1150]))

In [231]:
n = 9

In [232]:
times[n, lens[n]:lens[n]+5], new_times[n, lens[n]:lens[n]+5]

(tensor([0.7290, 0.0000, 0.0000, 0.0000, 0.0000]),
 tensor([0.7290, 0.7320, 0.7322, 0.7338, 0.0000]))

In [233]:
types[n, lens[n]:lens[n]+5], new_types[n, lens[n]:lens[n]+5]

(tensor([2, 0, 0, 0, 0]), tensor([ 2,  7, 11, 13,  0]))