In [1]:
import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix, classification_report


# Checking CUDA availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [2]:
data_root = "/home/kgulbarg/thesis/INR_SCHIC/Nagano_Data"
LABEL_TO_ID = {"G1": 0, "early_S": 1, "mid_S": 2, "late_S": 3}
emb_dim = 256

In [3]:
X, y = [], []

for label in sorted(os.listdir(data_root)):
    label_dir = os.path.join(data_root, label)

    # only real label folders
    if not os.path.isdir(label_dir) or label not in LABEL_TO_ID:
        continue

    # iterate files inside the label folder
    for name in sorted(os.listdir(label_dir)):
        if not (name.endswith("_latent.npy") and
                os.path.isfile(os.path.join(label_dir, name))):
            continue

        latent_path = os.path.join(label_dir, name)
        vec = np.load(latent_path).ravel()
        X.append(vec)
        y.append(LABEL_TO_ID[label])

    print(f"Collected latent representations for {label} cells.")


X = np.vstack(X)
y = np.array(y)     

print(X.shape)
print(y.shape)

Collected latent representations for G1 cells.
Collected latent representations for early_S cells.
Collected latent representations for late_S cells.
Collected latent representations for mid_S cells.
(1171, 256)
(1171,)


In [4]:
def train_classifier(latent_vecs, labels, epochs=500, lr=5e-3, val_split=0.2, seed=0):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    X = torch.as_tensor(latent_vecs, dtype=torch.float32, device=device)
    y = torch.as_tensor(labels,     dtype=torch.long,   device=device).view(-1)

    class Classifier(nn.Module):
        def __init__(self, emb_dim=256):
            super().__init__()
            self.norm = nn.LayerNorm(emb_dim)
            self.fc1  = nn.Linear(emb_dim, 128)
            self.fc2  = nn.Linear(128, 64)
            self.fc3  = nn.Linear(64, 4)
            self.drop = nn.Dropout(0.1)    # avoid overfitting
        def forward(self, x):
            x = self.norm(x)
            x = F.relu(self.fc1(x)); x = self.drop(x)
            x = F.relu(self.fc2(x))
            return self.fc3(x)   # logits

    # stratified split
    idx_tr_list, idx_val_list = [], []
    cpu_y = y.cpu().numpy()
    for c in range(4):
        idxs = np.where(cpu_y == c)[0]
        rng = np.random.default_rng(seed)
        rng.shuffle(idxs)
        cut_c = int((1 - val_split) * len(idxs))
        idx_tr_list.append(torch.as_tensor(idxs[:cut_c], device=device))
        idx_val_list.append(torch.as_tensor(idxs[cut_c:], device=device))
    idx_tr = torch.cat(idx_tr_list); idx_val = torch.cat(idx_val_list)
    Xtr, Xval = X[idx_tr], X[idx_val]; ytr, yval = y[idx_tr], y[idx_val]

    # train
    model = Classifier().to(device)
    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        opt.zero_grad()
        loss = loss_fn(model(Xtr), ytr)
        loss.backward(); opt.step()
        # print(f"epoch {epoch}, loss {loss.item():.4f}")

    # evaluate
    with torch.no_grad():
        train_acc = (model(Xtr).argmax(1) == ytr).float().mean().item()
        val_logits = model(Xval)
        val_acc = (val_logits.argmax(1) == yval).float().mean().item()
        yp = val_logits.argmax(1).cpu().numpy()

    print(confusion_matrix(yval.cpu().numpy(), yp))
    print(classification_report(yval.cpu().numpy(), yp, digits=3))

    return model, train_acc, val_acc


In [5]:
model, train_acc, val_acc = train_classifier(X, y)
print(f"train acc = {train_acc:.3f}, val acc = {val_acc:.3f}")

[[14  9 16 17]
 [ 9 15 17 20]
 [13  8 13 19]
 [18 17 13 18]]
              precision    recall  f1-score   support

           0      0.259     0.250     0.255        56
           1      0.306     0.246     0.273        61
           2      0.220     0.245     0.232        53
           3      0.243     0.273     0.257        66

    accuracy                          0.254       236
   macro avg      0.257     0.253     0.254       236
weighted avg      0.258     0.254     0.255       236

train acc = 1.000, val acc = 0.254
