# 🧬 Multi-Class Bacterial Species Identification from High-Dimensional Gene Expression Data

This notebook presents an end-to-end machine learning pipeline for classifying **10 bacterial species** from high-dimensional genetic expression data (286 gene features per sample).

### Objectives
- Perform exploratory data analysis (EDA) on bacterial genetic data
- Identify highly correlated gene pairs within and across species
- Apply **Box-Cox transformation** to normalize skewed gene distributions
- Use **Principal Component Analysis (PCA)** to reduce dimensionality and visualize species clusters
- Train and evaluate an **Extra Trees Classifier** under two configurations:
  - Using **100 principal components** (PCA-reduced)
  - Using all **286 original features** (full feature space)
- Compare model performance and generate test set predictions

### Dataset Summary
| Split | Raw Samples | Features | After Deduplication |
|-------|------------|----------|---------------------|
| Train | 200,000 | 287 (286 genes + target) | 123,993 |
| Test  | 100,000 | 286 | — |


## 1. Setup & Data Loading

We begin by importing all necessary libraries and loading the train/test CSV files.

- **Pandas / NumPy** — data wrangling and numerical operations
- **Seaborn / Matplotlib** — static visualization
- **Plotly** — interactive charts
- **SciPy** — statistical transformations (Box-Cox)
- **scikit-learn** — preprocessing, PCA, modeling, and evaluation


In [None]:
# ── Standard library ────────────────────────────────────────────────────────
import os, warnings
warnings.filterwarnings("ignore")

# ── Numerical & data wrangling ───────────────────────────────────────────────
import numpy as np
import pandas as pd

# ── Static visualization ─────────────────────────────────────────────────────
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors as mpl

# ── Interactive visualization ────────────────────────────────────────────────
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from plotly.offline import init_notebook_mode

# ── Statistical transformation ───────────────────────────────────────────────
from scipy.special import boxcox1p
from scipy.stats import boxcox_normmax

# ── Machine learning ─────────────────────────────────────────────────────────
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold
from sklearn.metrics import (accuracy_score, roc_curve, roc_auc_score,
                             auc, classification_report)

# ── Load data ─────────────────────────────────────────────────────────────────
# Update these paths to point to your local data directory
train = pd.read_csv('data/train.csv', index_col=0)
test  = pd.read_csv('data/test.csv',  index_col=0)
sub   = pd.read_csv('data/sample_submission.csv')

# ── Quality check ─────────────────────────────────────────────────────────────
# Report shape, missing values, and duplicate rows for both splits
print('Train Shape: {}  |  Missing Data: {}  |  Duplicates: {}'.format(
      train.shape, train.isna().sum().sum(), train.duplicated().sum()))
print('Test  Shape: {}  |  Missing Data: {}  |  Duplicates: {}'.format(
      test.shape,  test.isna().sum().sum(),  test.duplicated().sum()))

# ── Remove duplicates from training data ─────────────────────────────────────
# ~38% of training rows are exact duplicates; removing them prevents data leakage
# and speeds up training
train_d = train.drop_duplicates()
print('\nAfter deduplication → Train Shape: {}'.format(train_d.shape))


## 2. Exploratory Data Analysis

### 2.1 Summary Statistics by Species

We compute descriptive statistics (mean, std, min, max, quartiles) grouped by the target species label.
This gives us a high-level sense of how gene expression levels differ across bacterial species.


In [None]:
train_d.groupby('target').describe()

### 2.2 Species Class Distribution

Understanding the class balance is important before modeling. Imbalanced classes can bias a classifier
towards the majority class. The bar chart below shows each species as a percentage of the deduplicated
training set.


In [None]:
import plotly.io as pio
pio.renderers.default = "colab"

# ── Color palette and layout template ────────────────────────────────────────
pal  = sns.color_palette("mako_r", 12).as_hex()[:10]
temp = dict(layout=go.Layout(font=dict(family="Franklin Gothic", size=12)))

# ── Aggregate class proportions ──────────────────────────────────────────────
bact = train_d.target.value_counts(normalize=True).reset_index()
bact['proportion'] = bact['proportion'].mul(100)
bact['target']     = bact['target'].str.replace('_', ' ')
bact = bact.sort_values(by='proportion', ascending=False)

# ── Plot ─────────────────────────────────────────────────────────────────────
fig = px.bar(bact, x='target', y='proportion', text='proportion',
             color='target', color_discrete_sequence=pal, opacity=0.8)
