In [2]:
%load_ext autoreload
%autoreload 2

import logging
import torch
import pytorch_lightning as pl
import warnings

warnings.filterwarnings('ignore')
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

In [3]:
from glob import glob
import numpy as np
import logging
import pytorch_lightning as pl
import torch
from tqdm.auto import tqdm
import os
import pandas as pd
from pytorch_lightning.loggers import TensorBoardLogger


logger = logging.getLogger(__name__)

In [4]:
import hydra
from omegaconf import OmegaConf

conf = OmegaConf.load('mles_params.yaml')
model = hydra.utils.instantiate(conf.pl_module)
model.load_state_dict(torch.load("models_alpha_battle/coles_pretrain.pth"))

<All keys matched successfully>

In [5]:
from glob import glob
from ptls.data_load.iterable_processing_dataset import IterableProcessingDataset
from ptls.data_load.iterable_processing.target_move import TargetMove
from ptls.data_load.iterable_processing.target_empty_filter import TargetEmptyFilter
from ptls.data_load import padded_collate, padded_collate_wo_target
from ptls.data_load.iterable_processing.to_torch_tensor import ToTorch
from ptls.data_load.datasets import MemoryMapDataset
from tqdm.auto import tqdm

from ptls.data_load import IterableChain
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load.datasets.parquet_dataset import ParquetDataset, ParquetFiles
from ptls.data_load.utils import collate_feature_dict


from ptls.frames import PtlsDataModule

train_data = glob('train_data_not_agg.parquet')
valid_data = glob('valid_data_not_agg.parquet')

feature_cols = ['mcc', 'amnt', 'hour_diff']

target_cols = ['mcc', 'amnt', 'hour_diff']

dataset_conf = {
    'min_seq_len':0,
}



class SeqToTargetMultiheadDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data,
                 feature_cols,
                 target_cols,
                 target_dtype=None,
                 *args, **kwargs,
                 ):
        super().__init__(*args, **kwargs)

        self.data = data
        self.feature_cols = feature_cols
        self.target_cols = target_cols
        
        if type(target_dtype) is str:
            self.target_dtype = getattr(torch, target_dtype)
        else:
            self.target_dtype = target_dtype

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        feature_arrays = self.data[item]
        return feature_arrays

    def __iter__(self):
        for feature_arrays in self.data:
            yield feature_arrays


    def collate_fn(self, batch):
        
        targets = []
        values = []
        
        for target_col in target_cols:
            targets.append(torch.tensor([rec[target_col][-1] for rec in batch]).to(self.target_dtype[target_col]))
        
        for rec in batch:
            values.append({k: v[:-1] for k, v in rec.items() if k in feature_cols})
    
        return padded_collate_wo_target(values), targets

process = IterableChain(
            SeqLenFilter(min_seq_len=dataset_conf['min_seq_len']),
            ToTorch()
            )
    
def get_dataset(data):
    ds = MemoryMapDataset(ParquetDataset(data, post_processing=process))
    return SeqToTargetMultiheadDataset(ds, feature_cols, target_cols, target_dtype = {'mcc': torch.long, 'amnt': torch.float, 'hour_diff': torch.float})

train_ds = get_dataset(train_data)
valid_ds = get_dataset(valid_data)

dm = PtlsDataModule(
    train_data=train_ds,
#     valid_data=valid_ds,
    train_num_workers=4,
    train_batch_size=64)

In [6]:
import logging
from copy import deepcopy
from typing import List

import pandas as pd
import pytorch_lightning as pl
import torch
import torchmetrics
from omegaconf import DictConfig

from ptls.data_load.padded_batch import PaddedBatch

logger = logging.getLogger(__name__)


class SequenceToTargetMultihead(pl.LightningModule):


    def __init__(self,
                 seq_encoder: torch.nn.Module,
                 heads: List[torch.nn.Module],
                 losses: List[torch.nn.Module],
                 metric_list: torchmetrics.Metric=None,
                 optimizer_partial=None,
                 lr_scheduler_partial=None,
                 pretrained_lr=None,
                 train_update_n_steps=None,
                 ):
        super().__init__()

        self.save_hyperparameters(ignore=[
            'seq_encoder', 'heads', 'losses', 'metric_list', 'optimizer_partial', 'lr_scheduler_partial'])

        self.seq_encoder = seq_encoder
        self.heads = heads
        self.losses = losses
        self.n_heads = len(heads)

        self.optimizer_partial = optimizer_partial
        self.lr_scheduler_partial = lr_scheduler_partial

    def forward(self, x):
        x = self.seq_encoder(x)
        xs = [head(x) for head in self.heads]
        return xs

    def training_step(self, batch, _):
        x, y = batch
        y_hs = self(x)
        loss = sum([loss(y_hs[i], y[i]) for i, loss in enumerate(self.losses)])
        self.log('loss', loss)
        return loss

    def validation_step(self, batch, _):
        x, y = batch
        y_hs = self(x)
        loss = sum([loss(y_hs[i], y[i]) for i, loss in enumerate(self.losses)])
        self.log('val_loss', loss)


    def configure_optimizers(self):
        if self.hparams.pretrained_lr is not None:
            if self.hparams.pretrained_lr == 'freeze':
                for p in self.seq_encoder.parameters():
                    p.requires_grad = False
                parameters = self.parameters()
            else:
                parameters = [
                    {'params': self.seq_encoder.parameters(), 'lr': self.hparams.pretrained_lr},
                ] + [{'params': head.parameters()} for head in self.heads]
        else:
            parameters = self.parameters()

        optimizer = self.optimizer_partial(parameters)
        scheduler = self.lr_scheduler_partial(optimizer)
        return [optimizer], [scheduler]

