In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
from os import path
path = "/content/drive/MyDrive/Colab Notebooks/DISSERTATION/MSN/Logistic_Eval/logs_1%"
print(os.path.exists(path))

True


In [None]:
import torch, numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns

# load features
train_ckpt = torch.load(path + '/train-features-pathmnist-train-frac0.01-seed42-msn-pathmnist-train-latest.pth.tar.pth.tar', map_location='cpu')
test_ckpt  = torch.load(path + '/test-features-msn-pathmnist-train-latest.pth.tar.pth.tar',  map_location='cpu')

X_train = train_ckpt['embs'].numpy().astype(np.float32)
y_train = train_ckpt['labs'].numpy().astype(np.int64)
X_test  = test_ckpt['embs'].numpy().astype(np.float32)
y_test  = test_ckpt['labs'].numpy().astype(np.int64)

print('Train:', X_train.shape, y_train.shape)
print('Test :', X_test.shape,  y_test.shape)

Train: (900, 384) (900,)
Test : (7180, 384) (7180,)


In [None]:
# normalization to mirror HPC run (row-wise centering + L2)
NORMALIZE = True
if NORMALIZE:
    def row_center_l2(X):
        X = X - X.mean(axis=1, keepdims=True)
        n = np.linalg.norm(X, axis=1, keepdims=True) + 1e-12
        return X / n
    X_train = row_center_l2(X_train)
    X_test  = row_center_l2(X_test)

# logistic regression
lambd = 0.0025
C = max(1e-6, len(X_train) / lambd)

clf = LogisticRegression(
    solver='saga',
    multi_class='multinomial',
    fit_intercept=False,
    max_iter=2000,
    tol=1e-3,
    C=C,
    n_jobs=-1,
    penalty='l2'
)
clf.fit(X_train, y_train)

In [None]:
y_pred = clf.predict(X_test)

# confusion matrix
cm = confusion_matrix(y_test, y_pred)           # raw counts
classes = [f"{i}" for i in range(cm.shape[0])]  # 0-8 for PathMNIST

def plot_cm_sns(mat, title, cmap, fmt):
    plt.figure(figsize=(8, 6))
    sns.heatmap(mat,
                annot=True,
                fmt=fmt,
                cmap=cmap,
                cbar=True,
                xticklabels=classes,
                yticklabels=classes,
                square=False
                )

    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(title)
    plt.tight_layout()
    plt.show()

# Counts heatmap (blue)
plot_cm_sns(cm, "MSN - Confusion Matrix (1% of of Training Labels Used)", cmap="Blues", fmt="d")

# metrics
y_pred = clf.predict(X_test)
acc = accuracy_score(y_test, y_pred)
print(f"Test accuracy: {acc:.6f}\n")
print("MSN - Classification report (1% of Labels):\n", classification_report(y_test, y_pred, digits=6))

In [None]:
from sklearn.manifold import TSNE

# PathMNIST has 9 tissue classes
class_names = [
    "Label 0: Adipose",                             # 0
    "Label 1: Background",                          # 1
    "Label 2: Debris",                              # 2
    "Label 3: Lymphocytes",                         # 3
    "Label 4: Mucus",                               # 4
    "Label 5: Smooth Muscle",                       # 5
    "Label 6: Normal Colon Mucosa",                 # 6
    "Label 7: Cancer-Associated Stroma",            # 7
    "Label 8: Colorectal Adenocarcinoma Epithelium" # 8
]

# Row-wise centering + L2
def row_center_l2(A):
    A = A - A.mean(axis=1, keepdims=True)
    n = np.linalg.norm(A, axis=1, keepdims=True) + 1e-12
    return A / n

X_vis = row_center_l2(X_test)
y_vis = y_test

# t-SNE
X_embedded = TSNE(
    n_components=2,
    perplexity=30,
    metric='cosine',
    random_state=42,
    init='pca',
    learning_rate='auto',
    n_iter=1000
).fit_transform(X_vis)

# Plot
plt.figure(figsize=(12, 8))
sc = plt.scatter(X_embedded[:, 0], X_embedded[:, 1], c=y_vis, cmap="tab10", s=5)

# Legend with class names
handles, labels = sc.legend_elements()
labels = [class_names[int(lbl.strip('$\\mathdefault{}'))] for lbl in labels]
plt.legend(handles, labels, title="Classes", bbox_to_anchor=(1.01, 1), loc="upper left")

plt.title("t-SNE Visualization of MSN Test Feature Embeddings by Class")
plt.tight_layout()
plt.show()