fig.update_traces(texttemplate='%{text:,.2f}%', textposition='outside',
                  marker_line=dict(width=1, color='#28221D'))
fig.update_yaxes(visible=False, showticklabels=False)
fig.update_layout(template=temp,
                  title_text='Distribution of Bacterial Species in Training Data',
                  xaxis=dict(title='', tickangle=25, showline=True),
                  height=450, width=700, showlegend=False)
fig.show()


### 2.3 Inter-Gene Correlation Matrix

We compute pairwise Pearson correlations across all 286 gene features to identify:
- **Redundant features:** Highly correlated gene pairs carry overlapping information
- **Feature clusters:** Groups of co-expressed genes may represent biological pathways

A color-gradient heatmap allows visual inspection of the full correlation structure.


In [None]:
# ── Compute full correlation matrix (genes only, exclude target) ──────────────
cor = train_d.drop('target', axis=1).corr()

# Render as a styled heatmap (may be slow for large matrices)
cor.style.background_gradient(cmap='viridis')


### 2.4 Most Correlated Gene Pairs (Global)

We extract all gene pairs with an **absolute correlation above 0.75**, which indicates strong
linear dependence. These pairs are candidates for removal or consolidation via PCA.


In [None]:
# ── Flatten upper triangle of the correlation matrix ─────────────────────────
c = (cor.abs()
        .unstack()
        .drop_duplicates()
        .reset_index()
        .rename(columns={'level_0': 'Gene 1', 'level_1': 'Gene 2', 0: 'Correlation'}))

# Filter to strong correlations (below 1.0 to exclude self-correlation)
c = (c.query('.75 <= Correlation < 1')
      .sort_values(by='Correlation', ascending=False)
      .reset_index(drop=True))

c.style.background_gradient(cmap='flare_r')


### 2.5 Most Correlated Gene Pairs Per Species

For each bacterial species independently, we find the single most correlated gene pair.
Species-specific correlations can reveal genetic relationships that are masked when all
species are pooled together.


In [None]:
# ── Per-species top correlated gene pair ─────────────────────────────────────
for species_name in train_d.target.unique():
    # Subset to this species and drop the target label column
    subset = train_d[train_d.target == species_name].drop('target', axis=1)

    # Compute correlation matrix for this species
    cor_sp = subset.corr()

    # Flatten, remove self-correlations, and sort
    c_sp = (cor_sp.abs()
                  .unstack()
                  .drop_duplicates()
                  .reset_index()
                  .rename(columns={'level_0': 'Gene 1', 'level_1': 'Gene 2', 0: 'Correlation'}))
    c_sp = (c_sp.query('Correlation < 1')
                .sort_values(by='Correlation', ascending=False)
                .reset_index(drop=True))

    # Display only the top pair with a caption
    display(
        c_sp.iloc[:1, :].style
            .background_gradient(cmap='flare')
            .set_caption('Most correlated gene pair in {}'.format(
                species_name.replace('_', ' ')))
    )


## 3. Feature Engineering

### 3.1 Box-Cox Transformation for Skewed Features

Many gene expression values follow a right-skewed distribution, which can violate the assumptions
of linear models and reduce performance even for tree-based models. We apply the **Box-Cox power
transformation** to any feature with skewness > 0.75.

- `boxcox_normmax(x + 1)` estimates the optimal lambda (λ) that best normalizes the distribution
- `boxcox1p(x, λ)` applies the transformation safely for zero-containing data


In [None]:
# ── Identify skewed features ──────────────────────────────────────────────────
skew_cols = (train_d.select_dtypes(exclude='object')
                     .skew()
                     .sort_values(ascending=False))
skew_cols = (pd.DataFrame(skew_cols.loc[skew_cols > 0.75])
               .rename(columns={0: 'Skew Before'}))

print(f'Features with skew > 0.75: {len(skew_cols)}')

# ── Apply Box-Cox transformation ──────────────────────────────────────────────
t = train_d.copy()
for col in skew_cols.index.tolist():
    t[col] = boxcox1p(t[col], boxcox_normmax(t[col] + 1))

# ── Compare skewness before vs after ─────────────────────────────────────────
skew_df = (pd.concat([skew_cols, t[skew_cols.index].skew()], axis=1)
             .rename(columns={0: 'Skew After'}))
