In [25]:
from model import Model
from inter_model import InterpretationModel
import pickle

In [5]:
with open('./tokenizer.pickle', 'rb') as f :
    tokenizer = pickle.load(f)

In [3]:
model_path = f'./model.ckpt'

config = {
    'ah': 2,
    'dr': 0.1,
    'beta': 0.59,
    'output_dims': [7, 72, 268, 4255]
}

model = Model(config)

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['callbacks']['StochasticWeightAveraging']['average_model_state'])
model.eval();

In [18]:
sequence = 'AAAAA'

In [19]:
sequence = tokenizer.texts_to_sequences([sequence])
sequence[0] = [22] + sequence[0]
sequence[0] += [0 for _ in range(1024-len(sequence[0]))]
sequence = torch.Tensor(sequence).int()

In [21]:
model(sequence)

tensor([[3.8450e-01, 5.8693e-03, 1.6082e-01,  ..., 1.5032e-09, 1.4772e-09,
         3.3178e-10]], grad_fn=<AddBackward0>)

In [35]:
model_path = f'./inter_model.ckpt'

config = {
    'ah': 2,
    'dr': 0.1,
    'beta': 0.59,
    'output_dims': [7, 72, 268, 4255]
}

model = InterpretationModel(config)

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['state_dict'])
model.eval();

In [36]:
def avg_heads(cam, grad):
    cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
    grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
    cam = grad * cam
    cam = cam.clamp(min=0).mean(dim=0)
    return cam

# rule 6 from paper
def apply_self_attention_rules(R_ss, cam_ss):
    R_ss_addition = torch.matmul(cam_ss, R_ss)
    return R_ss_addition

def generate_relevance(model, sequence, index=None):
    
    sequence = tokenizer.texts_to_sequences([sequence])
    sequence[0] = [22] + sequence[0]
    sequence[0] += [0 for _ in range(1024-len(sequence[0]))]
    sequence = torch.Tensor(sequence).int()

    output = model(sequence)
    if index == None:
        index = np.argmax(output.cpu().data.numpy(), axis=-1)

    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot_vector = one_hot
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot * output)
    model.zero_grad()
    one_hot.backward(retain_graph=True)

    num_tokens = 1024
    R = torch.eye(num_tokens, num_tokens)
    for blk in [model.model.enc_1, model.model.enc_2, model.model.enc_3, model.model.enc_4]:
        grad = blk.attention.get_attn_gradients()
        cam = blk.attention.get_attn()
        cam = avg_heads(cam, grad)
        R += apply_self_attention_rules(R, cam)
    return R[0, 1:]

In [37]:
sequence = 'AAA'

In [38]:
exp = generate_relevance(model, sequence, index=None).detach()

In [40]:
kernel_size = 6
kernel = np.ones(kernel_size) / kernel_size
exp = np.convolve(exp, kernel, mode='same')

exp = exp - exp.min()
exp = exp / exp.max()