## Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## Install Necessary Packages

In [None]:
# LPIPS is used to measure the perceptual similarity between two images.
# For calculating perceptual hashes of images, which are useful for identifying similar or duplicate images
!pip -q install lpips imagehash

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm as SN
from torch.nn.utils import weight_norm
import math
from typing import Optional, Tuple

class PixelNorm(nn.Module):
    """Channel‑wise ℓ2 normalisation (StyleGAN)."""

    def forward(self, x: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:  # type: ignore[override]
        return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + eps)


class NoiseInjection(nn.Module):
    """Adds learnable per‑channel scalar × noise map (identical across batch)."""

    def __init__(self, channels: int):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1, channels, 1, 1))

    def forward(self, x: torch.Tensor, noise: Optional[torch.Tensor] = None) -> torch.Tensor:  # type: ignore[override]
        if noise is None:
            noise = torch.randn(x.size(0), 1, x.size(2), x.size(3), device=x.device, dtype=x.dtype)
        return x + self.weight * noise


# 3×3 blur kernel (StyleGAN2)
_blur_kernel = torch.tensor([1, 2, 1], dtype=torch.float32)
_blur_kernel = (_blur_kernel[:, None] * _blur_kernel[None, :]) / _blur_kernel.sum()


def blur(x: torch.Tensor) -> torch.Tensor:
    k = _blur_kernel.to(x.device, x.dtype).repeat(x.size(1), 1, 1, 1)
    return F.conv2d(x, k, padding=1, groups=x.size(1))


