In [2]:
import pickle as pkl
import torch
from torch import nn
from easy_tpp.preprocess.event_tokenizer import EventTokenizer
from easy_tpp.config_factory import DataSpecConfig
from models.encoders.gru import GRUTPPEncoder
from models.decoders.rmtpp import RMTPPDecoder, RMTPPLoss

In [3]:
class TPPModel(nn.Module):
    def __init__(self, config, hidden_dim, mlp_dim, device):
        super(TPPModel, self).__init__()
        self.encoder = GRUTPPEncoder(config, hidden_dim=hidden_dim)
        self.decoder = RMTPPDecoder(hidden_dim=hidden_dim, num_event_types=config.num_event_types, mlp_dim=mlp_dim, device=device)
        self.criterion = RMTPPLoss(device=device, ignore_index=config.pad_token_id)

    def forward(self, batch):
        hidden_states = self.encoder(batch)
        decoder_output = self.decoder(hidden_states)
        return decoder_output

    def compute_loss(self, batch, decoder_output):
        time_loss, mark_loss, total_loss = self.criterion(
            decoder_output,
            batch["time_delta_seqs"],
            batch["type_seqs"],
            batch["sequence_length"]
        )
        return time_loss, mark_loss, total_loss

In [7]:
dev_dict = pkl.load(open('data/earthquake/dev.pkl', 'rb'))
test_dict = pkl.load(open('data/earthquake/test.pkl', 'rb'))
train_dict = pkl.load(open('data/earthquake/train.pkl', 'rb'))

def prepare_data(raw_data, config):
    tokenizer = EventTokenizer(config)
    tokenizer.padding_side = 'right'

    input_data = {
        'time_seqs': [[x["time_since_start"] for x in seq] for seq in raw_data],
        'type_seqs': [[x["type_event"] for x in seq] for seq in raw_data],
        'time_delta_seqs': [[x["time_since_last_event"] for x in seq] for seq in raw_data]
    }
    
    filtered_data = {
        k: [seq for seq in v if len(seq) > 0]
        for k, v in input_data.items()
    }
    
    sequence_length = torch.tensor([len(seq) for seq in filtered_data['type_seqs']])
    
    batch = tokenizer.pad(filtered_data, return_tensors='pt', return_attention_mask=False)
    batch['sequence_length'] = sequence_length
    
    return batch


dim_process = train_dict['dim_process']
config = DataSpecConfig.parse_from_yaml_config({
    'num_event_types': dim_process,
    'pad_token_id': dim_process
})


In [8]:
HIDDEN_DIM = 128
MLP_DIM = 64
device = 'cpu' if not torch.cuda.is_available() else 'cuda'
train_data = train_dict['train']
processed_data = prepare_data(train_data, config)
model = TPPModel(config, hidden_dim=HIDDEN_DIM, mlp_dim=MLP_DIM, device=device).to(device)
decoder_output = model(processed_data)
time_loss, mark_loss, total_loss = model.compute_loss(processed_data, decoder_output)

print("Time Loss:", time_loss.item())
print("Mark Loss:", mark_loss.item())
print("Total Loss:", total_loss.item())

Time Loss: 0.9263450503349304
Mark Loss: 2.2181396484375
Total Loss: 3.144484758377075


In [9]:
test_data = test_dict['test']
processed_test_data = prepare_data(test_data, config)
model.eval()
processed_test_data = {k: v.to(device) for k, v in processed_test_data.items()}

with torch.no_grad():
    test_decoder_output = model(processed_test_data)
    test_time_loss, test_mark_loss, test_total_loss = model.compute_loss(processed_test_data, test_decoder_output)

print("Test Time Loss:", test_time_loss.item())
print("Test Mark Loss:", test_mark_loss.item())
print("Test Total Loss:", test_total_loss.item())

Test Time Loss: 0.9271682500839233
Test Mark Loss: 2.234062671661377
Test Total Loss: 3.16123104095459
