Purpose: run VAE+AA for Harvard-GDP datasets

# Import libaries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import matplotlib.pyplot as plt
import matplotlib

matplotlib.rcParams['font.family'] = 'Arial'
matplotlib.rcParams['font.weight'] = 'bold'
matplotlib.rcParams['font.size'] = 18
matplotlib.rcParams['figure.titlesize'] = 18
# matplotlib.rcParams['font.style'] = 'italic'

# utility
import time
import copy as cp
from tqdm import tqdm
from collections import defaultdict
import glob
import os
import sys
from sys import stderr
import warnings
warnings.simplefilter('ignore')

# 乱数
rng = np.random.RandomState(42)
random_state = 42

# %cd your path
# %pwd

# Dataloader for VAE

In [None]:
oct_name = "pRNFL_Harvard"
df_oct = pd.read_csv(f"datasets/data_all_sample.csv")

# add columns
df_oct["mrt"] = df_oct.iloc[:, 7:683].astype(float).mean(axis=1)
df_oct["ageclass"] = df_oct["age"].apply(lambda x: x // 5)

print(df_oct.shape)
df_oct.head()

In [None]:
from sklearn.model_selection import train_test_split

# no split
df_train, df_test = df_oct, df_oct
df_train.to_csv("datasets/data_train.csv", index=False)
df_test.to_csv("datasets/data_test.csv", index=False)

print("train", df_train.shape)
print("test", df_test.shape)
df_train.race.value_counts(normalize=True), df_test.race.value_counts(normalize=True)

df_train.head(1)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

class CustomDataset(Dataset):
    def __init__(self, dataframe):
        self.data = dataframe

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        X = item.iloc[7:683].values.astype(float)
        vmin,vmax = (0,240)
        X = (X - vmin) / (vmax - vmin)
        X = np.clip(X, 0, 1)
        X = X.reshape(1, 26, 26)

        ID = float(item['ID'])
        ageclass = float(item['ageclass'])
        sex = float(item['sex'])
        mrt = float(item['mrt'])
        
        return X, ID, ageclass, sex, mrt

def custom_collate(batch):
    X, ID, ageclass, sex, mrt = zip(*batch)
    X = torch.tensor(X, dtype=torch.float32)
    ID = torch.tensor(ID, dtype=torch.float32)
    ageclass = torch.tensor(ageclass, dtype=torch.float32)
    sex = torch.tensor(sex, dtype=torch.float32)
    mrt = torch.tensor(mrt, dtype=torch.float32)
    return X, ID, ageclass, sex, mrt


def dataframe2dataloader(df, batch_size, shuffle=False, random_state=42):
    dataset = CustomDataset(df)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=custom_collate)
    return dataset, dataloader

In [None]:
batch_size = 16
train_dataset, train_loader = dataframe2dataloader(df_train, batch_size, shuffle=True, random_state=random_state)
valid_dataset, valid_loader = dataframe2dataloader(df_test, batch_size, shuffle=False, random_state=random_state)

for X, ID, ageclass, sex, mrt in train_loader:
    print(f"X Shape: {X.shape}")
    print(f"ID: {ID.shape}")
    print(f"Ageclass: {ageclass.shape}")
    print(f"Sex: {sex.shape}")
    print(f"mRT: {mrt.shape}")
    break          

for X, ID, ageclass, sex, mrt in valid_loader:
    print(f"X Shape: {X.shape}")
    print(f"ID: {ID.shape}")
    print(f"Ageclass: {ageclass.shape}")
    print(f"Sex: {sex.shape}")
    print(f"mRT: {mrt.shape}")
    break

print()
print("train", len(train_dataset))
print("valid", len(valid_dataset))

# VAE model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.autograd as autograd
import torch.nn.functional as F
from torchvision import datasets, transforms, models

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.utils import shuffle
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from typing import Tuple

rng = np.random.RandomState(1234)
random_state = 42
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