class SelfAttn2d(nn.Module):
    """Non‑local self‑attention (SAGAN)."""

    def __init__(self, in_ch: int) -> None:
        super().__init__()
        self.q = SN(nn.Conv1d(in_ch, in_ch // 8, 1, bias=False))
        self.k = SN(nn.Conv1d(in_ch, in_ch // 8, 1, bias=False))
        self.v = SN(nn.Conv1d(in_ch, in_ch, 1, bias=False))
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        b, c, h, w = x.size()
        flat = x.view(b, c, -1)
        attn = self.q(flat).permute(0, 2, 1) @ self.k(flat)
        attn = F.softmax(attn / math.sqrt(c / 8), dim=-1)
        out = self.v(flat) @ attn.permute(0, 2, 1)
        return x + self.gamma * out.view(b, c, h, w)


class GResBlock(nn.Module):
    """Upsample → blur → conv×2 + skip (StyleGAN2 style)."""

    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.skip = nn.Conv2d(in_ch, out_ch, 1, bias=False)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.noise1, self.noise2 = NoiseInjection(out_ch), NoiseInjection(out_ch)
        self.pn1, self.pn2 = PixelNorm(), PixelNorm()
        for m in (self.conv1, self.conv2, self.skip):
            nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="leaky_relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        y = blur(F.interpolate(x, scale_factor=2, mode="nearest"))
        y = F.leaky_relu(self.pn1(self.noise1(self.conv1(y))), 0.2, inplace=True)
        y = F.leaky_relu(self.pn2(self.noise2(self.conv2(y))), 0.2, inplace=True)
        skip = self.skip(F.interpolate(x, scale_factor=2, mode="nearest"))
        return (y + skip) * (1 / math.sqrt(2))


class DResBlock(nn.Module):
    """Downsample residual block with spectral‑norm convs."""

    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.conv1, self.conv2 = SN(nn.Conv2d(in_ch, out_ch, 3, padding=1)), SN(nn.Conv2d(out_ch, out_ch, 3, padding=1))
        self.skip = SN(nn.Conv2d(in_ch, out_ch, 1, bias=False))
        self.avg_pool = nn.AvgPool2d(2)
        for m in (self.conv1, self.conv2, self.skip):
            nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in", nonlinearity="leaky_relu")

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        y = F.leaky_relu(self.conv1(x), 0.2, inplace=True)
        y = F.leaky_relu(self.conv2(y), 0.2, inplace=True)
        y = self.avg_pool(y)
        return (y + self.avg_pool(self.skip(x))) * (1 / math.sqrt(2))



class Generator(nn.Module):
    """256×256 conditional generator with clamped to-RGB dynamic range."""
    def __init__(self,
                 z_dim: int = 128,
                 label_dim: int = 50,
                 base_ch: int = 1024,
                 num_cls: int = 4):
        super().__init__()

        # ---------------- label embeddings -----------------
        self.exp_embed = nn.Embedding(num_cls + 1, label_dim)   # EXP 0-5
        self.icm_embed = nn.Embedding(num_cls,     label_dim)   # ICM A-C
        self.te_embed  = nn.Embedding(num_cls,     label_dim)   # TE  A-C
        lbl_tot = 3 * label_dim

        # ---------------- latent → 4×4 ---------------------
        self.fc      = nn.Linear(z_dim + lbl_tot, base_ch * 4 * 4)
        self.lbl_fc4 = nn.Linear(lbl_tot,          base_ch * 4 * 4)

        # ---------------- upsample backbone ----------------
        self.b8   = GResBlock(base_ch,       base_ch // 2)   # 4→8
        self.b16  = GResBlock(base_ch // 2,  base_ch // 4)   # 8→16
        self.b32  = GResBlock(base_ch // 4,  base_ch // 8)   # 16→32
        self.att32 = SelfAttn2d(base_ch // 8)
        self.b64  = GResBlock(base_ch // 8,  base_ch // 16)  # 32→64
        self.b128 = GResBlock(base_ch // 16, base_ch // 32)  # 64→128
        self.b256 = GResBlock(base_ch // 32, base_ch // 64)  # 128→256

        # ---------------- clamped to-RGB -------------------
        # weight-norm separates direction (v) & magnitude (g)
        self.to_rgb = weight_norm(nn.Conv2d(base_ch // 64, 3, 3, padding=1))

        # initialise direction + *small* gain
        nn.init.normal_(self.to_rgb.weight_v, 0.0, 1.0)
        nn.init.normal_(self.to_rgb.weight_g, 0.0, 0.02)

        # clamp g *every* forward pass (pre-hook)
        def _clamp_gain(_, inp):
            self.to_rgb.weight_g.data.clamp_(0.0, 0.05)
        self.to_rgb.register_forward_pre_hook(_clamp_gain)

        nn.init.zeros_(self.to_rgb.bias)

    # -------------------------------------------------------
    def _lbl_vec(self, exp, icm, te):
        return torch.cat([self.exp_embed(exp),
                          self.icm_embed(icm),
                          self.te_embed(te)], dim=1)

    def forward(self, z, exp, icm, te):
        """
        z   : [B, z_dim]
        exp : [B]   expansion grade (0-5, 0 = unexpanded)
        icm : [B]   ICM grade (0=A,1=B,2=C,3=D if present)
        te  : [B]   TE grade  (0=A,1=B,2=C,3=D)
        """
        lbl = self._lbl_vec(exp, icm, te)                         # [B,150]
        x = self.fc(torch.cat([z, lbl], dim=1))                   # [B,1024*4*4]
        x = x.view(-1, 1024, 4, 4)
        x = x + self.lbl_fc4(lbl).view_as(x)
        x = F.leaky_relu(x, 0.2, inplace=True)

        x = self.b8(x)
        x = self.b16(x)
        x = self.b32(x)
        x = self.att32(x)
        x = self.b64(x)
        x = self.b128(x)
        x = self.b256(x)

        rgb = self.to_rgb(x)                      # clamped gain here
        return torch.tanh(rgb)                    # [-1,1] image

In [None]:
# === 0. imports & helpers ======================================
import torch, lpips, imagehash, PIL.Image as PIL
from pathlib import Path
from torchvision.utils import save_image
import numpy as np
import random

CKPT_FOLDER = "/content/drive/MyDrive/Msc in AI/Deep Learning/Blastocyst_Dataset/cgan_checkpoints/"

device      = "cuda"
z_dim       = 128
batch       = 32
checkpoints = [             # pick 3–5 diverse ones
    CKPT_FOLDER+"gan_epoch_825.pth",
    CKPT_FOLDER+"gan_epoch_850.pth",
    CKPT_FOLDER+"gan_epoch_875.pth",
    CKPT_FOLDER+"gan_epoch_900.pth"
]
psi_trunc   = 0.8           # truncation trick for quality
lpips_fn    = lpips.LPIPS(net='vgg').to(device).eval()

# ---- simple dup filter (phash + LPIPS) ------------------------
def is_novel(img, seen_hashes, thresh_lpips=0.7):
    h = imagehash.phash(img)
    # 1) hash quick-reject
    if any(h - h0 <= 4 for h0 in seen_hashes):          #≤4 bits diff
        # 2) LPIPS exact check
        img_t = torch.from_numpy(np.array(img).transpose(2,0,1))\
                    .float().div(255).unsqueeze(0).to(device)
        for cand in seen_hashes[h]:                      # list of PILs
            cand_t = torch.from_numpy(np.array(cand).transpose(2,0,1))\
                        .float().div(255).unsqueeze(0).to(device)
            d = lpips_fn(img_t*2-1, cand_t*2-1).item()
            if d < thresh_lpips:
                return False
    return True


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]




Loading model from: /usr/local/lib/python3.11/dist-packages/lpips/weights/v0.1/vgg.pth


In [None]:
import pandas as pd
import torch, random, itertools

CSV_PATH = "/content/drive/MyDrive/Msc in AI/Deep Learning/Blastocyst_Dataset/Gardner_train_silver.csv"          # ← change if necessary
df = pd.read_csv(CSV_PATH, delimiter=";")

# The CSV is assumed to have columns like 'EXP', 'ICM', 'TE'
triples, freqs = np.unique(
    df[["EXP_silver", "ICM_silver", "TE_silver"]].astype(int).values, axis=0, return_counts=True
)

valid_triples = [tuple(map(int, t)) for t in triples]
freqs         = freqs / freqs.sum()          # normalise for probability

class TripleSampler:
    """
    Draw (exp, icm, te) ONLY from the list supplied in `valid_triples`.
    If `weighted=True` the empirical frequency in the CSV is respected.
    """
    def __init__(self, triples, probs=None, device="cpu"):
        self.device  = device
        self.triples = torch.tensor(triples, device=device)
        if probs is None:
            self.probs = torch.ones(len(triples), device=device) / len(triples)
        else:
            self.probs = torch.tensor(probs, device=device)

    def __call__(self, batch):
        idx  = torch.multinomial(self.probs, batch, replacement=True)
        pick = self.triples[idx]                     # [B,3]
        return pick[:,0], pick[:,1], pick[:,2]       # exp, icm, te

valid_sampler = TripleSampler(valid_triples, freqs, device=device)

In [None]:
# === 1. target counters ========================================
targets = {
    "EXP": {k: 1000 for k in range(5)},
    "ICM": {k: 1000 for k in range(4)},       # 0:A,1:B,2:C
    "TE" : {k: 1000 for k in range(4)}
}
out_root = Path("/content/drive/MyDrive/Msc in AI/Deep Learning/Blastocyst_Dataset/Synthetic_GAN"); out_root.mkdir(exist_ok=True, parents=True)
for cat,d in targets.items():
    for lab in d: (out_root/cat/str(lab)).mkdir(parents=True, exist_ok=True)

seen_hash = {}          # {hash : [PILs]} for dup filter


In [None]:
# === 2. harvest loop ===========================================
for ckpt in checkpoints:
    G = Generator(z_dim=z_dim).to(device)
    G.load_state_dict(torch.load(ckpt, map_location="cpu")["G"])
    G.eval()

    while any(v>0 for cat in targets.values() for v in cat.values()):
        z  = torch.randn(batch, z_dim, device=device) * psi_trunc
        exp, icm, te = valid_sampler(batch)

        with torch.no_grad():
            imgs = (G(z, exp, icm, te) + 1) / 2     # [0,1]

        for k,(img_t,e,i,t) in enumerate(zip(imgs, exp, icm, te)):
            # ------------- EXP bucket ---------------------------
            if targets["EXP"][e.item()] > 0:
                pil = PIL.fromarray((img_t.mul(255)
                                    .byte().permute(1,2,0)
                                    .cpu().numpy()))
                if is_novel(pil, seen_hash):
                    fname = out_root/"EXP"/str(e.item())/f"{targets['EXP'][e.item()]:04d}.png"
                    pil.save(fname)
                    targets["EXP"][e.item()] -= 1
            # ------------- ICM bucket ---------------------------
            label_i = i.item()
            if targets["ICM"][label_i] > 0:
                pil = PIL.fromarray((img_t.mul(255).byte()
                                    .permute(1,2,0).cpu().numpy()))
                if is_novel(pil, seen_hash):
                    fname = out_root/"ICM"/str(label_i)/f"{targets['ICM'][label_i]:04d}.png"
                    pil.save(fname)
                    targets["ICM"][label_i] -= 1
            # ------------- TE bucket ----------------------------
            label_t = t.item()
            if targets["TE"][label_t] > 0:
                pil = PIL.fromarray((img_t.mul(255).byte()
                                    .permute(1,2,0).cpu().numpy()))
                if is_novel(pil, seen_hash):
                    fname = out_root/"TE"/str(label_t)/f"{targets['TE'][label_t]:04d}.png"
                    pil.save(fname)
                    targets["TE"][label_t] -= 1

        # optional progress print
        if random.random()<.05:
            print({k:sum(v.values()) for k,v in targets.items()})


  WeightNorm.apply(module, name, dim)


{'EXP': 4520, 'ICM': 3520, 'TE': 3520}
{'EXP': 3752, 'ICM': 2752, 'TE': 2752}
{'EXP': 3112, 'ICM': 2305, 'TE': 2112}
{'EXP': 3056, 'ICM': 2281, 'TE': 2057}
{'EXP': 2686, 'ICM': 1974, 'TE': 1657}
{'EXP': 1315, 'ICM': 952, 'TE': 880}
{'EXP': 1176, 'ICM': 947, 'TE': 867}
{'EXP': 786, 'ICM': 929, 'TE': 838}
{'EXP': 699, 'ICM': 926, 'TE': 833}
{'EXP': 693, 'ICM': 926, 'TE': 831}
{'EXP': 640, 'ICM': 925, 'TE': 828}
{'EXP': 570, 'ICM': 920, 'TE': 820}
{'EXP': 272, 'ICM': 905, 'TE': 779}
{'EXP': 217, 'ICM': 899, 'TE': 772}
{'EXP': 198, 'ICM': 897, 'TE': 767}
{'EXP': 150, 'ICM': 894, 'TE': 752}
{'EXP': 113, 'ICM': 888, 'TE': 740}
{'EXP': 110, 'ICM': 888, 'TE': 739}
{'EXP': 80, 'ICM': 886, 'TE': 729}
{'EXP': 34, 'ICM': 880, 'TE': 709}
{'EXP': 0, 'ICM': 878, 'TE': 687}
{'EXP': 0, 'ICM': 876, 'TE': 682}
{'EXP': 0, 'ICM': 876, 'TE': 680}
{'EXP': 0, 'ICM': 874, 'TE': 678}
{'EXP': 0, 'ICM': 874, 'TE': 676}
{'EXP': 0, 'ICM': 871, 'TE': 660}
{'EXP': 0, 'ICM': 870, 'TE': 658}
{'EXP': 0, 'ICM': 868, 'TE'