In [None]:
import os
import pprint
import pandas as pd
import matplotlib.pyplot as plt
import random
import torch
torch.cuda.is_available()
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay

from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import Dataset, CacheDataset, DataLoader, ThreadDataLoader
from monai.metrics import ROCAUCMetric
from monai.networks.nets import DenseNet169
from monai.transforms import (
    Activations,
    AsDiscrete,
    EnsureChannelFirstd,
    EnsureTyped,
    Compose,
    ConcatItemsd,
    LoadImaged,
    RandFlipd,
    RandRotate90d,
    RandZoomd,
    ScaleIntensityRangePercentilesd,
)
from monai.utils import set_determinism
from sklearn.preprocessing import LabelEncoder

print_config()

In [None]:
torch.cuda.is_available()

In [None]:
# Specify dataset path holding training and validation datasets of a nuclear protein
base_path = "./QUAC_cortex_types/H3K4me1/"            # ./QUAC_cortex_types/mH2A1/
train_dir = os.path.join(base_path, "train")
val_dir = os.path.join(base_path, "val")

# List all class names (sub-directory names)
classes = sorted(os.listdir(train_dir))  # ['astrocyte', 'neuron']
class_map = {c: idx for idx, c in enumerate(classes)}
num_classes = len(classes)

# Encode labels using class_map
def encode_label(path):
    class_name = os.path.basename(os.path.dirname(path))
    return class_map[class_name]

# Build a list of image paths and their corresponding labels
def build_datalist(data_dir):
    datalist = []
    for class_name in os.listdir(data_dir):
        class_dir = os.path.join(data_dir, class_name)
        for image_name in os.listdir(class_dir):
            image_path = os.path.join(class_dir, image_name)
            label = encode_label(image_path)
            datalist.append({"image": image_path, "label": label})
    return datalist

train_datalist = build_datalist(train_dir)
val_datalist = build_datalist(val_dir)

In [None]:
pp = pprint.PrettyPrinter()
pp.pprint(train_datalist[:3])

In [None]:
# Define your folder-to-class mapping at the top
folder_to_class = {
    "0": "CTX-Glut",
    "5": "CTX-Olig",
    "8": "CTX-GABA",
    "11": "CTX-Astro"
}

# Create the class_map that maps folder names to sequential indices
classes = sorted(os.listdir(train_dir))  # ['0', '11', '5', '8']
class_map = {c: idx for idx, c in enumerate(classes)}  # '0'->0, '11'->1, '5'->2, '8'->3

# Create inverse mapping for visualization: sequential index -> class name
class_map_inv = {idx: folder_to_class[folder_name] for folder_name, idx in class_map.items()}
# This gives: {0: 'CTX-Glut', 1: 'CTX-Astro', 2: 'CTX-Olig', 3: 'CTX-GABA'}

num_classes = len(classes)
class_names = list(class_map_inv.values())

# Rest of your code stays the same...

In [None]:
print(class_map)
print(class_names)

In [None]:
# visualize test batch with normalization
set_determinism(seed=0)
transforms_visualize = Compose(
    [
        LoadImaged(keys=['image'], reader="PILReader", image_only=True),
        EnsureChannelFirstd(keys=['image']),
        ScaleIntensityRangePercentilesd(
            keys=['image'], lower=1.0, upper=99.0, b_min=0.0, b_max=1.0, clip=True
        ),
        # ConcatItemsd(keys=["image"], name="image", dim=0),
    ]
)

batch_size_viz = 6
viz_ds = Dataset(train_datalist, transform=transforms_visualize)
viz_loader = DataLoader(viz_ds, batch_size=batch_size_viz, shuffle=True, num_workers=0)
batch_data = next(iter(viz_loader))

fig, axs = plt.subplots(1, batch_size_viz, figsize=(15,5))   #(10, 10 * batch_size_viz), dpi=100)
for idx in range(batch_size_viz):
    img = batch_data['image'][idx].squeeze().numpy()  # Remove channel dimension if needed
    axs[idx].imshow(img, cmap="gray")
    axs[idx].axis("off")
    label = int(batch_data['label'][idx])
    label_name = class_map_inv.get(label, f"Unknown ({label})")
    axs[idx].set_title(label_name)