print('\nSkewness reduction (first 5 features):')
skew_df.head()


### 3.2 Feature Standardization

Before applying PCA or distance-sensitive algorithms, we standardize all features to
**zero mean and unit variance** using `StandardScaler`. This ensures that features with
large absolute values do not dominate the principal components.


In [None]:
# ── Separate features from target ────────────────────────────────────────────
X = t.drop('target', axis=1)

# ── Fit and apply StandardScaler ─────────────────────────────────────────────
# We fit on the transformed training data; the same scaler will later be
# applied to the validation and test sets
X_scaled = pd.DataFrame(StandardScaler().fit_transform(X), columns=X.columns)

print(f'Scaled feature matrix shape: {X_scaled.shape}')
X_scaled.head()


## 4. Principal Component Analysis (PCA)

### 4.1 Full PCA — Explained Variance

We first fit PCA with all 286 components to understand how much variance is captured
cumulatively. The animated chart below shows:
- **Individual variance** (blue): Contribution of each successive principal component
- **Cumulative variance** (teal): Total variance explained as we include more components

This helps us determine the optimal number of components for the modeling step
(balancing information retention vs. dimensionality reduction).


In [None]:
# ── Fit full PCA (n_components = number of features) ─────────────────────────
pca_full = PCA(n_components=286).fit(X_scaled)

# Cumulative and individual variance series (in %)
pca_cum = pd.Series(np.cumsum(pca_full.explained_variance_ratio_)).mul(100)
pca_ind = pd.Series(pca_full.explained_variance_ratio_).mul(100)

# Re-index from 1 for readability
pca_cum.index = np.arange(1, len(pca_cum) + 1)
pca_ind.index = np.arange(1, len(pca_ind) + 1)

print(f'Variance explained by PC1:  {pca_ind[1]:.2f}%')
print(f'Variance explained by PC2:  {pca_ind[2]:.2f}%')
print(f'Variance explained by top 10 PCs:  {pca_cum[10]:.2f}%')
print(f'Variance explained by top 100 PCs: {pca_cum[100]:.2f}%')

# ── Animated variance chart ──────────────────────────────────────────────────
fig = go.Figure(
    layout=go.Layout(
        updatemenus=[dict(type="buttons", direction="left", x=0.15, y=1.2, showactive=False)],
        xaxis=dict(range=[1, 287], autorange=False, tickwidth=2),
        yaxis=dict(range=[0, 100], autorange=False)))

fig.add_trace(go.Scatter(x=pca_ind.index[:1], y=pca_ind[:1],
                         line=dict(color='#5758A3', width=3), visible=True,
                         fill='tozeroy', opacity=0.8,
                         hovertemplate='Variance Explained = %{y:.2f}%<br>Principal Component %{x:.0f}',
                         name='Individual'))
fig.add_trace(go.Scatter(x=pca_cum.index[:1], y=pca_cum[:1],
                         line=dict(color='#57A0A3', width=3), visible=True,
                         fill='tonexty', opacity=0.7,
                         hovertemplate='Cumulative Variance = %{y:.1f}%<br>Components = %{x:.0f}',
                         name='Cumulative'))

fig.update(frames=[
    go.Frame(data=[go.Scatter(x=pca_ind.index[:i], y=pca_ind[:i]),
                   go.Scatter(x=pca_cum.index[:i], y=pca_cum[:i])])
    for i in range(1, 287)])

fig.update_yaxes(title='Variance Explained (%)', showline=True, ticksuffix='%', range=[0, 105])
fig.update_layout(template=temp,
                  title='Cumulative & Individual Variance Explained by Principal Components',
                  xaxis_title='Number of Principal Components',
                  hovermode='x unified', width=700,
                  legend=dict(orientation='v', yanchor='bottom', y=1.08,
                              xanchor='right', x=.99, title=''),
                  updatemenus=[dict(
                      buttons=[dict(label='▶ Play', method='animate',
                                    args=[None, {'frame': {'duration': 15, 'redraw': False}},
                                          {'fromcurrent': True}]),
                               dict(label='⏸ Pause', method='animate',
                                    args=[[None], {'frame': {'duration': 0, 'redraw': False},
                                                   'mode': 'immediate', 'transition': {'duration': 0}}])],
                      direction='left', x=0.15, y=1.2)])
fig.show()


### 4.2 Gene Importance Across Top 10 Principal Components

