## Model Interpretation of NeoPrecis-Immuno

This notebook is for model interpretation. 

The results are shown in Figure 3.

In [None]:
from imm_utils import *

In [None]:
### path

ref_file = '../src/CRD/ref.h5'
ckpt_file = '../src/CRD/PeptCRD_checkpoint.ckpt'

In [None]:
### model interpretation object

MI = ModelInterpretation(ref_file, ckpt_file)

# anchor mapping
MI.BuildAnchorMapping() # anchor_pos_series, anchor_residue_dict

# motif enrichment
MI.MotifEnrichmentInterpretation() # enrich_embs, enrich_emb_affine_matrices, enrich_emb_max_changed_axes

# allele annotation
MI.AlleleSummarization() # allele_pos_annot_df, allele_annot_df

### Embedding

#### Residue embedding¶
- BLOSUM62
- Sub. embedding
- Sub. + motif enrichment

In [None]:
### embedding df

# example
allele = 'B*40:01'
pos = 2
motif_idx = MI.ref_allele_list.index(allele)

# BLOSUM62
blosum_emb = MI.ref_aa_pc2_encode[:20,:]
blosum_emb_df = pd.DataFrame(blosum_emb, index=MI.ref_aa_list[:20], columns=['emb1', 'emb2'])

# Model: sub.
sub_emb_df = MI.aa_emb_df

# Model: sub. with motif enrichment
enrich_emb = MI.enrich_embs[motif_idx, pos-1, :20, :] # pos -> pos_index
enrich_emb_df = pd.DataFrame(enrich_emb, index=MI.ref_aa_list[:20], columns=['emb1', 'emb2'])

# affine transformation
affine_matrix = MI.enrich_emb_affine_matrices[motif_idx, pos-1]
affine_intp = MI._affine_interpretation(affine_matrix)
print(affine_intp)

In [None]:
### embedding plot

fig, ax = plt.subplots(1, 3, figsize=(10, 3), dpi=dpi)

EmbeddingPlot(blosum_emb_df, ax=ax[0])
EmbeddingPlot(sub_emb_df, ax=ax[1])
EmbeddingPlot(enrich_emb_df, ax=ax[2])

# axis
ax[0].set_xlabel('')
ax[2].set_xlabel('')
ax[1].set_ylabel('')
ax[2].set_ylabel('')

# legend
ax[0].get_legend().remove()
ax[1].get_legend().remove()
sns.move_legend(ax[2], loc='lower left', bbox_to_anchor=(1, 0.1))

# title
ax[0].set_title('BLOSUM62')
ax[1].set_title('NP-Immuno - residue')
ax[2].set_title('NP-Immuno - residue + motif')

fig.tight_layout()

#### Positional embedding

In [None]:
# position factors
plot_df = pd.DataFrame(index=list(range(1, 10)))
plot_df['MHC-I'] = MI.mhci_pos_facs
plot_df['MHC-II'] = MI.mhcii_pos_facs
plot_df = plot_df.reset_index(names='Position')
plot_df = plot_df.melt(id_vars=['Position'], var_name='MHC', value_name='Factor')

# plot
fig, ax = plt.subplots(1, 1, figsize=(4,3), dpi=dpi)
sns.lineplot(data=plot_df, x='Position', y='Factor', hue='MHC', ax=ax)
ax.legend(loc='lower left', bbox_to_anchor=(0, 1), ncol=2)
_ = ax.set_xticks(range(1, 10))
fig.tight_layout()

#### Affine transformation
- Used for motif-enriched embedding
- Transform sub. embedding to motif-enriched embedding

In [None]:
# example
allele = 'B*07:02'
pos = 2
motif_idx = MI.ref_allele_list.index(allele)

# embedding
sub_emb = MI.aa_emb_df.to_numpy()
enrich_emb = MI.enrich_embs[motif_idx, pos-1, :20, :]

