## Simple Multimodal Medical Dataset (Toy Model)

### Latent variables

| Symbol | Meaning |
|:--:|--|
| $a$ | age |
| $d$ | breast density |
| $r$ | risk factor (e.g., family history) |
| $s$ | scanner quality |
| $z \in \{0,1\}$ | true disease state |

---

### Core equations

#### 1) Disease probability (patient-factor dependent)

Let $\sigma(x)=\dfrac{1}{1+e^{-x}}$ be the logistic (sigmoid) function.  
Define $\boldsymbol{\alpha}=(\alpha_0,\alpha_1,\alpha_2,\alpha_3)$. Then:

$$
P(z=1)=\sigma(\alpha_0+\alpha_1 a+\alpha_2 d+\alpha_3 r)
$$

#### 2) Lesion parameters (conditional on disease)

For $z=1$ only, lesion **size** and **contrast** are defined as:

$$
\text{size} = \beta_0 + \beta_1 a + \beta_2 d + \varepsilon_s
$$

$$
\text{contrast} = \gamma_0 - \gamma_1 d + \varepsilon_c
$$

with noise terms (e.g.) $\varepsilon_s \sim \mathcal{N}(0,\sigma_s^2)$ and $\varepsilon_c \sim \mathcal{N}(0,\sigma_c^2)$.

#### 3) Generate data

$$
\boldsymbol{Image:I} = background(d,s) + z * lesion(size,contrast)
$$

$$
\boldsymbol{Radiomics:R} = f(size,contrast,d) + \varepsilon_{R}
$$

#### 4) Observed clinical data

$$
y =
\begin{cases}
1, & \text{if } z = 1 \text{ and } \text{random()} > \eta_{FN} \\[4pt]
0, & \text{if } z = 0 \text{ and } \text{random()} > \eta_{FP}
\end{cases}
$$
where:
- $\eta_{FN}$ — false negative rate  
- $\eta_{FP}$ — false positive rate  
- `random()` — uniform sample in $[0,1]$


#### 5) Summary

| Variable | Generated by | Notes |
|:--|:--|:--|
| $z$ | logistic(age, density, risk) | true hidden condition |
| *image* | function of (density, scanner, lesion params) | visual data |
| *radiomics* | derived from lesion params | tabular data |
| $y$ | noisy version of $z$ | clinical diagnosis |

In [None]:
from pathlib import Path
import sys

ROOT = Path(__file__).resolve().parents[1] if "__file__" in globals() else Path().resolve().parents[0]
sys.path.append(str(ROOT))
from scripts.ClinicalSynth import ClinicalSynth, Config

# --- demanding, realistic setup ---
cfg = Config(
    H=256, W=256,
    seed=42,

    # --- latent disease model (rare + weak but real correlations) ---
    a0=-4.5,            # low baseline prevalence (~5–8%)
    a_age=0.02,         # slightly higher risk with age
    a_den=0.35,         # dense breasts -> risk + imaging challenge
    a_risk=0.6,         # family history increases risk

    # --- diagnostic noise (radiologist uncertainty) ---
    eta_fn=0.25,        # 25% missed cancers
    eta_fp=0.20,        # 20% overcalls

    # --- lesion characteristics (subtle) ---
    size_base=10.0,     # moderately small lesion
    size_age=0.05,      # slight size increase with age
    size_den=2.0,       # dense tissue compresses lesion
    con_base=0.9,       # moderately low contrast
    con_den=0.6,        # denser breast lowers contrast

    # --- imaging conditions ---
    bg_base=0.25,                       # noisy parenchyma texture
    device_blur={"A": 1, "B": 3, "C": 6},  # varied scanner blur (domain shift)
)

# --- generate splits ---
gen = ClinicalSynth(cfg)
train = ClinicalSynth(Config(**{**cfg.__dict__, "seed": 123}))
val   = ClinicalSynth(Config(**{**cfg.__dict__, "seed": 456}))
test  = ClinicalSynth(Config(**{**cfg.__dict__, "seed": 789}))