Each principal component is a linear combination of the original 286 genes. The **loading**
of a gene on a component reflects how much that gene contributes to that component.

We compute a **weighted importance** for each gene in the top 10 PCs:

```
weighted_importance = |loading| × variance_explained_by_PC
```

This tells us which genes carry the most meaningful signal in the compressed representation.


In [None]:
# ── Extract PCA component loadings ───────────────────────────────────────────
# Shape: (n_features, n_components) — rows=genes, cols=PCs
loadings = pd.DataFrame(
    abs(pca_full.components_.T),
    columns=['PC' + str(i + 1) for i in range(286)],
    index=X_scaled.columns
)

# ── Compute weighted gene importance per PC ───────────────────────────────────
# For each PC, find the gene with the highest loading and weight it by
# that PC's explained variance fraction
pca_var_ratio = pd.Series(pca_full.explained_variance_ratio_)
var_pca = []
for pc_name, var_frac in zip(loadings.columns, pca_var_ratio):
    top_gene = loadings[pc_name].nlargest(1)           # gene with highest loading
    weighted  = top_gene * var_frac                    # weight by variance explained
    var_pca.append(pd.DataFrame({
        'Principal Component': str(int(pc_name[2:])),
        'Gene': weighted.index,
        'Weighted Importance': weighted.values
    }))

var_pca = pd.concat(var_pca).reset_index(drop=True)

# ── Plot top 10 PCs ───────────────────────────────────────────────────────────
plot_df = var_pca.iloc[:10, :]
pal_v   = sns.color_palette("viridis", 14).as_hex()[1:11]

fig = px.bar(plot_df, x='Gene', y='Weighted Importance', text='Weighted Importance',
             color='Principal Component', color_discrete_sequence=pal_v, opacity=0.7)
fig.update_traces(texttemplate='%{text:,.3f}', textposition='outside',
                  marker_line=dict(width=1, color='#28221D'))
fig.update_layout(template=temp,
                  title_text='Gene Importance in the Top 10 Principal Components',
                  xaxis_title='Gene Segment', xaxis_tickangle=28,
                  yaxis_title='Weighted Importance', legend_title='Principal<br>Component',
                  height=500, width=700)
fig.show()


### 4.3 Top Gene Loadings in PC1 – PC4

Each subplot below shows the **5 genes with the highest absolute loadings** on the first
four principal components. High loadings indicate that a gene strongly defines that component's
direction in the feature space — these genes are the primary "drivers" of each PC.


In [None]:
# ── Subplot: top-5 gene loadings for PC1, PC2, PC3, PC4 ─────────────────────
fig = make_subplots(rows=2, cols=2,
                    subplot_titles=('Principal Component 1', 'Principal Component 2',
                                    'Principal Component 3', 'Principal Component 4'))

for idx, (pc, row, col) in enumerate([('PC1',1,1), ('PC2',1,2), ('PC3',2,1), ('PC4',2,2)]):
    top5 = loadings[pc].sort_values(ascending=False)[:5]
    fig.add_trace(
        go.Bar(x=top5.index, y=top5, name=pc, showlegend=False,
               marker_color=pal_v[idx], opacity=0.8,
               marker_line=dict(width=1, color='#28221D'),
               hovertemplate=f'Gene %{{x}} loading on {pc} = %{{y:.3f}}<extra></extra>'),
        row=row, col=col)

fig.update_layout(template=temp,
                  title_text='Top-5 Gene Loadings on the First Four Principal Components',
                  height=900, width=700)
fig.show()


### 4.4 Species Clusters in 2D PCA Space

Reducing the data to 10 principal components and plotting the first two gives us an
intuitive 2D view of how well the species separate in the compressed feature space.

- **Well-separated clusters** → PCA successfully captures species-discriminating variance
- **Overlapping clusters** → The classifier will need more components to resolve fine-grained differences

PC1 and PC2 together explain ~52% of the total variance.


In [None]:
# ── Project training data onto first 10 PCs ───────────────────────────────────
pca_10    = PCA(n_components=10).fit_transform(X_scaled)
pca_df    = pd.DataFrame(pca_10, columns=['PC' + str(i + 1) for i in range(10)])
species   = train_d.target.reset_index(drop=True).str.replace('_', ' ')

pca_df = pd.concat([species, pca_df], axis=1)