# reconstruction from residue embedding
affine_matrix = MI.enrich_emb_affine_matrices[motif_idx, pos-1]
A = affine_matrix[:, :2]
t = affine_matrix[:, 2]
recon_emb = sub_emb @ A.T + t

In [None]:
### embedding plot

# embedding df
sub_emb_df = MI.aa_emb_df
enrich_emb_df = pd.DataFrame(enrich_emb, index=MI.ref_aa_list[:20], columns=['emb1', 'emb2'])
recon_emb_df = pd.DataFrame(recon_emb, index=MI.ref_aa_list[:20], columns=['emb1', 'emb2'])

# fig
fig, ax = plt.subplots(1, 3, figsize=(10, 3.5), dpi=dpi)

EmbeddingPlot(sub_emb_df, ax=ax[0])
EmbeddingPlot(enrich_emb_df, ax=ax[1])
EmbeddingPlot(recon_emb_df, ax=ax[2])

# axis
ax[0].set_xlabel('')
ax[2].set_xlabel('')
ax[1].set_ylabel('')
ax[2].set_ylabel('')

# legend
ax[0].get_legend().remove()
ax[1].get_legend().remove()
sns.move_legend(ax[2], loc='lower left', bbox_to_anchor=(1, 0.1))

# title
ax[0].set_title('Residue')
ax[1].set_title('Enriched')
ax[2].set_title('Reconstructed')

fig.suptitle(f'{allele}_{pos}')
fig.tight_layout()

### Motif enrichment

#### Scaling

In [None]:
# plot df
plot_df = MI.allele_pos_annot_df.reset_index()
plot_df['position'] = plot_df['position'].astype(str)

# fig
fig, ax = plt.subplots(1, 1, figsize=(7,3), dpi=dpi)
sns.scatterplot(data=plot_df, x='emb1_scaling', y='emb2_scaling', hue='group', palette=aa_color_map, style='residue', ax=ax)

# legend
handles, labels = ax.get_legend_handles_labels()
unique_hue_labels = plot_df['group'].unique()

# Extract handles and labels for the style legend
color_handles = handles[2:len(unique_hue_labels)+1]
color_labels = labels[2:len(unique_hue_labels)+1]
style_handles = handles[len(unique_hue_labels)+2:]
style_labels = labels[len(unique_hue_labels)+2:]

# Add the color legend
color_legend = ax.legend(
    handles=color_handles,
    labels=color_labels,
    loc='upper left',
    bbox_to_anchor=(1, 1),  # Position to avoid overlap
    ncol=2
)

# Add the style legend to the ax
style_legend = ax.legend(
    handles=style_handles,
    labels=style_labels,
    loc='upper left',  # Position for the style legend
    bbox_to_anchor=(1, 0.55),
    ncol=3,
)

ax.add_artist(color_legend)  # Add the color legend first

ax.set_title('Motif enrichment - scaling factors')
fig.tight_layout()

#### Allele benefit score

In [None]:
# plot df
plot_df = MI.allele_annot_df.reset_index()
plot_df = plot_df.dropna()
plot_df['MHC'] = plot_df['allele'].apply(lambda x: 'II' if x.startswith('D') else 'I')

# order
orders = plot_df.groupby(['MHC', 'residue'])['benefitScore'].mean().unstack().T.fillna(0).sort_values(by=['I','II'], ascending=False).index.tolist()

# residue color
colors = [aa_color_map[aa_dict[aa]] for aa in orders]

# plot
fig, ax = plt.subplots(1, 1, figsize=(10, 3), dpi=dpi)
sns.barplot(data=plot_df, x='residue', y='benefitScore', hue='MHC', order=orders, palette='muted', ax=ax)
for label, color in zip(ax.get_xticklabels(), colors):
    label.set_color(color)
sns.move_legend(ax, 'center left', bbox_to_anchor=(1, 0.5))
ax.set_xlabel('')
fig.tight_layout()