In [1]:
import pandas as pd
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from latentmi import lmi
import os
import numpy as np

In [2]:
# load held-out labels and images
val_data = pd.read_csv("../embeddings/MNIST/held_out.csv")
val_labels = val_data["label"].values

# function to compute accuracy of a linear probe!
def compute_linear_probe_accuracy(epoch):
    # load embeddings for the given epoch
    embeddings_file = f"../embeddings/MNIST/AE_epoch{epoch}.csv"
    embeddings = pd.read_csv(embeddings_file, header=None).values

    # split into train and test
    X_train, X_test, y_train, y_test = train_test_split(embeddings[:, 1:], embeddings[:, 0], test_size=0.2, random_state=42)

    # train logistic regression
    clf = LogisticRegression(max_iter=1000, solver='lbfgs', multi_class='multinomial')
    clf.fit(X_train, y_train)

    # evaluate on test set
    y_pred = clf.predict(X_test)
    acc = accuracy_score(y_test, y_pred)
    print(f"Accuracy of linear probe for epoch {epoch}: {acc:.4f}")
    return acc

def MI_probe(epoch):
    embeddings_file = f"../embeddings/MNIST/AE_epoch{epoch}.csv"
    embeddings = pd.read_csv(embeddings_file, header=None).values
    X, y = embeddings[:, 1:], embeddings[:, 0]
    y_onehot = np.random.normal(size=(len(y), 10))*0.1
    y_onehot[np.arange(len(y), dtype=int), y.astype(int)] = 1
    return np.nanmean(lmi.estimate(X, y_onehot)[0])


MIs = []
linears = []

for epoch in range(1, 51, 3):  # assuming up to 50 epochs
    embeddings_file = f"../embeddings/MNIST/AE_epoch{epoch}.csv"
    if not os.path.exists(embeddings_file):
        print(f"No embeddings found for epoch {epoch}. Stopping.")
        break
    linears.append(compute_linear_probe_accuracy(epoch))
    MIs.append(MI_probe(epoch))


STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy of linear probe for epoch 1: 0.5342
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy of linear probe for epoch 4: 0.6597
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy of linear probe for epoch 7: 0.7445
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy of linear probe for epoch 10: 0.7633
epoch 292 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy of linear probe for epoch 13: 0.7780
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


Accuracy of linear probe for epoch 16: 0.7925
epoch 256 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻Accuracy of linear probe for epoch 19: 0.8063
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻Accuracy of linear probe for epoch 22: 0.8158
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻Accuracy of linear probe for epoch 25: 0.8213
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻Accuracy of linear probe for epoch 28: 0.8248
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻Accuracy of linear probe for epoch 31: 0.8297
epoch 299 (of max 300) 🌻🌻🌻🌻🌻🌻🌻🌻🌻No embeddings found for epoch 34. Stopping.


In [3]:
d = {"epoch": range(1, len(MIs)+1), "MI": MIs, "linear": linears}
df = pd.DataFrame(d)
df

Unnamed: 0,epoch,MI,linear
0,1,1.50029,0.534167
1,2,2.325634,0.659667
2,3,2.655746,0.7445
3,4,2.753342,0.763333
4,5,2.769338,0.778
5,6,2.855195,0.7925
6,7,2.873935,0.806333
7,8,2.893156,0.815833
8,9,2.905273,0.821333
9,10,2.915986,0.824833


In [4]:
df.to_csv("../results/MI_vs_linear_probe.csv", index=False)