In [1]:
import sys
sys.path.append('../Utils')

import pandas as pd
import numpy as np
from sklearn import pipeline, preprocessing, model_selection, base, compose, metrics
import torch
import torch.nn as nn
import torch.optim as optim
import skorch
import os
import joblib

from utils import ReorderTransformer, DTypeTransformer, LabelingTransformer, RestoreMoveCheckpoint, DimensionTransformer, TransformerModelv4, build_inference_pipe

SEED = 42

torch.manual_seed(SEED)
np.random.seed(SEED)

In [2]:
DATASET_DIR = '../Dataset'
MODEL = '../Models/transformer_v4/transformer_v4_baseline.jl'

# Load datasets

In [3]:
train_df = pd.read_csv(os.path.join(DATASET_DIR, 'n_train.csv'))
dev_df = pd.read_csv(os.path.join(DATASET_DIR, 'n_dev.csv'))

In [4]:
df = pd.concat((train_df, dev_df))

In [5]:
train_indices = np.arange(0, train_df.shape[0])
dev_indices = np.arange(train_df.shape[0], train_df.shape[0] + dev_df.shape[0])

In [6]:
ds_info = joblib.load(os.path.join(DATASET_DIR, 'ds_info.jl'))

columns = ds_info['columns']
numerical_cols = ds_info['numerical_columns']
categorical_cols = ds_info['categorical_columns']
label_col = ds_info['target_column']

# Load model

In [7]:
pipe = joblib.load(MODEL)

In [8]:
pipe.score(dev_df.drop(label_col, axis=1), dev_df[label_col].values[:, np.newaxis].astype(np.float32))

0.8602

# Hook registration and attention extraction

In [9]:
attn_weights = []

In [10]:
def attention_extraction(self, input, output):
    attn_weights.append(output[1].detach().cpu().numpy())

In [11]:
for enc_layer in pipe['classifier'].module_.transformer_encoder.layers:
    enc_layer.self_attn.register_forward_hook(attention_extraction)

In [12]:
attn_weights = []

preds = pipe.predict(df.drop(label_col, axis=1))

In [13]:
attn_weights = np.vstack(attn_weights)

In [14]:
assert attn_weights.shape[0] == df.shape[0], 'Shapes does not match'

In [15]:
_ = joblib.dump((attn_weights[train_indices], train_df[label_col]), os.path.join(DATASET_DIR, 'attn_train.jl'))
_ = joblib.dump((attn_weights[dev_indices], dev_df[label_col]), os.path.join(DATASET_DIR, 'attn_dev.jl'))