# Sort by species frequency so more common species are drawn first (z-order)
pca_df['_freq'] = pca_df['target'].map(pca_df['target'].value_counts())
pca_df = pca_df.sort_values('_freq', ascending=False).drop('_freq', axis=1)

# ── Scatter plot (PC1 vs PC2) ─────────────────────────────────────────────────
fig = px.scatter(pca_df, x='PC1', y='PC2', color='target',
                 color_discrete_sequence=pal, opacity=0.4)
fig.update_traces(marker_size=7,
                  hovertemplate='PC1 = %{x:.2f}<br>PC2 = %{y:.2f}')
fig.update_layout(template=temp,
                  title='Bacterial Species Projected onto Principal Components 1 and 2',
                  legend_title='Species', width=700, height=600,
                  xaxis_title='Component 1 (~31.9% variance)',
                  yaxis_title='Component 2 (~20.4% variance)')
fig.show()


## 5. Classification Modeling

We train an **Extra Trees Classifier** — an ensemble of fully randomized decision trees — because:
- It handles high-dimensional data well without overfitting
- It is computationally efficient with many features
- `class_weight='balanced'` corrects for any residual class imbalance

We compare two feature configurations:
1. **PCA-reduced (100 components):** Tests whether a compressed representation is sufficient
2. **Full feature space (286 features):** Serves as the upper-bound baseline

### 5.1 Data Splitting & Scaling


In [None]:
# ── Encode species labels as integers (required by sklearn) ───────────────────
enc = LabelEncoder()
y   = enc.fit_transform(train_d.target)

# Features (Box-Cox transformed; target already removed from train_d above)
X_full = train_d.drop('target', axis=1, errors='ignore')

# ── Stratified 80/20 train/validation split ───────────────────────────────────
# stratify=y ensures each species is proportionally represented in both splits
X_train, X_val, y_train, y_val = train_test_split(
    X_full, y, test_size=0.2, shuffle=True, stratify=y, random_state=21)

# ── Fit scaler on train; transform train, val, and test ───────────────────────
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled   = scaler.transform(X_val)
X_test_scaled  = scaler.transform(test)

print(f'Train:      {X_train_scaled.shape}  |  Labels: {y_train.shape}')
print(f'Validation: {X_val_scaled.shape}  |  Labels: {y_val.shape}')
print(f'Test:       {X_test_scaled.shape}')


### 5.2 Model A — Extra Trees with 100 Principal Components

Here we fit PCA with 100 components on the **training set** and apply the learned
transformation to the validation and test sets. The classifier is then trained on the
compressed representations.

> **Why 100 components?** From the variance chart above, 100 PCs capture a large
> fraction of the total variance while cutting the feature count by ~65%.


In [None]:
# ── PCA: fit on train, transform all splits ────────────────────────────────────
pca_model = PCA(n_components=100)
X_train_pca = pca_model.fit_transform(X_train_scaled)
X_val_pca   = pca_model.transform(X_val_scaled)
X_test_pca  = pca_model.transform(X_test_scaled)

print(f'PCA Train shape: {X_train_pca.shape} | Cumulative variance: '
      f'{np.sum(pca_model.explained_variance_ratio_)*100:.1f}%')

# ── Train Extra Trees on PCA-compressed features ───────────────────────────────
et_pca = ExtraTreesClassifier(
    n_estimators=500,
    class_weight='balanced',   # compensates for any class imbalance
    random_state=92
).fit(X_train_pca, y_train)

print('\nModel trained:', et_pca)

# ── Evaluate on validation set ────────────────────────────────────────────────
y_preds_pca = et_pca.predict(X_val_pca)
y_probs_pca = et_pca.predict_proba(X_val_pca)

val_acc_pca = accuracy_score(y_true=y_val, y_pred=y_preds_pca)
val_auc_pca = roc_auc_score(y_true=y_val, y_score=y_probs_pca,
                             average='weighted', multi_class='ovr')

# Per-class metrics
report_pca = classification_report(y_val, y_preds_pca,
                                    target_names=enc.classes_, output_dict=True)
c_pca = (pd.DataFrame(report_pca).T.iloc[:10, :]
          [['f1-score', 'precision', 'recall', 'support']])
val_f1_pca = c_pca['f1-score'].mean()

