In [5]:
from models import MODEL_DIR, load_model
from data_loader import ReviewsDataset, EssaysDataset
from captum.attr import IntegratedGradients
from IPython.display import HTML
import matplotlib
import numpy as np

def show_IG(model, dataset, ex_id=None):
    model.eval()
    if not ex_id:  ex_id = np.random.randint(len(dataset))
    ig = IntegratedGradients(model.forward_emb)
    pred = model(dataset[ex_id][0]).argmax()
    attributions = ig.attribute(inputs=model.get_embeddings(dataset[ex_id][0]), baselines=None, target=pred)
    scores = np.mean(attributions.detach().numpy(), axis=2).squeeze()

    encoding, label = dataset.__getitem__(ex_id, raw=True)

    # create color mapping
    color_mapping = [   # RED => GREEN
        ((206, 35, 35), scores.min()),
        ((255, 255, 255), np.median(scores)),
        ((22, 206, 16), scores.max())
    ]
    
    # CREDIT:  https://databasecamp.de/en/ml/integrated-gradients-nlp  (reference for creating HTML display)
    def create_color_map(color_coords, color_bounds):
        def to_cmap_coord(x, level=0.0):  return( (level, np.interp(x, xp=[0,255], fp=[0,1]), np.interp(x, xp=[0,255], fp=[0,1])) )

        cmap_price_bounds = [np.interp(p, xp=[min(color_bounds), max(color_bounds)], fp=[0, 1]) for p in color_bounds]

        c_dict = {
            'red':tuple(to_cmap_coord(color_coords[i][0], cmap_price_bounds[i]) for i in range(len(color_coords))),
            'green':tuple(to_cmap_coord(color_coords[i][1], cmap_price_bounds[i]) for i in range(len(color_coords))),
            'blue':tuple(to_cmap_coord(color_coords[i][2], cmap_price_bounds[i]) for i in range(len(color_coords))),
        }
        
        return (matplotlib.colors.LinearSegmentedColormap('cmap', segmentdata=c_dict))
    c_map = create_color_map([c[0] for c in color_mapping], [c[1] for c in color_mapping])
    norm = matplotlib.colors.Normalize(vmin=scores.min(), vmax=scores.max())

    def build_html(text, c_map, norm, encoding, scores):
        def highlight(token, score):
            return f"<mark style=\"margin: 0; padding: 0; background-color:{matplotlib.colors.rgb2hex(c_map(norm(score)))}\">{token}</mark>"
        prev = (0, 0)
        cur_html = ""
        for i in range(len(encoding)):
            cur_html = cur_html + text[prev[1]: encoding.offsets[i][0]]
            cur_html = cur_html + highlight(encoding.tokens[i], scores[i])
            prev = encoding.offsets[i]
        return HTML(cur_html)
    
    print(f"{'Predicted Rating:' : <18}", pred.item())
    print(f"{'Actual Rating:' : <18}", label)

    return build_html(dataset.get_text(ex_id), c_map, norm, encoding, scores)

In [6]:
#show_IG(load_model('reviews_dan_cat'), ReviewsDataset(score_type='categorical'), ex_id=22)
show_IG(load_model('reviews_trans_cat'), ReviewsDataset(score_type='categorical'), ex_id=202)
#show_IG(load_model('essays_dan_bin'), EssaysDataset(score_type='binary'), ex_id=22)

Predicted Rating:  4
Actual Rating:     3
