In [None]:
import os
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch
import torchvision.transforms as T
from threelayerRBM import extract_features
import pickle
import numpy as np



class ChestXRayDataset(Dataset):
    def __init__(self, csv_path, root_dir, rbm_model, findings_list, hidden_features_dict,
                 transform=None, vectorizer_v=None, hidden_vectorizers=None):
        """
        csv_path: path to output.csv
        root_dir: './files-1024'
        rbm_model: an instance of ThreeLayerRBM already trained and loaded
        findings_list, hidden_features_dict: from your threelayerRBM.py setup
        transform: torchvision transforms for images
        """
        self.df = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.transform = transform
        self.rbm = rbm_model
        self.findings_list = findings_list
        self.hidden_features_dict = hidden_features_dict

        # Store vectorizers
        self.vectorizer_v = vectorizer_v
        self.hidden_vectorizers = hidden_vectorizers

        # Pre-load vectorizers from your threelayerRBM code (assuming you saved them)
        # If not saved, you must re-create them as done in extract_features()
        # For simplicity, let's assume you have functions or preloaded vectorizers:
        # self.vectorizer_v, self.hidden_vectorizers = load_vectorizers(...)

        # To avoid repeated processing, we can store relevant info
        # We'll store each row's path and text for lazy processing
        self.samples = []
        for _, row in self.df.iterrows():
            level1 = row['Level1']  # e.g. 'p10'
            level2 = row['Level2']  # e.g. 'p10000032'
            file_ = row['File']     # e.g. 's50414267'
            # Construct the image directory
            img_dir = os.path.join(self.root_dir, level1, level2, file_)

            # Extract text
            findings = str(row['FINDINGS']) if not pd.isnull(row['FINDINGS']) else ''
            impression = str(row['IMPRESSION']) if not pd.isnull(row['IMPRESSION']) else ''
            text = findings + ' ' + impression

            self.samples.append((img_dir, text))

    def __len__(self):
        return len(self.samples)

    def preprocess_text(self, text):
        # As in threelayerRBM.py
        import re
        text = text.lower()
        text = re.sub(r'[^\w\s]', '', text)
        return text

    def get_features_from_text(self, clean_text, vectorizer_v, hidden_vectorizers):
        # vectorizer_v is a CountVectorizer for findings
        X_visible = vectorizer_v.transform([clean_text]).toarray()

        hidden_features = []
        for category, vec in hidden_vectorizers.items():
            X_hidden_cat = vec.transform([clean_text]).toarray()
            hidden_features.append(X_hidden_cat)
        X_hidden = np.concatenate(hidden_features, axis=1) if len(hidden_features) > 0 else np.array([])

        return X_visible, X_hidden
    
    def text_to_cond_z(self, text):
        # Convert text to cond_z using RBM
        # 1. Preprocess text
        clean_text = self.preprocess_text(text)

        # 2. Vectorize text to get X_visible and X_hidden (as in threelayerRBM.py extract_features step)
        # Assuming you have vectorizers prepared:
        # X_visible = self.vectorizer_v.transform([clean_text]).toarray()
        # For hidden features:
        # hidden_features = []
        # for category, terms in self.hidden_features_dict.items():
        #     X_hidden_cat = self.hidden_vectorizers[category].transform([clean_text]).toarray()
        #     hidden_features.append(X_hidden_cat)
        # X_hidden = np.concatenate(hidden_features, axis=1)

        # If you do not have them preloaded, you must create vectorizers here or store them from training phase.

        # For demonstration, assume you have a function get_features_from_text that returns X_visible and X_hidden:
        X_visible, X_hidden = self.get_features_from_text(clean_text, self.vectorizer_v, self.hidden_vectorizers)
        # 3. cond_z = rbm.transform(X_visible)
        cond_z = torch.tensor(self.rbm.transform(X_visible), dtype=torch.float32)
        
        return cond_z[0]  # Since batch size = 1 here


    def load_image(self, img_dir):
        # Load one or multiple images from img_dir
        # If there are multiple .jpg images, decide how to handle them
        imgs = [f for f in os.listdir(img_dir) if f.lower().endswith('.jpg')]
        if len(imgs) == 0:
            raise FileNotFoundError(f"No image found in {img_dir}")

        # For simplicity, load the first image
        img_path = os.path.join(img_dir, imgs[0])
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

    def __getitem__(self, idx):
        img_dir, text = self.samples[idx]

        img = self.load_image(img_dir)
        cond_z = self.text_to_cond_z(text)
        return img, cond_z