print(f'\n── PCA Model Performance ───────────────────────────────')
print(f'Accuracy  = {val_acc_pca*100:.2f}%')
print(f'F1-Score  = {val_f1_pca*100:.2f}%')
print(f'AUC (OvR) = {val_auc_pca:.4f}')

# Styled per-class breakdown
c_pca[['f1-score', 'precision', 'recall']] = c_pca[['f1-score', 'precision', 'recall']].mul(100)
(c_pca.sort_values('f1-score', ascending=False).style
   .background_gradient(cmap='flare_r', subset=['f1-score'])
   .format({'f1-score': '{:,.1f}%', 'precision': '{:,.1f}%',
            'recall': '{:,.1f}%', 'support': '{:,.0f}'}))


### 5.3 ROC Curves — PCA Model

The **Receiver Operating Characteristic (ROC)** curve plots the true positive rate (sensitivity)
against the false positive rate (1 − specificity) at various decision thresholds.

For multi-class classification we use the **One-vs-Rest (OvR)** strategy: each species is
treated as the positive class while the rest are negative. The **Area Under the Curve (AUC)**
summarizes performance in a single number — a perfect classifier achieves AUC = 1.0.


In [None]:
# ── Compute per-species ROC curves ────────────────────────────────────────────
fpr, tpr, roc_auc_dict, thresh = {}, {}, {}, {}

# Species sorted by F1 (best to worst) for legend ordering
species_sorted = c_pca.sort_values('f1-score', ascending=False).index.str.replace('_', ' ')

for i, sp in enumerate(species_sorted):
    fpr[i], tpr[i], thresh[i] = roc_curve(y_val, y_probs_pca[:, i], pos_label=i)
    roc_auc_dict[i] = auc(fpr[i], tpr[i])

# ── Plot ──────────────────────────────────────────────────────────────────────
fig = go.Figure()
for (i, sp), color in zip(enumerate(species_sorted), pal):
    fig.add_trace(go.Scatter(
        x=fpr[i], y=tpr[i],
        line=dict(color=color, width=3), opacity=0.75,
        hovertemplate='TPR = %{y:.3f}  |  FPR = %{x:.3f}',
        name=f'{sp}  (AUC = {roc_auc_dict[i]:.3f})'))

fig.update_layout(
    template=temp,
    title='Multiclass ROC Curves — PCA Model (One-vs-Rest)',
    hovermode='x unified',
    xaxis_title='False Positive Rate (1 − Specificity)',
    yaxis_title='True Positive Rate (Sensitivity)',
    legend=dict(y=0.1, x=0.98, xanchor='right',
                bordercolor='black', borderwidth=0.5, font=dict(size=11)),
    height=550, width=700)
fig.show()


### 5.4 PCA Model Test Predictions

We apply the trained PCA model to the held-out test set and visualise the
predicted species distribution.


In [None]:
# ── Generate test predictions ─────────────────────────────────────────────────
test_preds_pca = et_pca.predict(X_test_pca)
target_pca     = enc.inverse_transform(test_preds_pca)

sub_pca = pd.DataFrame({
    'row_id': range(int(2e5), int(3e5)),
    'target': target_pca
})

# ── Visualize predicted distribution ─────────────────────────────────────────
bact_pca = (sub_pca.target.value_counts(normalize=True)
                          .reset_index()
                          .rename(columns={'proportion': 'proportion'}))
bact_pca['proportion'] = bact_pca['proportion'].mul(100)
bact_pca['target']     = bact_pca['target'].str.replace('_', ' ')
bact_pca = bact_pca.sort_values('proportion', ascending=False)

fig = px.bar(bact_pca, x='target', y='proportion', text='proportion',
             color='target', color_discrete_sequence=pal, opacity=0.8)
fig.update_traces(texttemplate='%{text:,.2f}%', textposition='outside',
                  marker_line=dict(width=1, color='#28221D'))
fig.update_yaxes(visible=False, showticklabels=False)
fig.update_layout(template=temp,
                  title_text='Predicted Species Distribution — PCA Model (100 Components)',
                  xaxis=dict(title='', tickangle=25, showline=True),
                  height=450, width=700, showlegend=False)
fig.show()

# ── Save PCA submission ───────────────────────────────────────────────────────
sub_pca.to_csv('submission_pca.csv', index=False)
print('PCA submission saved → submission_pca.csv')
sub_pca.head()


### 5.5 Model B — Extra Trees with All 286 Features