class CNNEncoder(nn.Module):
    def __init__(self, z_dim, h_dim=128, indim=(26, 26), c_init=1, outdim_conv=8):
        super(CNNEncoder, self).__init__()

        self.conv1 = nn.Conv2d(c_init, 2, kernel_size=3, stride=1, padding=1, bias=False)
        self.ln1 = nn.LayerNorm([2, indim[0], indim[1]])
        
        self.conv2 = nn.Conv2d(2, 4, kernel_size=3, stride=1, padding=1, bias=False)
        self.ln2 = nn.LayerNorm([4, indim[0], indim[1]])
        
        self.conv3 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1, bias=False)
        self.ln3 = nn.LayerNorm([4, indim[0], indim[1]])
        
        self.conv4 = nn.Conv2d(4, outdim_conv, kernel_size=3, stride=1, padding=1, bias=False)
        self.ln4 = nn.LayerNorm([outdim_conv, indim[0], indim[1]])

        self.fc1 = nn.Linear(outdim_conv * indim[0] * indim[1], h_dim)
        self.dropout = nn.Dropout(p=0.2) 
        
        self.fc_mean = nn.Linear(h_dim, z_dim)
        self.fc_var = nn.Linear(h_dim, z_dim)

    def forward(self, x):
        x = F.relu(self.ln1(self.conv1(x)))
        x = F.relu(self.ln2(self.conv2(x))) 
        x = F.relu(self.ln3(self.conv3(x)))  
        x = F.relu(self.ln4(self.conv4(x))) 

        x = x.view(x.size(0), -1)

        x = F.relu(self.fc1(x))
        x = self.dropout(x) 
        
        mean = self.fc_mean(x)
        log_var = self.fc_var(x)
        
        return mean, log_var

class CNNDecoder(nn.Module):
    def __init__(self, z_dim, h_dim=128, indim=(26, 26), c_init=1, indim_deconv=8):
        super(CNNDecoder, self).__init__()

        self.fc1 = nn.Linear(z_dim, h_dim)
        self.fc2 = nn.Linear(h_dim, indim_deconv * indim[0] * indim[1])
        self.dropout = nn.Dropout(p=0.2) 
        
        self.deconv1 = nn.ConvTranspose2d(indim_deconv, 4, kernel_size=3, stride=1, padding=1)
        self.ln1 = nn.LayerNorm([4, indim[0], indim[1]])
        
        self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=3, stride=1, padding=1)
        self.ln2 = nn.LayerNorm([4, indim[0], indim[1]])
        
        self.deconv3 = nn.ConvTranspose2d(4, 2, kernel_size=3, stride=1, padding=1)
        self.ln3 = nn.LayerNorm([2, indim[0], indim[1]])
        
        self.deconv4 = nn.ConvTranspose2d(2, c_init, kernel_size=3, stride=1, padding=1)

    def forward(self, z, indim=(26, 26), indim_deconv=8):
        z = F.relu(self.fc1(z))
        z = F.relu(self.fc2(z))
        z = self.dropout(z)  
        
        z = z.view(z.size(0), indim_deconv, indim[0], indim[1])
        
        z = F.relu(self.ln1(self.deconv1(z)))
        z = F.relu(self.ln2(self.deconv2(z)))
        z = F.relu(self.ln3(self.deconv3(z)))
        reconstruction = torch.sigmoid(self.deconv4(z))
        
        return reconstruction

class VAE_CNN(nn.Module):
    def __init__(self, z_dim=64, h_dim=128):
        super(VAE_CNN, self).__init__()
        self.encoder = CNNEncoder(z_dim, h_dim)
        self.decoder = CNNDecoder(z_dim, h_dim)

    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mean + eps * std

    def forward(self, x):
        mean, log_var = self.encoder(x)
        z = self.reparameterize(mean, log_var)
        return self.decoder(z), mean, log_var

    def loss(self, x):
        recon_x, mean, log_var = self.forward(x)
        BCE = F.binary_cross_entropy(recon_x, x, reduction='mean', size_average=False)
        KLD = -0.5 * torch.mean(1 + log_var - mean.pow(2) - log_var.exp())
        return KLD, BCE 

# Train VAE

Loading pretrained VAE model

In [None]:
z_dim = 32
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
save_path = "params/VAE/pretrained_for_pRNFL.pth"

model_vae = VAE_CNN(z_dim = z_dim).to(device)
model_vae.load_state_dict(torch.load(save_path))

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch.optim as optim

# configurement
n_epochs = 100
beta = 100
tolerance = 3  # Early stopping tolerance
optimizer = optim.Adam(model_vae.parameters(), lr=0.001)

train_losses = []
train_KL_losses = []
train_reconstruction_losses = []
val_losses = []

best_val_loss = float('inf')
epochs_no_improve = 0
best_model_state = None  