Xtr = train.sample_batch(1000)  # 1k train
Xva = val.sample_batch(100)     # 200 val
Xte = test.sample_batch(100)    # 200 test


In [4]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class ClinicalDataset(Dataset):
    """Wraps arrays from ClinicalSynth into a PyTorch dataset."""
    def __init__(self, imgs, tabs, y):
        self.imgs = torch.tensor(imgs, dtype=torch.float32).unsqueeze(1)
        self.tabs = torch.tensor(tabs, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, i):
        return self.imgs[i], self.tabs[i], self.y[i]


train_ds = ClinicalDataset(*Xtr[:3])
val_ds   = ClinicalDataset(*Xva[:3])
test_ds  = ClinicalDataset(*Xte[:3])

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=16)
test_loader  = DataLoader(test_ds, batch_size=16)


class SimpleModel(nn.Module):
    def __init__(self, n_tab: int):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 8, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(8, 16, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1)
        )
        self.tab = nn.Sequential(
            nn.Linear(n_tab, 32), nn.ReLU(),
            nn.Linear(32, 32), nn.ReLU()
        )
        self.fc = nn.Sequential(
            nn.Linear(16 + 32, 32), nn.ReLU(),
            nn.Linear(32, 2)
        )

    def forward(self, x_img, x_tab):
        vi = self.cnn(x_img).flatten(1)
        vt = self.tab(x_tab)
        v = torch.cat([vi, vt], dim=1)
        return self.fc(v)
    

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleModel(n_tab=Xtr[1].shape[1]).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()


def run_epoch(loader, train_mode=True):
    if train_mode:
        model.train()
    else:
        model.eval()

    total, correct, loss_sum = 0, 0, 0.0
    loop = tqdm(loader, leave=False)
    for img, tab, y in loop:
        img, tab, y = img.to(device), tab.to(device), y.to(device)

        with torch.set_grad_enabled(train_mode):
            out = model(img, tab)
            loss = criterion(out, y)

        if train_mode:
            opt.zero_grad()
            loss.backward()
            opt.step()

        pred = out.argmax(1)
        total += y.size(0)
        correct += (pred == y).sum().item()
        loss_sum += loss.item() * y.size(0)

        loop.set_description(f"{'Train' if train_mode else 'Eval'}")
        loop.set_postfix(loss=loss.item(), acc=correct/total)
    return loss_sum / total, correct / total


for epoch in range(10):
    tr_loss, tr_acc = run_epoch(train_loader, train_mode=True)
    va_loss, va_acc = run_epoch(val_loader, train_mode=False)
    print(f"Epoch {epoch+1:02d} | Train loss {tr_loss:.3f} acc {tr_acc:.3f} | "
          f"Val loss {va_loss:.3f} acc {va_acc:.3f}")


                                                                             

Epoch 01 | Train loss 0.728 acc 0.802 | Val loss 0.692 acc 0.680


                                                                             

Epoch 02 | Train loss 0.746 acc 0.802 | Val loss 0.677 acc 0.740


                                                                             

Epoch 03 | Train loss 0.623 acc 0.798 | Val loss 0.683 acc 0.740


                                                                             

Epoch 04 | Train loss 0.578 acc 0.796 | Val loss 0.602 acc 0.740


                                                                             

Epoch 05 | Train loss 0.590 acc 0.796 | Val loss 0.643 acc 0.740


                                                                             

Epoch 06 | Train loss 0.783 acc 0.794 | Val loss 0.603 acc 0.740


                                                                             

Epoch 07 | Train loss 0.588 acc 0.792 | Val loss 0.679 acc 0.740


                                                                             

Epoch 08 | Train loss 0.661 acc 0.796 | Val loss 0.617 acc 0.740


                                                                             

Epoch 09 | Train loss 0.524 acc 0.792 | Val loss 0.620 acc 0.700


                                                                             

Epoch 10 | Train loss 0.505 acc 0.802 | Val loss 0.611 acc 0.730




In [5]:
test_loss, test_acc = run_epoch(test_loader, train_mode=False)
print(f"\nTest accuracy: {test_acc:.3f}")


                                                                          


Test accuracy: 0.740