As a comparison baseline, we train the same Extra Trees classifier using the **full
feature space** (all 286 original gene features after Box-Cox + standardization).

This answers the question: *Does PCA-based compression cost us meaningful predictive
accuracy, or does it preserve the discriminative signal?*


In [None]:
# ── Train Extra Trees on all 286 scaled features ──────────────────────────────
et_all = ExtraTreesClassifier(
    n_estimators=500,
    class_weight='balanced',
    random_state=21
).fit(X_train_scaled, y_train)

print('Model trained:', et_all)

# ── Evaluate on validation set ────────────────────────────────────────────────
y_preds_all = et_all.predict(X_val_scaled)
y_probs_all = et_all.predict_proba(X_val_scaled)

val_acc_all = accuracy_score(y_true=y_val, y_pred=y_preds_all)
val_auc_all = roc_auc_score(y_true=y_val, y_score=y_probs_all,
                             average='weighted', multi_class='ovr')
report_all  = classification_report(y_val, y_preds_all,
                                     target_names=enc.classes_, output_dict=True)
c_all   = (pd.DataFrame(report_all).T.iloc[:10, :]
             [['f1-score', 'precision', 'recall', 'support']])
val_f1_all = c_all['f1-score'].mean()

print(f'\n── Full Feature Model Performance ───────────────────────')
print(f'Accuracy  = {val_acc_all*100:.2f}%')
print(f'F1-Score  = {val_f1_all*100:.2f}%')
print(f'AUC (OvR) = {val_auc_all:.4f}')

c_all[['f1-score', 'precision', 'recall']] = c_all[['f1-score', 'precision', 'recall']].mul(100)
(c_all.sort_values('f1-score', ascending=False).style
   .background_gradient(cmap='flare_r', subset=['f1-score'])
   .format({'f1-score': '{:,.1f}%', 'precision': '{:,.1f}%',
            'recall': '{:,.1f}%', 'support': '{:,.0f}'}))


## 6. Summary & Final Predictions

### Model Comparison

| Configuration | Features | Accuracy | F1-Score | AUC |
|---------------|----------|----------|----------|-----|
| Extra Trees + PCA | 100 PCs | *run to see* | *run to see* | *run to see* |
| Extra Trees (Full) | 286 genes | *run to see* | *run to see* | *run to see* |

### Key Takeaways
- **PCA dimensionality reduction** compresses 286 gene features into 100 principal components
  while retaining the majority of the variance — significantly speeding up training.
- The **2D PCA projection** (PC1 vs PC2) reveals naturally separable clusters for most
  bacterial species, confirming that the gene expression patterns are genuinely discriminative.
- **Box-Cox transformation** and **standardization** are essential preprocessing steps;
  they ensure PCA components are not dominated by a handful of extreme-valued genes.
- **Gene loadings analysis** reveals which specific gene segments drive each principal
  component, bridging the gap between the statistical model and biological interpretation.


In [None]:
# ── Generate final test predictions (full-feature model) ──────────────────────
test_preds_all = et_all.predict(X_test_scaled)
target_all     = enc.inverse_transform(test_preds_all)

sub_all = pd.DataFrame({
    'row_id': range(int(2e5), int(3e5)),
    'target': target_all
})

# ── Visualize predicted distribution ─────────────────────────────────────────
bact_all = (sub_all.target.value_counts(normalize=True)
                          .reset_index()
                          .rename(columns={'proportion': 'proportion'}))
bact_all['proportion'] = bact_all['proportion'].mul(100)
bact_all['target']     = bact_all['target'].str.replace('_', ' ')
bact_all = bact_all.sort_values('proportion', ascending=False)

fig = px.bar(bact_all, x='target', y='proportion', text='proportion',
             color='target', color_discrete_sequence=pal, opacity=0.8)
fig.update_traces(texttemplate='%{text:,.2f}%', textposition='outside',
                  marker_line=dict(width=1, color='#28221D'))
fig.update_yaxes(visible=False, showticklabels=False)
fig.update_layout(template=temp,
                  title_text='Predicted Species Distribution — Full Feature Model (286 Genes)',
                  xaxis=dict(title='', tickangle=25, showline=True),
                  height=450, width=700, showlegend=False)
fig.show()

# ── Save full-feature submission ──────────────────────────────────────────────
sub_all.to_csv('submission_full.csv', index=False)
print('Full-feature submission saved → submission_full.csv')
sub_all.head()
