In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

## load data

annotated sequences

In [None]:
test = pd.read_csv('./data/TTE-ds/annotated/paired-1k-annotated.csv')
test = test[['sequence_id', 'cdr_mask_heavy', 'cdr_mask_light']]

inference results

In [None]:
models = ['paired-absolute', 'paired-rotary', 'unpaired-absolute', 'unpaired-rotary']

In [None]:
results = []
for model in models:
    res = pd.read_parquet(f'./results/per-position/{model}_paired1k-perpos-loss.parquet')
    res = res.merge(test, on='sequence_id')
    results.append(res)

In [None]:
len(results)

## processing

In [None]:
# separate heavy and light chain losses
def extract(df):
    h_loss = []
    l_loss = []

    for _, r in df.iterrows():
        hlen = len(r['heavy'])
        seplen = r['sep'].count('<')
        h_loss.append(r['loss'][:hlen])
        l_loss.append(r['loss'][hlen:])

    df['heavy_loss'] = h_loss
    df['light_loss'] = l_loss

    return df

In [None]:
results = [extract(dataset) for dataset in results]

In [None]:
# extract loss by region
def region_processing(df, ppl_data, model):

    regions = ['fwr1', 'cdr1', 'fwr2', 'cdr2', 'fwr3', 'cdr3', 'fwr4']
    
    for _, r in df.iterrows():

        # for both chains separately
        for chain in ['heavy', 'light']:
            loss = r[f'{chain}_loss']
            cdr_mask = r[f'cdr_mask_{chain}']
    
            # find regions
            mask_segments = []
            prev_char = cdr_mask[0]
            start_idx = 0
    
            for i, char in enumerate(cdr_mask):
                if char != prev_char:  # region change
                    mask_segments.append((start_idx, i))
                    start_idx = i
                prev_char = char
            mask_segments.append((start_idx, len(cdr_mask))) # final region
    
            assert len(mask_segments) == len(regions) 

            # extract by region
            for (region, (hstart, hend)) in zip(regions, mask_segments):
                region_loss = loss[hstart:hend]
    
                ppl_data.append({
                    'region': region,
                    'model': model,
                    'chain': chain,
                    'loss': region_loss,
                    'mean_loss': np.mean(region_loss),
                    'median_loss': np.median(region_loss),
                })    
            
    return ppl_data

In [None]:
ppl_data = []
for dataset, model in zip(results, models):
    ppl_data = region_processing(dataset, ppl_data, model)

In [None]:
ppl_df = pd.DataFrame(ppl_data)
ppl_df

## plot

In [None]:
# regions for axis labels
H_REGIONS = ['FRH1', 'CDRH1', 'FRH2', 'CDRH2', 'FRH3', 'CDRH3', 'FRH4']
L_REGIONS = ['FRL1', 'CDRL1', 'FRL2', 'CDRL2', 'FRL3', 'CDRL3', 'FRL4']

In [None]:
# plotting unpaired models only
unpaired_models = ['unpaired-absolute', 'unpaired-rotary']

In [None]:
# select color palette
color_palette = sns.color_palette("hls", 8)
unpaired_colors = [color_palette[1], color_palette[0]]

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(8, 6))

for i, chain in enumerate(["heavy", "light"]):
    # boxplot
    sns.boxenplot(
        data=ppl_df[(ppl_df['chain']==chain) & (ppl_df['model'].isin(unpaired_models))], 
        x='region', 
        y='median_loss', 
        hue='model', 
        palette=unpaired_colors,
        dodge=True,
        showfliers=False,
        k_depth='proportion',
        outlier_prop=0.1,
        width=0.7,
        saturation=1,
        ax=ax[i],
    )

    # ticks
    ax[i].tick_params(axis='x', labelsize=11)
    ax[i].set_xticks(range(len(L_REGIONS)))
    ax[i].set_xticklabels(L_REGIONS if chain == 'light' else H_REGIONS) 

    # labels
    ax[i].set_xlabel('', fontsize=0)
    ax[i].set_ylabel(f'{chain.title()} Chain \n Per-position CE Loss', fontsize=12)

    # remove legends
    ax[i].get_legend().remove()

plt.savefig("./results/unpaired-model_paired-loss.png", bbox_inches='tight', dpi=300)