In [7]:
from functools import partial
import torch
import torchmetrics
from ptls.nn import Head


head_mcc = Head(input_size=model.seq_encoder.embedding_size, 
                use_batch_norm=True,
                hidden_layers_sizes=[128],
                objective='classification',
                num_classes=109).to('cuda:0')

head_amnt = Head(input_size=model.seq_encoder.embedding_size, 
                 use_batch_norm=True,
                 hidden_layers_sizes=[128],
                 objective='softplus').to('cuda:0')

head_hour_diff = Head(input_size=model.seq_encoder.embedding_size, 
                      use_batch_norm=True,
                      hidden_layers_sizes=[128],
                      objective='softplus').to('cuda:0')

model_multihead = SequenceToTargetMultihead(
    seq_encoder=model.seq_encoder,
    heads=[head_mcc, head_amnt, head_hour_diff],
    losses=[torch.nn.NLLLoss(), torch.nn.L1Loss(), torch.nn.L1Loss()],
#     metric_list=torchmetrics.Accuracy(compute_on_step=False),
    pretrained_lr=0.00001,
    optimizer_partial=partial(torch.optim.Adam, lr=0.001), # , weight_decay=1e-5
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=1, gamma=0.9),
)

In [None]:
logger = TensorBoardLogger('src/ptls-experiments/scenario_alpha_battle/lightning_logs',
                           name='coles-baseline-alpha-battle')

trainer = pl.Trainer(
    num_sanity_val_steps=0,
    gpus=1,
    auto_select_gpus=False,
    max_epochs=10,
    enable_checkpointing=False,
    deterministic=True,
    logger=logger,
)

trainer.fit(model_multihead, dm)
print(trainer.logged_metrics)

Training: 0it [00:00, ?it/s]

In [10]:
torch.save(model_multihead.state_dict(), "models_alpha_battle/cpc-multi-head.pth")

In [11]:
model_multihead.load_state_dict(torch.load("models_alpha_battle/cpc-multi-head.pth"))

<All keys matched successfully>

In [14]:
# %%time
import tqdm

def inference(model, dl, device='cuda:0'):
    
    model.to(device)
    X = []
    for batch in tqdm.tqdm(dl):
        with torch.no_grad():
            features = batch[0]
            targets = [t.unsqueeze(dim=1).to(device) for t in batch[1]]
            x = model(features.to(device))
            mcc = torch.argmax(x[0], dim=1, keepdim=True)
            amnt = x[1].unsqueeze(dim=1)
            hour_diff = x[2].unsqueeze(dim=1)
            predicted = [mcc, amnt, hour_diff]
            X += [torch.cat(predicted + targets, dim=1)]
    return X


valid_dl = torch.utils.data.DataLoader(dataset=valid_ds, 
                                       collate_fn=valid_ds.collate_fn,
                                       num_workers=0,
                                       batch_size=128)

In [15]:
preds = torch.vstack(inference(model_multihead, valid_dl)).cpu().numpy()

100%|██████████| 3358/3358 [00:48<00:00, 68.73it/s]


In [16]:
import numpy as np

df_valid = pd.DataFrame(preds, columns = ['predicted_mcc', 'predicted_amnt', 'predicted_hour_diff', 'mcc', 'amnt', 'hour_diff'])
df_valid.head()

Unnamed: 0,predicted_mcc,predicted_amnt,predicted_hour_diff,mcc,amnt,hour_diff
0,2.0,0.43145,28.729313,9.0,0.352101,0.0
1,1.0,0.309665,23.632116,2.0,0.539584,5.0
2,61.0,0.304868,13.139742,61.0,0.232139,0.0
3,1.0,0.308173,5.65425,1.0,0.352706,96.0
4,5.0,0.350544,4.952467,2.0,0.477476,258.0


In [17]:
from sklearn.metrics import accuracy_score

print("Accuracy:", {accuracy_score(df_valid['mcc'],  df_valid['predicted_mcc'])})

Accuracy: {0.3485738089087782}


In [18]:
from sklearn.metrics import mean_absolute_error

print("Mae amnt:", {mean_absolute_error(df_valid['amnt'],  df_valid['predicted_amnt'])})

Mae amnt: {0.0768877}


In [19]:
from sklearn.metrics import mean_absolute_error

print("MAE:", {mean_absolute_error(df_valid['hour_diff'],  df_valid['predicted_hour_diff'])})

MAE: {41.097492}