from threelayerRBM import ThreeLayerRBM

# Load your RBM model and vectorizers
findings_list = [
    "left", "right", "atelectasis", "bronchiectasis", "bulla", "consolidation", "dextrocardia", "effusion", "emphysema",
    "fracture clavicle", "fracture rib", "groundglass opacity", "interstitial opacification",
    "mass paraspinal", "mass soft tissue", "nodule", "opacity", "pneumomediastinum", "pneumonia",
    "pneumoperitoneum", "pneumothorax", "pleural effusion", "pulmonary edema", "scoliosis",
    "tuberculosis", "volume loss", "rib", "mass", "infiltration", "other findings"
]

hidden_features_dict = {
    'location': [
        "left lung", "right lung", "upper lobe", "lower lobe", "cardiac region",
        "pleural space", "diaphragm", "mediastinum", "thoracic spine", "abdominal region"
    ],
    'organ_system': [
        "respiratory system", "cardiovascular system", "musculoskeletal system", "digestive system"
    ],
    'mode_of_pathology': [
        "congenital", "acquired", "infection", "inflammation", "tumor", "degenerative", "vascular"
    ],
    'severity': [
        "mild", "moderate", "severe",
    ],
}
df = pd.read_csv('output.csv')
X_visible, X_hidden, vectorizer_v = extract_features(df, findings_list, hidden_features_dict)

n_visible = X_visible.shape[1]
n_hidden_middle = 30
n_hidden_top = X_hidden.shape[1]
rbm = ThreeLayerRBM(n_visible, n_hidden_middle, n_hidden_top)
rbm.load_model('rbm_model.pkl')
with open('vectorizer_v.pkl', 'rb') as f:
    vectorizer_v = pickle.load(f)
with open('hidden_vectorizers.pkl', 'rb') as f:
    hidden_vectorizers = pickle.load(f)

# transforms for images
transform = T.Compose([
    T.Resize((256,256)),
    T.ToTensor(),
])