fig.tight_layout()
plt.show()

In [None]:
num_classes = 4 # Replace with your actual number of classes.

# Keys for your dataset
keys = ["image"]

# Training Transforms
transforms_train = Compose(
    [
        LoadImaged(keys=keys, reader="PILReader", image_only=True),
        EnsureChannelFirstd(keys=keys),
        ScaleIntensityRangePercentilesd(
            keys=keys, lower=1.0, upper=99.0, b_min=0.0, b_max=1.0, clip=True
        ),
        ConcatItemsd(keys=keys, name="image", dim=0),
        EnsureTyped(keys=["image", "label"], track_meta=False),
        RandRotate90d(keys=["image"], prob=0.75),
        RandFlipd(keys=["image"], spatial_axis=[0, 1], prob=0.5),
        RandZoomd(keys=["image"], min_zoom=0.9, max_zoom=1.1, prob=0.5),
    ]
)

# Validation Transforms
transforms_val = Compose(
    [
        LoadImaged(keys=keys, reader="PILReader", image_only=True),
        EnsureChannelFirstd(keys=keys),
        ScaleIntensityRangePercentilesd(
            keys=keys, lower=1.0, upper=99.0, b_min=0.0, b_max=1.0, clip=True
        ),
        ConcatItemsd(keys=keys, name="image", dim=0),
    ]
)

# Output Transformations
y_pred_trans = Compose([Activations(softmax=True)])
y_trans = Compose([AsDiscrete(to_onehot=num_classes)])

# Training Dataset and DataLoader
batch_size_train = 8
train_ds = CacheDataset(data=train_datalist, transform=transforms_train, num_workers=10)
train_loader = ThreadDataLoader(train_ds, batch_size=batch_size_train, shuffle=True)

# Validation Dataset and DataLoader
val_ds = CacheDataset(data=val_datalist, transform=transforms_val, num_workers=10)
val_loader = ThreadDataLoader(val_ds, batch_size=batch_size_train, shuffle=True)

In [None]:
# Load DenseNet169 model for training
device = "cuda:0"
model = DenseNet169(spatial_dims=2, in_channels=1, out_channels=num_classes, pretrained=True)

In [None]:
# Or load a trained model for further training

device = torch.device("cuda:0")

# Create the same model architecture as before
model = DenseNet169(
    spatial_dims=2,   # 2D images
    in_channels=1,    # single-channel images
    out_channels=4,
    pretrained=False  # Turn off ImageNet weights when resuming your own
)

# Load your trained weights
model.load_state_dict(torch.load("QUAC_model/all_types_CTX_model/H3K4me1_55epochs_model.pth", map_location=device))

model.to(device)
model.train() 

In [None]:
model.to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 5e-5)   # training rate: 5e-5
max_epochs = 50   # specify number of epochs to train
val_interval = 1
auc_metric = ROCAUCMetric()

In [None]:
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []

for epoch in range(max_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs = batch_data["image"].to(device)
        labels = batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print(f"{step}/{len(train_ds) // train_loader.batch_size}, " f"train_loss: {loss.item():.4f}")
        epoch_len = len(train_ds) // train_loader.batch_size
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

In [None]:
# evaluate the trained classification model

batch_size_test = 32
test_ds = Dataset(val_datalist, transform=transforms_val)
test_loader = DataLoader(test_ds, batch_size=batch_size_test, shuffle=True, num_workers=4)

model.eval()
y_true = []
y_pred = []
with torch.no_grad():
    for test_data in test_loader:
        test_images, test_labels = (
            test_data["image"].to(device),
            test_data["label"].to(device),
        )
        pred = model(test_images).argmax(dim=1)
        for i in range(len(pred)):
            y_true.append(test_labels[i].item())
            y_pred.append(pred[i].item())

print(classification_report(y_true, y_pred, target_names=class_names, digits=4))

In [None]:
# save trained model
torch.save(model.state_dict(), "QUAC_model/all_types_CTX_model/H3K4me1_model.pth")

In [None]:
# Convert model to torch-script for QuAC
model_script = torch.jit.script(model)
model_script.save("QUAC_model/all_types_CTX_model/H3K4me1_jit.pt")