In [None]:
from model import Model
from inter_model import InterpretationModel
import pickle
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm
import re
import sklearn.metrics

In [None]:
with open('./tokenizer.pickle', 'rb') as f :
    tokenizer = pickle.load(f)

In [None]:
model_path = f'./model.ckpt'

config = {
    'ah': 2,
    'dr': 0.1,
    'beta': 0.59,
    'output_dims': [7, 72, 268, 4255]
}

model = Model(config)

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['callbacks']['StochasticWeightAveraging']['average_model_state'])
model.eval();

In [None]:
sequence = 'AAAAA'

In [None]:
sequence = tokenizer.texts_to_sequences([sequence])
sequence[0] = [22] + sequence[0]
sequence[0] += [0 for _ in range(1024-len(sequence[0]))]
sequence = torch.Tensor(sequence).int()

In [None]:
model(sequence)

In [None]:
model_path = f'./inter_model.ckpt'

config = {
    'ah': 2,
    'dr': 0.1,
    'beta': 0.59,
    'output_dims': [7, 72, 268, 4255]
}

model = InterpretationModel(config)

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['state_dict'])
model.eval();

In [None]:
def avg_heads(cam, grad):
    cam = cam.reshape(-1, cam.shape[-2], cam.shape[-1])
    grad = grad.reshape(-1, grad.shape[-2], grad.shape[-1])
    cam = grad * cam
    cam = cam.clamp(min=0).mean(dim=0)
    return cam

def apply_self_attention_rules(R_ss, cam_ss):
    R_ss_addition = torch.matmul(cam_ss, R_ss)
    return R_ss_addition

def generate_relevance(model, sequence, index=None):
    
    sequence = tokenizer.texts_to_sequences([sequence])
    sequence[0] = [22] + sequence[0]
    sequence[0] += [0 for _ in range(1024-len(sequence[0]))]
    sequence = torch.Tensor(sequence).int()

    output = model(sequence)
    if index == None:
        index = np.argmax(output.cpu().data.numpy(), axis=-1)

    one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
    one_hot[0, index] = 1
    one_hot_vector = one_hot
    one_hot = torch.from_numpy(one_hot).requires_grad_(True)
    one_hot = torch.sum(one_hot * output)
    model.zero_grad()
    one_hot.backward(retain_graph=True)

    num_tokens = 1024
    R = torch.eye(num_tokens, num_tokens)
    for blk in [model.model.enc_1, model.model.enc_2, model.model.enc_3, model.model.enc_4]:
        grad = blk.attention.get_attn_gradients()
        cam = blk.attention.get_attn()
        cam = avg_heads(cam, grad)
        R += apply_self_attention_rules(R, cam)
    return R[0, 1:]

In [None]:
sequence = 'AAA'

In [None]:
exp = generate_relevance(model, sequence, index=None).detach()

In [None]:
kernel_size = 6
kernel = np.ones(kernel_size) / kernel_size
exp = np.convolve(exp, kernel, mode='same')

exp = exp - exp.min()
exp = exp / exp.max()

In [None]:
exp

## Test with New-392

In [None]:
model_path = f'./model.ckpt'

config = {
    'ah': 2,
    'dr': 0.1,
    'beta': 0.59,
    'output_dims': [7, 72, 268, 4255]
}

model = Model(config)

model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['callbacks']['StochasticWeightAveraging']['average_model_state'])
model.eval();

In [None]:
with open('./label_encoder.pkl', 'rb') as f :
    le = pickle.load(f)

In [None]:
new392 = pd.read_csv('./new.csv', sep='\t')

labels = [] 
preds = [] 

ec_number_pattern = r'\b\d+\.\d+\.\d+\.\d+\b'
    
with torch.no_grad() :
    for index, row in tqdm(new392.iterrows(), total=len(new392)) : 

        sequence = row['Sequence']
        sequence = tokenizer.texts_to_sequences([sequence[:1023]])
        sequence[0] = [22] + sequence[0]
        sequence[0] += [0 for _ in range(1024-len(sequence[0]))]
        sequence = torch.Tensor(sequence).int()

        output = (model(sequence)[:, -4255:] > 0.4).int()[0]

        label = row['EC number'].split(';')
        matches = re.findall(ec_number_pattern, ','.join(label))

        labels.append([1 if ec in label else 0 for ec in le.classes_])
        preds.append(output.tolist())


print(f"Micro-averaged F1-score: {sklearn.metrics.f1_score(labels, preds, average='micro'):.2f}")