for epoch in range(n_epochs):
    losses = []
    KL_losses = []
    reconstruction_losses = []

    model_vae.train()
    with tqdm(total=len(train_loader), leave=False) as pbar:
        for i, (x, id, ageclass, sex, mrt) in enumerate(train_loader):
            x = x.float().to(device)  
            model_vae.zero_grad()

            KL_loss, reconstruction_loss = model_vae.loss(x)
            loss = beta * KL_loss + reconstruction_loss

            loss.backward()
            optimizer.step()

            # save loss
            losses.append(loss.cpu().detach().numpy())
            KL_losses.append(KL_loss.cpu().detach().numpy())
            reconstruction_losses.append(reconstruction_loss.cpu().detach().numpy())
            
            # progress bar
            pbar.set_postfix({'loss': loss.item(), 'KL_loss': KL_loss.item(), 'reconstruction_loss': reconstruction_loss.item()})
            pbar.update(1)

    train_losses.append(losses)
    train_KL_losses.append(KL_losses)
    train_reconstruction_losses.append(reconstruction_losses)

    losses_val = []
    model_vae.eval()
    with tqdm(total=len(valid_loader), leave=False) as pbar:
        for i, (x, id, ageclass, sex, mrt) in enumerate(valid_loader):
            x = x.float().to(device)  
            KL_loss, reconstruction_loss = model_vae.loss(x)
            loss = KL_loss + reconstruction_loss

            # save loss
            losses_val.append(loss.cpu().detach().numpy())
            pbar.update(1)

    val_losses.append(losses_val)

    # loss
    avg_train_loss = sum(losses) / len(losses)
    avg_val_loss = sum(losses_val) / len(losses_val)
    print(f"Epoch {epoch + 1}/{n_epochs} - Train Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

    # Early stopping
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0
        best_model_state = model_vae.state_dict().copy()
        print(f"Best validation loss updated to {best_val_loss:.4f}")
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= tolerance:
        print(f"Early stopping triggered after {tolerance} epochs without improvement")
        break

if best_model_state is not None:
    model_vae.load_state_dict(best_model_state)
    print(f"Restored model_vae to best state with validation loss: {best_val_loss:.4f}")


In [None]:
import torch
import datetime
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt

os.makedirs("params/VAE", exist_ok=True)
os.makedirs("figures", exist_ok=True)
os.makedirs("tables", exist_ok=True)

current_time = datetime.datetime.now().strftime('%Y%m%d%H%M')
avg_train_losses = [sum(epoch_losses)/len(epoch_losses) for epoch_losses in train_losses]
avg_train_KL_losses = [sum(epoch_losses)/len(epoch_losses) for epoch_losses in train_KL_losses]
avg_train_reconstruction_losses = [sum(epoch_losses)/len(epoch_losses) for epoch_losses in train_reconstruction_losses]
avg_val_losses = [sum(epoch_losses)/len(epoch_losses) for epoch_losses in val_losses]

best_epoch = len(avg_val_losses) - epochs_no_improve - 1 if epochs_no_improve > 0 else len(avg_val_losses) - 1
best_epoch = max(0, best_epoch)  

save_path = f"params/VAE/finetunig_for_pRNFL.pth"
torch.save(model_vae.state_dict(), save_path)

# figure
plt.figure(figsize=(12, 6))
epochs = list(range(1, len(avg_train_losses) + 1))
plt.plot(epochs, avg_train_losses, label='Train Loss')
plt.plot(epochs, avg_val_losses, label='Validation Loss')

plt.axvline(x=best_epoch+1, color='r', linestyle='--', 
            label=f'Best Model (Epoch {best_epoch+1})')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.xticks(range(1, len(avg_train_losses) + 1, 1))
plt.legend()
plt.tight_layout()
plt.savefig(f'figures/LOSS_model_{current_time}.pdf', transparent=True)

# csv
epochs = list(range(1, len(train_losses) + 1))
is_best = [False] * len(train_losses)
is_best[best_epoch] = True 

df_loss = pd.DataFrame({
    "epoch": epochs,
    "train_loss": avg_train_losses,
    "train_KL_loss": avg_train_KL_losses,
    "train_reconstruction_loss": avg_train_reconstruction_losses,
    "val_loss": avg_val_losses,
    "is_best_model": is_best
})

df_loss.to_csv(f"tables/LOSS_model_{current_time}.csv", index=False)

with open(f"tables/MODEL_INFO_{current_time}.txt", 'w') as f:
    f.write(f"Model: VAE\n")
    f.write(f"OCT Name: {oct_name}\n")
    f.write(f"z_dim: {z_dim}\n")
    f.write(f"beta: {beta}\n")
    f.write(f"Max Epochs: {n_epochs}\n")
    f.write(f"Actual Epochs: {len(train_losses)}\n")
    f.write(f"Best Epoch: {best_epoch+1}\n")
    f.write(f"Best Validation Loss: {avg_val_losses[best_epoch]:.6f}\n")

# Inference: reconstructed images

In [None]:
import matplotlib.colors as mcolors

def visualize(arr, vmin=0, vmax=1, figname=None):
    matrix = np.reshape(arr, (26, 26))

    base_cmap = plt.cm.RdYlBu_r
    colors = base_cmap(np.arange(base_cmap.N))
    new_cmap = mcolors.ListedColormap(colors)

    fig, ax = plt.subplots(figsize=(5, 5))
    im = ax.imshow(matrix, cmap=new_cmap, vmin=vmin, vmax=vmax)

    ax.set_xticks([])
    ax.set_yticks([])

    plt.title("")

    plt.savefig(f"figures/samples_VAE/{figname}.pdf", transparent=True)
    plt.close()

def show_ranges(arr): return arr.min(), arr.max()

In [None]:
import os
os.makedirs("figures/samples_VAE", exist_ok=True)

model_vae.eval()

n = len(valid_dataset)
# n = 100
for i in tqdm(range(n)):
    i *= 1

    x, ID, ageclass, sex, mrt = valid_dataset[i]

    x = torch.from_numpy(x).float()
    x = x.unsqueeze(0)
    x = x.to(device)

    ID = int(ID)

    y, z_mu, z_logvar = model_vae(x)
    KLD, BCE = model_vae.loss(x)

    x = x.cpu().detach().numpy()
    y = y.cpu().detach().numpy()
    BCE = BCE.cpu().detach().numpy().astype(int)

    visualize(x.reshape(676), vmin=0, vmax=1, figname=f"{oct_name}_{ID}_origin_{BCE}")
    visualize(y.reshape(676), vmin=0, vmax=1, figname=f"{oct_name}_{ID}_recon_{BCE}")



# Inference: datasets with Z

In [None]:
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import os

def extract_features_from_dataset(dataset, model, device, z_dim=32, prefix='data'):
    data = []
    
    for i, (x, id, ageclass, sex, mrt) in enumerate(tqdm(dataset, desc=f"Processing {prefix} data")):
        x = torch.from_numpy(x).float() if isinstance(x, np.ndarray) else x.float()
        x = x.unsqueeze(0) if len(x.shape) == 3 else x  # バッチ次元がない場合は追加
        x = x.to(device)
        
        with torch.no_grad():  
            y, z_mu, z_logvar = model_vae(x)
            kl_loss, bce_loss = model_vae.loss(x)
        
        id = int(id)
        ageclass = int(ageclass)
        sex = int(sex)
        mrt = float(mrt)
        z = z_mu.cpu().detach().numpy()[0]
        bce_loss = float(bce_loss.cpu().detach().numpy())
        
        data.append([id, ageclass, sex, mrt, bce_loss] + list(z))
    
    columns = ['ID', 'ageclass', 'sex', 'mrt', 'bce_loss'] + [f'z{i}' for i in range(1, z_dim + 1)]
    
    return pd.DataFrame(data, columns=columns)

model_vae.eval()  
z_dim = 32

# training data feature
df_train = extract_features_from_dataset(
    dataset=train_dataset, 
    model=model_vae, 
    device=device,
    z_dim=z_dim,
    prefix="train"
)

# save
train_path = f'datasets/data_train_zdim.csv'
df_train.to_csv(train_path, index=False)
print(f"Saved training features to {train_path}")


# validation data feature
df_valid = extract_features_from_dataset(
    dataset=valid_dataset,
    model=model_vae,
    device=device,
    z_dim=z_dim,
    prefix="validation"
)

# save
valid_path = f'datasets/data_test_zdim.csv'
df_valid.to_csv(valid_path, index=False)
print(f"Saved validation features to {valid_path}")



# AA model (with data filtering)

In [None]:
from archetypes import AA
from tqdm import tqdm

def perform_archetypal_analysis(X_train, X_test, n_archetypes=10, n_init=1, max_iter=1000, random_state=42):
    model_aa = AA(n_archetypes=n_archetypes, n_init=n_init, max_iter=max_iter, random_state=random_state)   
    model_aa.fit(X_train)
    
    A = model_aa.archetypes_
    RSS = model_aa.rss_
    Z2 = model_aa.transform(X_test)
    
    print("A shape:", A.shape, "RSS:", RSS)
    
    return A, Z2, RSS

filtering OCTs by BCE loss

In [None]:
oct_name = "pRNFL_Harvard"

df_train = pd.read_csv(f'datasets/data_train_zdim.csv')
df_test = pd.read_csv(f'datasets/data_test_zdim.csv')

def exclude_abnormal(df):
    mean_bce = df['bce_loss'].mean()
    std_bce = df['bce_loss'].std()
    df['z_score'] = (df['bce_loss'] - mean_bce) / std_bce
    
    df_filtered = df[(df['z_score'] > -1.96) & (df['z_score'] < 1.96)].copy()    
    df_filtered.drop(columns=['z_score'], inplace=True)
    
    return df_filtered

df_train_filtered = exclude_abnormal(df_train)
df_valid_filtered = exclude_abnormal(df_valid)

df_train_filtered.to_csv(f'datasets/data_train_zdim_filtered.csv', index=False)
df_valid_filtered.to_csv(f'datasets/data_test_zdim_filtered.csv', index=False)

df_train.shape, df_test.shape, df_train_filtered.shape, df_valid_filtered.shape

In [None]:
import os
os.makedirs("tables/AT", exist_ok=True)
os.makedirs("datasets/AT_data", exist_ok=True)
os.makedirs("datasets/AT_info", exist_ok=True)

df_train = df_train_filtered
df_test = df_valid_filtered

X_train = df_train.iloc[:, 5:].values
X_test = df_test.iloc[:, 5:].values

# AA model
RSS_lists = []
for k in range(3, 15):
    A, Z, RSS = perform_archetypal_analysis(X_train, X_test, n_archetypes=k, max_iter=1000)
    A_df = pd.DataFrame(A.T, columns=[f'A{i+1}' for i in range(k)])
    Z_df = pd.DataFrame(Z, columns=[f'A{i+1}' for i in range(k)])
    RSS_lists.append(RSS) 

    # save as csv
    A_df.to_csv(f"datasets/AT_info/{oct_name}_AT{k}_info.csv", index=False)
    pd.concat([df_test.reset_index(drop=True), Z_df.reset_index(drop=True)], axis=1).to_csv(f"datasets/AT_data/data_test_AT{k}.csv", index=False)

pd.DataFrame(RSS_lists, columns=['RSS']).to_csv(f"datasets/AT_info/{oct_name}_RSS.csv", index=False)


RSS plots

In [None]:
import os
os.makedirs("figures/AT", exist_ok=True)

RSS_lists = pd.read_csv(f"tables_revise/AT/{oct_name}_RSS.csv").values.flatten()

max_k = 15

# RSS plots
figure = plt.figure(figsize=(6, 4))
plt.plot(range(3, max_k), RSS_lists[:], marker='o', color="blue", lw=2, markersize=10)
plt.xlabel('Number of Archetypes')
plt.ylabel('RSS')
plt.xticks(range(3, max_k))

title_name = "mGCIPL" if oct_name == "mGCLP" else oct_name
plt.title(f'{title_name}', fontsize=20)

# plt.tight_layout()
plt.subplots_adjust(left=0.17, right=0.97, bottom=0.15, top=0.9)
plt.savefig(f"figures/AT/RSS_{oct_name}.pdf", transparent=True)

# Visualize AT

In [None]:
def visualize(arr, vmin=0, vmax=1, figname=None, title=None):
    matrix = np.reshape(arr, (26, 26))
    base_cmap = plt.cm.RdYlBu_r

    colors = base_cmap(np.arange(base_cmap.N))
    new_cmap = mcolors.ListedColormap(colors)

    fig, ax = plt.subplots(figsize=(5, 5))
    im = ax.imshow(matrix, cmap=new_cmap, vmin=vmin, vmax=vmax)

    ax.set_xticks([])
    ax.set_yticks([])

    plt.title(f"{title}", fontsize=40)

    plt.savefig(f"figures/AT/AT_maps_K={k}/{figname}.pdf", transparent=True)
    plt.close()

def show_ranges(arr): return arr.min(), arr.max()

In [None]:
k =  12
oct_name = "pRNFL_Harvard"

import os
os.makedirs(f'figures/AT/AT_maps_K={k}', exist_ok=True)

A_df = pd.read_csv(f"datasets/AT_info/pRNFL_Harvard_AT{k}_info.csv")

for idx in range(k):
    z = A_df.values[:, idx]
    z = torch.from_numpy(z).float()
    z = z.unsqueeze(0)
    z = z.to(device)

    y = model_vae.decoder(z)
    y = y.cpu().detach().numpy()

    visualize(y.reshape(676), vmin=0, vmax=1, figname=f"{oct_name}_A{idx+1}", title=f"A{idx+1}")