dataset = ChestXRayDataset(
    csv_path='output.csv',
    root_dir='./files-1024',
    rbm_model=rbm,
    findings_list=findings_list,
    hidden_features_dict=hidden_features_dict,
    transform=transform,
    vectorizer_v=vectorizer_v,
    hidden_vectorizers=hidden_vectorizers
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataloader = DataLoader(dataset, batch_size=8, shuffle=True).to(device)


Model loaded from rbm_model.pkl


In [91]:
import torch
import pandas as pd
import os
from PIL import Image
import torchvision.transforms as T
import torchvision.utils as vutils
import numpy as np
import torch.nn.functional as F
from model import AutoEncoder
import re
from distributions import Normal, DiscMixLogistic, NormalDecoder

def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    return text.strip()

def get_features_from_text(clean_text, vectorizer_v, hidden_vectorizers):
    # vectorizer_v is a CountVectorizer for findings
    X_visible = vectorizer_v.transform([clean_text]).toarray()

    hidden_features = []
    for category, vec in hidden_vectorizers.items():
        X_hidden_cat = vec.transform([clean_text]).toarray()
        hidden_features.append(X_hidden_cat)
    X_hidden = np.concatenate(hidden_features, axis=1) if len(hidden_features) > 0 else np.array([])

    return X_visible, X_hidden

def sample(model, num_samples, t, cond_z=None):
    scale_ind = 0
    z0_size = [num_samples] + model.z0_size
    dist = Normal(mu=torch.zeros(z0_size), log_sigma=torch.zeros(z0_size), temp=t)
    z, _ = dist.sample()

    idx_dec = 0
    s = model.prior_ftr0.unsqueeze(0)
    batch_size = z.size(0)
    s = s.expand(batch_size, -1, -1, -1)

    # If conditioning is provided
    if cond_z is not None and model.cond_z_dim > 0:
        # cond_z shape: [num_samples, cond_z_dim]
        # We'll apply this shift whenever we form mu, log_sigma from dec_sampler
        shift_all = model.cond_mapper(cond_z)  # [B, 2*latent_per_group]
        shift_all = shift_all.unsqueeze(-1).unsqueeze(-1)  # [B, 2*latent_per_group, 1, 1]

    for cell in model.dec_tower:
        if cell.cell_type == 'combiner_dec':
            if idx_dec > 0:
                # form prior
                param = model.dec_sampler[idx_dec - 1](s)
                mu, log_sigma = torch.chunk(param, 2, dim=1)
                dist = Normal(mu, log_sigma, t)
                z, _ = dist.sample()

            # 'combiner_dec'
            s = cell(s, z)
            idx_dec += 1
        else:
            s = cell(s)
            if cell.cell_type == 'up_dec':
                scale_ind += 1

    if model.vanilla_vae:
        s = model.stem_decoder(z)

    for cell in model.post_process:
        s = cell(s)

    logits = model.image_conditional(s)
    return logits

# Load model, rbm, vectorizers
# Assume you have a function load_model and it returns a model that can do model.sample(num_samples, cond_z=...)
def load_model(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location='cpu')
    args = checkpoint['args']
    arch_instance = checkpoint['arch_instance']
    model = AutoEncoder(args, None, arch_instance, cond_z_dim=24)  # fill in cond_z_dim
    model.load_state_dict(checkpoint['state_dict'], strict=False)
    model.eval()
    return model, args

# Assume you have vectorizer_v, hidden_vectorizers, rbm loaded
with open('vectorizer_v.pkl', 'rb') as f:
    vectorizer_v = pickle.load(f)
with open('hidden_vectorizers.pkl', 'rb') as f:
    hidden_vectorizers = pickle.load(f)
rbm = ThreeLayerRBM(n_visible, n_hidden_middle, n_hidden_top)
rbm.load_model('rbm_model.pkl')

model, args = load_model('./eval-exp64-12000-7epoch/checkpoint.pt')

# Read cond_z from output.csv
df = pd.read_csv('output.csv')
# We'll assume df has Level1, Level2, File, FINDINGS, IMPRESSION columns
# We'll generate cond_z for first 5 samples
cond_zs = []
for i in range(25):
    row = df.iloc[i]
    findings = str(row['FINDINGS']) if pd.notnull(row['FINDINGS']) else ''
    impression = str(row['IMPRESSION']) if pd.notnull(row['IMPRESSION']) else ''
    text = findings + ' ' + impression
    clean_text = preprocess_text(text)
    X_visible, X_hidden = get_features_from_text(clean_text, vectorizer_v, hidden_vectorizers)
    cond_z = torch.tensor(rbm.transform(X_visible), dtype=torch.float32).unsqueeze(0)  # [1, cond_z_dim]
    cond_zs.append(cond_z)

cond_zs = torch.cat(cond_zs, dim=0)  # shape [25, cond_z_dim]

# Generate first 5 images from cond_z
# Assume model.sample(num_samples, t=1.0, cond_z=...) can handle condition
# If model needs cond_z per image, loop or modify model to accept a batch of cond_z
with torch.no_grad():
    # For demonstration, let's say model.sample(num_samples=5, cond_z=cond_zs) works
    logits = sample(model,num_samples=25, t=1.0, cond_z=cond_zs)  
    # Decode logits
    output = model.decoder_output(logits)
    # output.sample() to get images
    gen_images = output.sample().cpu()  # [25, C, H, W]

# Assume original images are from the dataset. Load first 5 corresponding originals
# You must know how to map them. We assume df also has paths or we replicate the logic from dataset.
def load_original_image(row):
    level1 = row['Level1']
    level2 = row['Level2']
    file_ = row['File']
    img_dir = os.path.join('./files-1024', str(level1), str(level2), str(file_))
    imgs = [f for f in os.listdir(img_dir) if f.lower().endswith('.jpg')]
    img_path = os.path.join(img_dir, imgs[0])
    img = Image.open(img_path).convert('L') # assume grayscale
    transform = T.Compose([
        T.Resize((64,64)),  # match training resolution
        T.ToTensor(),
    ])
    img = transform(img)
    return img

#originals = []
#for i in range(25):
#    row = df.iloc[i]
#    orig_img = load_original_image(row)
#    originals.append(orig_img)
#originals = torch.stack(originals, dim=0)  # [5, 1, 32,32]

# 4) Create a 5x5 montage for original images
#original_montage = vutils.make_grid(originals[:25], nrow=5, padding=2, normalize=True)
#vutils.save_image(original_montage, 'original_montage.png')

# 5) Create a 5x5 montage for generated images
generated_montage = vutils.make_grid(gen_images[:25], nrow=5, padding=2, normalize=True)
vutils.save_image(generated_montage, 'generated_montage.png')

# 6) calculate KL divergence of the first 500 original and associated generated images
# We must define a KL divergence measure. We'll assume both original and generated are normalized
# distributions over pixels. This is simplistic and not necessarily meaningful, but as an example:

def kl_divergence(p, q, eps=1e-8):
    # p, q are [C,H,W]. Ensure sum to 1 if treating as distributions
    # We'll sum over all pixels. Let's flatten them.
    # This is not a standard approach for images, but just a demonstration.
    p = p.flatten()
    q = q.flatten()
    # Normalize to sum=1
    p = p / (p.sum() + eps)
    q = q / (q.sum() + eps)
    kl = (p * (torch.log(p+eps) - torch.log(q+eps))).sum()
    return kl.item()

# We'll load first 500 samples (or as many as we have)
num_samples_kl = min(len(df), 500)
kls = []
orig_list = []
gen_list = []
with torch.no_grad():
    for i in range(num_samples_kl):
        row = df.iloc[i]
        # get cond_z
        findings = str(row['FINDINGS']) if pd.notnull(row['FINDINGS']) else ''
        impression = str(row['IMPRESSION']) if pd.notnull(row['IMPRESSION']) else ''
        text = findings + ' ' + impression
        clean_text = preprocess_text(text)
        X_visible, X_hidden = get_features_from_text(clean_text, vectorizer_v, hidden_vectorizers)
        cond_z = torch.tensor(rbm.transform(X_visible), dtype=torch.float32).unsqueeze(0)

        # Generate image for this cond_z
        logits = sample(model,num_samples=1, t=1.0, cond_z=cond_z)
        out = model.decoder_output(logits)
        gen_img = out.sample()[0]  # [C,H,W]

        orig_img = load_original_image(row) # [1,H,W]
        # Convert both to CPU float
        orig_img = orig_img.cpu().float()
        gen_img = gen_img.cpu().float()
        orig_list.append(orig_img)
        gen_list.append(torch.mean(gen_img,dim=0))

        kl_value = kl_divergence(orig_img, torch.mean(gen_img,dim=0))
        kls.append(kl_value)

# 7) report an average of calculated KL divergence
avg_kl = sum(kls) / len(kls)
print("Average KL divergence over first 500 samples:", avg_kl)


Model loaded from rbm_model.pkl


  checkpoint = torch.load(checkpoint_path, map_location='cpu')


len log norm: 128
len bn: 92
Average KL divergence over first 500 samples: 0.33895698574185373


In [92]:
import torch
import pandas as pd
import os
from PIL import Image
import torchvision.transforms as T
import torchvision.utils as vutils
import numpy as np
import torch.nn.functional as F
from torchvision.models import inception_v3
from torch.utils.data import DataLoader, Dataset
import math

from scipy.linalg import sqrtm

# --------------- Calculate Inception Score for generated images ----------------
# Inception Score only on gen_batch

def inception_score(imgs, splits=10):
    # imgs: [N,3,299,299]
    N = imgs.size(0)
    batch_size = 50
    preds = []
    with torch.no_grad():
        for i in range(0, N, batch_size):
            batch = imgs[i:i+batch_size]
            out = inception(batch)
            # out: logits before softmax. Get probabilities:
            p_yx = F.softmax(out, dim=1)
            preds.append(p_yx)
    preds = torch.cat(preds, dim=0) # [N,1000]
    # Compute IS
    # split into 10 groups
    split_size = N // splits
    is_scores = []
    for k in range(splits):
        part = preds[k*split_size:(k+1)*split_size, :]
        py = part.mean(dim=0)
        # KL divergence: mean over part of sum p(y|x)*log(p(y|x)/p(y))
        scores = (part * (torch.log(part+1e-8) - torch.log(py+1e-8))).sum(dim=1).mean()
        is_scores.append(torch.exp(scores))
    return sum(is_scores)/len(is_scores)

# --------------- Calculate FID ----------------
# For FID, we need activations from a layer of inception (often pool3 features)
from torch.nn.functional import adaptive_avg_pool2d

def get_activations(imgs, model, batch_size=50):
    # Extract pool3 features
    # Modify inception to output features before fc:
    # For simplicity, use model until pool3:
    model.Mixed_7c.register_forward_hook(lambda m,i,o: setattr(model,'_hidden',o))
    activations = []
    with torch.no_grad():
        for i in range(0, imgs.size(0), batch_size):
            batch = imgs[i:i+batch_size]
            _ = model(batch) 
            # model._hidden now has features [B,2048,H',W']
            # pool to 1x1
            feat = adaptive_avg_pool2d(model._hidden, (1,1)).squeeze(-1).squeeze(-1)
            activations.append(feat)
    activations = torch.cat(activations, dim=0)
    return activations.cpu().numpy()

def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
    # from official FID formula
    diff = mu1 - mu2
    covmean = sqrtm(sigma1.dot(sigma2))
    if not np.isfinite(covmean).all():
        covmean = np.nan_to_num(covmean)
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    return diff.dot(diff) + np.trace(sigma1 + sigma2 - 2*covmean)


def compute_statistics(acts):
    mu = np.mean(acts, axis=0)
    sigma = np.cov(acts, rowvar=False)
    return mu, sigma

def inception_preprocess(img):
    # img is a torch.Tensor with shape either:
    # - [H,W]
    # - [1,H,W]
    # We want to end up with [3,299,299]

    # If img is [H,W], add a channel dimension
    if img.dim() == 2:
        # shape: [H,W] -> [1,H,W]
        img = img.unsqueeze(0)

    # If img has only one channel, replicate it to get 3 channels
    if img.size(0) == 1:
        img = img.repeat(3,1,1)  # [1,H,W] -> [3,H,W]

    # Now we have [3,H,W]. Convert to PIL, then apply resizing & normalization
    pil = T.ToPILImage()(img)  # Convert tensor to PIL
    transform = T.Compose([
        T.Resize((299,299)),
        T.ToTensor(),
        T.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
    img = transform(pil)
    return img

orig_batch = torch.stack([inception_preprocess(im) for im in orig_list], dim=0)
gen_batch = torch.stack([inception_preprocess(im) for im in gen_list], dim=0)

inception = inception_v3(pretrained=True, transform_input=False).eval()

IS = inception_score(gen_batch, splits=10)
print("Inception Score:", IS)

activations_real = get_activations(orig_batch, inception)
activations_fake = get_activations(gen_batch, inception)

mu_r, sigma_r = compute_statistics(activations_real)
mu_g, sigma_g = compute_statistics(activations_fake)

FID = calculate_fid(mu_r, sigma_r, mu_g, sigma_g)
print("FID:", FID)



Inception Score: tensor(1.5833)
FID: 328.8058996039571
