# UMAP embeddings for BALM-paired

The code below extracts the final layer embeddings for either of our paired sequence models (BALM-paired or BALM-shuffled), reduces via UMAP, and plots clusters (colored by v-gene and mutation)

## setup

In [None]:
from dataclasses import dataclass
import pickle
from tqdm.notebook import tqdm

import numpy as np
import pandas as pd

import torch
from transformers import RobertaTokenizer, RobertaForMaskedLM

import umap
import abutils
import seaborn as sns
import matplotlib as mpl
from scatterplot import scatter
from natsort import natsorted

## load model

In [None]:
# replace with actual path to model
model_path = './models/BALM-unpaired/' 

In [None]:
model = RobertaForMaskedLM.from_pretrained(model_path).to('cuda')

## tokenizer

In [None]:
tokenizer = RobertaTokenizer.from_pretrained('../tokenizer/')

## load data

In [None]:
# replace with actual path to data
# this is a subset of 20k paired sequences from the AIRR-annotated test dataset
data_path = './test-20kembeddings_annotated.csv'

In [None]:
df = pd.read_csv(data_path)
df['chain'] = ['heavy' if l == 'IGH' else 'light' for l in df['locus']]

In [None]:
seqs = list(df.sequence_aa)
seq_names = list(df.pair_id)
chains = list(df.chain)

## tokenize data

In [None]:
tokenized_data = []
for s in tqdm(seqs):
    tokenized_data.append(tokenizer(s, return_tensors='pt').to('cuda'))

## inference & extract embeddings

In [None]:
@dataclass
class ModelOutput:
    '''
    
    '''
    name: str
    chain: str
    mean_final_layer_embedding: np.ndarray

In [None]:
inputs = list(zip(seq_names, seqs, chains, tokenized_data))

outputs = []
with torch.no_grad():
    for name, seq, chain, i in tqdm(inputs):
        o = model(
            **i,
            output_hidden_states=True,
            return_dict=False,
        )
        
        # extract hidden states of final layer
        # o is a tuple of (logits, hidden_states)
        final_layer_hidden_state = np.array(o[1][-1][0].to('cpu'))

        # average across sequence dimension and append to outputs
        outputs.append(ModelOutput(name, chain, final_layer_hidden_state.mean(axis=0)))

In [None]:
with open('./unpaired_outputs_20k.pkl', 'wb') as f:
    pickle.dump(outputs, f)

## process outputs

In [None]:
unpaired_hdata = {o.name: o.mean_final_layer_embedding for o in outputs if o.chain == 'heavy'}
unpaired_hdf = pd.DataFrame(unpaired_hdata).T

In [None]:
unpaired_ldata = {o.name: o.mean_final_layer_embedding for o in outputs if o.chain == 'light'}
unpaired_ldf = pd.DataFrame(unpaired_ldata).T

Reformat heavy and light chain annotations for coloring plots:

In [None]:
seqs = abutils.io.read_csv(data_path)
pairs = abutils.pair.assign_pairs(seqs, id_key='pair_id')
pdict = {p.name: p for p in pairs}

In [None]:
heavies = [pdict[p].heavy for p in unpaired_hdf.index.values]
lights = [pdict[p].light for p in unpaired_ldf.index.values]

## UMAP

In [None]:
unpaired_hreducer = umap.UMAP()
unpaired_hembedding = unpaired_hreducer.fit_transform(unpaired_hdf)

In [None]:
unpaired_lreducer = umap.UMAP()
unpaired_lembedding = unpaired_lreducer.fit_transform(unpaired_ldf)

## plot - VH gene

In [None]:
vh_fams = [h['v_gene'].split('-')[0] for h in heavies]
vh_order = [f'IGHV{i}' for i in range(1, 8)]

scatter(
    x=unpaired_hembedding[:, 0],
    y=unpaired_hembedding[:, 1],
    hue=vh_fams,
    hue_order=vh_order,
    size=15,
    alpha=0.1,
    hide_legend=True,
    xlabel='UMAP1',
    ylabel='UMAP2',
    xlabel_fontsize=14,
    ylabel_fontsize=14,
    figsize=[5, 5],
    equal_axes=False,
    #figfile='./figures/umap_unpaired-model_heavy-chains_vgene-colored_scatterplot.pdf'
)

## plot - VH mutations

In [None]:
def get_grey_zero_cmap(cmap):
    base_cmap = mpl.cm.get_cmap(cmap).copy()
    cropped_cmap = base_cmap(np.linspace(0.1, 1, 255))
    cmap_colors = [np.array([0.8, 0.8, 0.8, 1.0])] + list(cropped_cmap)
    return mpl.colors.LinearSegmentedColormap.from_list("mycmap", cmap_colors)

In [None]:
vh_muts = [float(min(50, h['v_mutation_count'])) for h in heavies]
mut_cmap = get_grey_zero_cmap('YlOrRd')
mut_cmap = get_grey_zero_cmap('Reds')


scatter(
    x=unpaired_hembedding[:, 0],
    y=unpaired_hembedding[:, 1],
    hue=vh_muts,
    cmap=mut_cmap,
    size=15,
    alpha=0.1,
    hide_legend=True,
    hide_cbar=True,
    xlabel='UMAP1',
    ylabel='UMAP2',
    xlabel_fontsize=14,
    ylabel_fontsize=14,
    figsize=[5, 5],
    equal_axes=False,
    #figfile='./figures/umap_unpaired-model_heavy-chains_mutation-colored_scatterplot.pdf'
)

## plot - VL gene

In [None]:
vl_fams = [l['v_gene'].split('-')[0].rstrip('D') for l in lights]
vl_order = natsorted(set(vl_fams))

scatter(
    x=unpaired_lembedding[:, 0],
    y=unpaired_lembedding[:, 1],
    hue=vl_fams,
    hue_order=vl_order,
    color=sns.hls_palette(len(vl_order)),
    size=15,
    alpha=0.1,
    hide_legend=True,
    xlabel='UMAP1',
    ylabel='UMAP2',
    xlabel_fontsize=14,
    ylabel_fontsize=14,
    figsize=[5, 5],
    equal_axes=False,
    #figfile='./figures/umap_unpaired-model_light-chains_vgene-colored_scatterplot.pdf'
)

## plot - VL mutations

In [None]:
vl_muts = [float(min(50, l['v_mutation_count'])) for l in lights]
mut_cmap = get_grey_zero_cmap('YlOrRd')
mut_cmap = get_grey_zero_cmap('Reds')


scatter(
    x=unpaired_lembedding[:, 0],
    y=unpaired_lembedding[:, 1],
    hue=vl_muts,
    cmap=mut_cmap,
    size=15,
    alpha=0.2,
    hide_legend=True,
    hide_cbar=True,
    xlabel='UMAP1',
    ylabel='UMAP2',
    xlabel_fontsize=14,
    ylabel_fontsize=14,
    figsize=[5, 5],
    equal_axes=False,
    #figfile='./figures/umap_unpaired-model_light-chains_mutation-colored_scatterplot.pdf'
)