### GNNS Final Project

In [2]:
!pip install kornia
!pip install git+https://github.com/openai/CLIP.git

Collecting kornia
  Downloading kornia-0.8.0-py2.py3-none-any.whl.metadata (17 kB)
Collecting kornia_rs>=0.1.0 (from kornia)
  Downloading kornia_rs-0.1.8-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.9.1->kornia)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.9.1->kornia)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.9.1->kornia)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9.1->kornia)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.9.1->kornia)
  Downloading nvidia_cublas

In [3]:
from google.colab import drive
drive.mount('/content/drive')
checkpoint_path = '/content/drive/MyDrive/checkpoints/lambda_model.pth'
dataset_path = '/content/drive/MyDrive/dataset_medium.pkl'
# Make sure the directory exists
#import os
#os.makedirs(checkpoint_path, exist_ok=True)

Mounted at /content/drive


In [4]:
import torch
import torch.nn as nn
import clip
import numpy as np
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import AutoTokenizer, AutoModel
import kornia
import zipfile
import pickle
import torchvision.transforms as transforms
from PIL import Image
from diffusers import StableDiffusionPipeline



DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

"""
    CLIP embedder classes adapted from https://github.com/UCSB-NLP-Chang/DiffusionDisentanglement/blob/main/ldm/modules/encoders/modules.py#L5
"""

class FrozenClipTextEmbedder(nn.Module):
    """
    Uses the CLIP transformer encoder for text.
    """
    def __init__(self, version='ViT-L/14', device=DEVICE, max_length=77, n_repeat=1, normalize=True):
        super().__init__()
        self.model, _ = clip.load(version, jit=False, device=DEVICE)
        self.device = DEVICE
        self.max_length = max_length
        self.n_repeat = n_repeat
        self.normalize = normalize

    def freeze(self):
        self.model = self.model.eval()
        for param in self.parameters():
            param.requires_grad = False

    def forward(self, text):
        tokens = clip.tokenize(text).to(self.device)
        z = self.model.encode_text(tokens)
        if self.normalize:
            z = z / torch.linalg.norm(z, dim=1, keepdim=True)
        return z

    def encode(self, text):
        z = self(text)
        if z.ndim==2:
            z = z[:, None, :]
        z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
        return z


class FrozenClipImageEmbedder(nn.Module):
    """Uses the CLIP image encoder."""
    def __init__(self, model="ViT-L/14", device=DEVICE, antialias=False):
        super().__init__()
        self.model, _ = clip.load(model, device=DEVICE)
        self.device = DEVICE
        self.antialias = antialias
        self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
        self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)

    def forward(self, x):
        """Encodes the image into embeddings."""
        x = self.preprocess(x).to(self.device)
        return self.model.encode_image(x).to(self.device)

    def encode(self, x):
        """Encodes an image into CLIP embedding."""
        return self(x)

    def preprocess(self, x):
            """ Resize and normalize image for CLIP """
            if isinstance(x, np.ndarray):
                x = torch.tensor(x, dtype=torch.float32)  # Convert NumPy to Tensor
            if isinstance(x, Image.Image):
                x = transforms.ToTensor()(x)  # Converts to (3, H, W), normalized to [0,1]
            print('ndims:', x.ndim, ' shape:' , x.shape)
            if x.ndim == 2:  # Grayscale images (H, W)
                x = x.unsqueeze(0)  # Add channel dimension → (1, H, W)
                x = x.repeat(3, 1, 1)  # Convert to 3 channels → (3, H, W)
                print('first if ndims:', x.ndim, ' shape:', x.shape)

            if x.shape[0] == 1:  # Convert grayscale to RGB by repeating channels
                print('second if ndims:', x.ndim, ' shape:', x.shape)
                x = x.repeat(3, 1, 1)  # Now shape is (3, H, W)

            x = x / 255.0  # Normalize pixel values to [0,1]

            # Make sure x is on the same device as self.mean
            x = x.to(self.mean.device)

            x = kornia.geometry.resize(x.unsqueeze(0), (224, 224),
                                        interpolation='bicubic', align_corners=True,
                                        antialias=self.antialias)  # Add batch dim

            # Move self.mean and self.std to the same device as x
            x = (x - self.mean.to(x.device)[:, None, None]) / self.std.to(x.device)[:, None, None]

            return x


"""
    Create text descriptions
"""
def summarize_labels(labels, max_items=3):
    """ Summarizes a long label list while keeping key findings. This is because there is a limit to the input of the CLIP text embedder

    TODO: review if this is the best approach to deal with long descriptions """

    key_conditions = labels[:max_items]  # Take first 3 labels
    other_count = max(0, len(labels) - max_items)

    if other_count > 0:
        return f"{', '.join(key_conditions)}, and {other_count} other findings"
    else:
        return ", ".join(key_conditions)

def create_neutral_desc(sample):
    """Creates a neutral medical descriptor from a dataset entry
     sample: Shape [filename, img_array, orientation, labels] """
    filename, img_array, orientation, labels = sample

    # Convert label string to a proper list
    labels = eval(labels) if isinstance(labels, str) else labels
    # Join labels into a sentence
    label_text = summarize_labels(labels)  # Summarize findings
    if label_text == "normal":
        return f"A chest X-ray of a patient with no findings."
    if label_text == "unchanged":
        return f"A chest X-ray of a patient with unchanged findings."
    return f"A chest X-ray of a patient with {label_text}."

def create_style_rich_desc(sample):
    """
    Creates a style-rich descriptor with more context
    sample: Shape [filename, img_array, orientation, labels]
    """
    filename, img_array, orientation, labels = sample

    # Convert label string to a proper list
    labels = eval(labels) if isinstance(labels, str) else labels
    # Join labels into a sentence
    label_text = summarize_labels(labels)  # Summarize findings
    if label_text == "normal":
        return f"A chest X-ray of a patient with no findings taken in {orientation} orientation."
    if label_text == "unchanged":
        return f"A chest X-ray of a patient with unchanged findings taken in {orientation} orientation."
    return f"A chest X-ray of a patient with {label_text}, taken in {orientation} orientation."

def load_data_add_descriptions(pickle_filename):
    """ Adds description strings to the dataset which are later turned into embeddings """
    with open(pickle_filename, "rb") as f:
        dataset = pickle.load(f)
        imgs_w_desc = list()
        for sample in dataset[:10]: # just process 10 for faster testing
            neutral = create_neutral_desc(sample)
            style_rich = create_style_rich_desc(sample)
            imgs_w_desc.append([sample[0], sample[1], neutral, style_rich]) #filename, img, neutral desc, syle rich desc

    for sample in imgs_w_desc[:10]:
        print(sample[0], sample[2], sample[3])
    return imgs_w_desc

"""
    Create embeddings
"""
def add_embeddings(imgs_w_desc, batch_size):
    """ turn the description strings and images into embeddings
    imgs_w_desc: Shape [[filename, img_array, neutral_desc, style_rich_desc]]
    """
    clip_image_embedder = FrozenClipImageEmbedder()
    clip_text_embedder = FrozenClipTextEmbedder()

    # Freeze models for inference
    clip_image_embedder.eval()
    clip_text_embedder.eval()

    embedded_data = []

    with torch.no_grad():  # No gradients needed
        for i in range(0, len(imgs_w_desc), batch_size):
            batch_samples = imgs_w_desc[i:i + batch_size]
            filenames = [sample[0] for sample in batch_samples]
            images = [sample[1] for sample in batch_samples]
            neutral_descs = [sample[2] for sample in batch_samples]
            style_descs = [sample[3] for sample in batch_samples]

            # Process the batch
            neutral_embeddings =  []
            for neutral_desc in neutral_descs:
              print(neutral_desc)
              neutral_embedding = clip_text_embedder.encode(neutral_desc).to(DEVICE)
              neutral_embeddings.append(neutral_embedding)
            style_embeddings = []
            for style_desc in style_descs:
              print(style_desc)
              style_embedding = clip_text_embedder.encode(style_desc).to(DEVICE)
              style_embeddings.append(style_embedding)
            # Process each image individually
            img_embeddings = []
            for image in images:
              # Process each image separately
              img_embedding = clip_image_embedder.encode(image).to(DEVICE)
              img_embeddings.append(img_embedding)
            print(len(batch_samples))
            for idx, sample in enumerate(batch_samples):
              print(sample)
              print(len(img_embeddings), len(neutral_embeddings), len(style_embeddings))
              embedded_data.append([filenames[idx], img_embeddings[idx], neutral_embeddings[idx], style_embeddings[idx]])
              print(embedded_data[idx])

    return embedded_data


"""
    Soft combination of embeddings according to Disentanglement paper (Wu et al., 2022)
"""
def soft_combine_embeddings(c0, c1, lambda_t):
    """
    c0 (Tensor) The neutral text embedding. Shape: [batch, dim]
    c1 (Tensor): The style-rich text embedding. Shape: [batch, dim]
    lambda_t (Tensor): The combination weight (0 to 1). Shape: [T] or [T, 1]

    c_t (Tensor): The combined embedding over time. Shape: [T, batch, dim]
    """
    # Ensure lambda_t has correct shape for broadcasting
    lambda_t = lambda_t.view(-1, 1, 1)  # Shape: [T, 1, 1]

    # Linearly combine the embeddings over time
    c_t = lambda_t * c1 + (1 - lambda_t) * c0  # Shape: [T, batch, dim]

    return c_t

def get_lambda_schedule(T, mode="linear"):
    """
    Generates a lambda schedule over T timesteps
    """
    if mode == "linear":
        return torch.linspace(0, 1, steps=T)  # Linearly increasing
    elif mode == "sigmoid":
        x = torch.linspace(-6, 6, steps=T)  # Sigmoid range
        return torch.sigmoid(x)  # Smooth start and end
    elif mode == "cosine":
        return (1 - torch.cos(torch.linspace(0, 3.1416, steps=T))) / 2  # Cosine ease-in-out
    else:
        raise ValueError("Invalid mode! Choose 'linear', 'sigmoid', or 'cosine'.")


def test_soft_combined_embeddings(embedded_data_list, T=50):
    """
    embedded_data_list: Shape [[filename, img_embedding, neutral_desc_embedding, style_rich_desc_embedding]]
    """
    for sample in embedded_data_list:
        lambda_t = get_lambda_schedule(T, mode="sigmoid")  # TODO: Try different schedules!

        # Compute the soft combination of embeddings
        c_t = soft_combine_embeddings(sample[2], sample[3], lambda_t)
        #print(c_t)


In [5]:
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

def compute_losses(img_neutral, img_interpolated, img_stylized, text_neutral_emb, text_stylized_emb):
    """
    Computes CLIP loss + perceptual loss for disentanglement training.
    """
    # this loss makes no sense because we only have one image

    # Define transformation
    to_tensor = transforms.ToTensor()

    # Convert images to tensors if they are not already
    if not isinstance(img_neutral, torch.Tensor):
        img_neutral = to_tensor(img_neutral).to(DEVICE)

    if not isinstance(img_interpolated, torch.Tensor):
        img_interpolated = to_tensor(img_interpolated).to(DEVICE)

    if not isinstance(img_stylized, torch.Tensor):
        img_stylized = to_tensor(img_stylized).to(DEVICE)

    # Compute CLIP loss
    def clip_loss(img_emb_neutral, img_emb_interpolated, img_emb_stylized, text_neutral_emb, text_stylized_emb):
        direction_text = text_stylized_emb - text_neutral_emb
        direction_image = img_emb_stylized - img_emb_interpolated
        return -F.cosine_similarity(direction_text, direction_image, dim=-1).mean()

    # Compute perceptual loss (ensuring Xₜ remains semantically close to X₀)
    def perceptual_loss(img_neutral, img_interpolated):
        return F.l1_loss(img_neutral, img_interpolated)

    beta = 0.5  # Adjust as needed
    loss = clip_loss(img_neutral, img_interpolated, img_stylized, text_neutral_emb, text_stylized_emb) + beta * perceptual_loss(img_neutral, img_interpolated)

    return loss

def clip_loss(img_emb_interpolated, text_neutral_emb, text_stylized_emb, alpha=0.5):
        # Define transformation
    to_tensor = transforms.ToTensor()

    # Convert images to tensors if they are not already
    if not isinstance(text_neutral_emb, torch.Tensor):
        img_neutral = to_tensor(img_neutral).to(DEVICE)

    if not isinstance(img_emb_interpolated, torch.Tensor):
        img_interpolated = to_tensor(img_interpolated).to(DEVICE)

    if not isinstance(text_stylized_emb, torch.Tensor):
        img_stylized = to_tensor(img_stylized).to(DEVICE)
    sim_neutral = F.cosine_similarity(img_emb_interpolated, text_neutral_emb, dim=-1)
    sim_stylized = F.cosine_similarity(img_emb_interpolated, text_stylized_emb, dim=-1)

    # Encourage similarity to both, but prioritize stylized alignment
    return -((1 - alpha) * sim_neutral.mean() + alpha * sim_stylized.mean())



In [6]:
import sys
import os
from diffusers import StableDiffusionPipeline
import torch
from matplotlib import pyplot as plt
import random
from sklearn.model_selection import train_test_split
from torchvision import transforms

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Disable NSFW checker in pipeline since some of the chest xrays are accidentally flagged which returns a black image
def dummy_checker(images, **kwargs):
    return images, [False] * len(images)# Always return images without flagging them

def generate_images_from_embeddings_visualize(embedded_data_list, T=20, model_id="Nihirc/Prompt2MedImage"):
    """
    Generates and visualizes images using soft-combined embeddings across denoising steps.
    embedded_data_list: [[filename, img_embedding, neutral_desc_embedding, style_rich_desc_embedding]]
    """

    # Load the pre-trained diffusion model
    pipe = StableDiffusionPipeline.from_pretrained(model_id).to(DEVICE)
    pipe.safety_checker = dummy_checker

    all_generated_images = {}  # Store images per sample
    for sample in embedded_data_list:
        filename = sample[0]  # Get filename for reference
        lambda_t = get_lambda_schedule(T, mode="sigmoid").to(DEVICE)
        c_t = soft_combine_embeddings(sample[2], sample[3], lambda_t).to(DEVICE)  # Soft combination

        generated_images = []  # Store images for this sample

        for t in range(T):  # Simulating diffusion steps
            with torch.no_grad():  # No gradients needed for inference
                empty_negative_prompt = torch.zeros_like(c_t[t]).to(DEVICE) # dummy because pipeline expects 2 embeddings

                img = pipe(prompt_embeds=c_t[t].unsqueeze(0), negative_prompt_embeds=empty_negative_prompt.unsqueeze(0), num_inference_steps=20).images[0]
                generated_images.append(img)  # Store generated image

        all_generated_images[filename] = generated_images  # Store all images

        # Visualizing the generated images
        plt.figure(figsize=(10, 2))
        for i in range(min(5, T)):  # Show up to 5 images for preview
            plt.subplot(1, 5, i + 1)
            plt.imshow(generated_images[i])
            plt.axis("off")
        plt.suptitle(f"Generated Images for {filename}")
        plt.show()

    return all_generated_images  # Return all images for further processing


def generate_images_from_embeddings(embedded_data_list, T, model_id="Nihirc/Prompt2MedImage"):
    """
    Generates images using soft-combined embeddings across denoising steps.
    embedded_data_list: [[filename, img_embedding, neutral_desc_embedding, style_rich_desc_embedding]]
    """

    # Load the pre-trained diffusion model
    pipe = StableDiffusionPipeline.from_pretrained(model_id).to(DEVICE)
    pipe.safety_checker = dummy_checker

    all_generated_images = {}  # Store images per sample
    for sample in embedded_data_list:
        filename = sample[0]  # Get filename for reference
        print("generate images: processing file ", filename)
        lambda_t = get_lambda_schedule(T, mode="sigmoid").to(DEVICE)
        c_t = soft_combine_embeddings(sample[2], sample[3], lambda_t).to(DEVICE)  # Soft combination

        generated_images = []  # Store images for this sample

        for t in range(T):  # Simulating diffusion steps
            with torch.no_grad():  # No gradients needed for inference
                empty_negative_prompt = torch.zeros_like(c_t[t]).to(DEVICE) # dummy because pipeline expects 2 embeddings

                img = pipe(prompt_embeds=c_t[t].unsqueeze(0), negative_prompt_embeds=empty_negative_prompt.unsqueeze(0), num_inference_steps=20).images[0]
                generated_images.append(img)  # Store generated image

        all_generated_images[filename] = generated_images  # Store all images

        return generated_images  # Return list of images across diffusion steps

def plot_losses(train_losses, test_losses):
    """
    Plots training and testing loss curves over epochs.
    """
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", marker='o', color='b')
    plt.plot(range(1, len(test_losses) + 1), test_losses, label="Test Loss", marker='s', color='r')

    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training vs Testing Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

def train_lambda(embedded_data_list, T, epochs=10, lr=1e-3):
    """
    Optimizes lambda_t for better disentanglement.
    """
    # Split data into training and testing sets
    clip_image_embedder = FrozenClipImageEmbedder()
    # Trainable λₜ
    lambda_t = get_lambda_schedule(T, mode="sigmoid").to(DEVICE).requires_grad_()

    # Optimizer
    optimizer = torch.optim.Adam([lambda_t], lr=lr)

    # Store losses for plotting
    train_losses = []
    test_losses = []
    print("Start training loop")
    for epoch in range(epochs):
        total_train_loss = 0
        total_test_loss = 0

        # Training loop
        for sample in embedded_data_list:
            print("Training loop: process sample ", sample[0])
            print(sample)
            generated_images = generate_images_from_embeddings([sample], T)
            img_interpolated = generated_images[-1]
            desc_neutral, desc_stylized = sample[2], sample[3]

            # Define a transformation to convert PIL image -> Tensor
            transform = transforms.Compose([
                transforms.ToTensor(),  # Converts PIL image to tensor
                transforms.Resize((224, 224)),  # Resize to CLIP input size
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
            ])

            # Apply transformation
            img_interpolated = transform(img_interpolated).to(DEVICE)

            # Now pass it to the CLIP embedder
            img_interpolated = clip_image_embedder.forward(img_interpolated).squeeze(0)

            # TODO fix this
            #loss = compute_losses(sample[1], img_interpolated_emb, sample[1], sample[2], sample[3])
            loss = clip_loss(img_interpolated, desc_neutral, desc_stylized)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_train_loss:.4f}")

    return lambda_t


def train_lambda_train_test(embedded_data_list, T, save_path, epochs=10, lr=1e-3, batch_size=8):
    """
    Optimizes lambda_t using train/test split and batches.
    """
    # Split data into training and testing sets
    clip_image_embedder = FrozenClipImageEmbedder()
    clip_text_embedder = FrozenClipTextEmbedder()

    # Trainable λₜ
    lambda_t = get_lambda_schedule(T, mode="sigmoid").to(DEVICE).requires_grad_()

    # Optimizer
    optimizer = torch.optim.Adam([lambda_t], lr=lr)

    # Store losses for plotting
    train_losses = []
    test_losses = []
    print("Start training loop")

    # Split the embedded data into train and test sets
    num_samples = len(embedded_data_list)
    train_size = int(0.8 * num_samples)  # 80% for training
    test_size = num_samples - train_size  # 20% for testing

    # Split the embedded data into train and test sets
    train_data = embedded_data_list[:train_size]
    test_data = embedded_data_list[train_size:]

    # Define a transformation to convert PIL image -> Tensor
    transform = transforms.Compose([
         transforms.ToTensor(),  # Converts PIL image to tensor
          transforms.Resize((224, 224)),  # Resize to CLIP input size
          transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize
            ])
    # Train loop with batch processing
    for epoch in range(epochs):
        total_train_loss = 0
        total_test_loss = 0

        # Batching for training
        for i in range(0, len(train_data), batch_size):
            batch_samples = train_data[i:i + batch_size]
            filenames = [sample[0] for sample in batch_samples]
            neutral_desc_embeddings = [sample[2] for sample in batch_samples]
            stylized_desc_embeddings = [sample[3] for sample in batch_samples]

            # Generate embeddings for the batch (process each sample one by one)
            batch_lambda_t = get_lambda_schedule(T, mode="sigmoid").to(DEVICE)
            batch_c_t = []
            for sample in batch_samples:
                neutral_embedding = sample[2]
                stylized_embedding = sample[3]
                c_t = soft_combine_embeddings(neutral_embedding, stylized_embedding, batch_lambda_t)
                batch_c_t.append(c_t.to(DEVICE))  # Process one at a time

            # Generate images for each sample in the batch (individual call to generate_images_from_embeddings)
            generated_images = []
            print("Train: generate images")
            for sample in batch_samples:
                # Generate images for each sample as done previously (using the generate_images_from_embeddings function)
                generated_images_for_sample = generate_images_from_embeddings([sample], T)
                img_interpolated = generated_images_for_sample[-1]
                img_interpolated = transform(img_interpolated).to(DEVICE)

                generated_images.append(img_interpolated)

            # Calculate loss for the batch
            batch_loss = 0
            print("Train: calculate loss")
            for idx, sample in enumerate(batch_samples):
                img_interpolated = generated_images[idx]
                desc_neutral, desc_stylized = sample[2], sample[3]

                #img_interpolated = transforms.ToTensor()(img_interpolated).to_device(DEVICE)#unsqueeze(0).to(DEVICE)
                img_interpolated = clip_image_embedder.forward(img_interpolated).squeeze(0).to(DEVICE)

                # Compute loss for this sample
                loss = clip_loss(img_interpolated, desc_neutral, desc_stylized)
                batch_loss += loss.item()

            # Backpropagate
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            total_train_loss += batch_loss.item()

        # Testing loop (Batch processing for testing)
        for i in range(0, len(test_data), batch_size):
            batch_samples = test_data[i:i + batch_size]
            filenames = [sample[0] for sample in batch_samples]
            neutral_desc_embeddings = [sample[2] for sample in batch_samples]
            stylized_desc_embeddings = [sample[3] for sample in batch_samples]

            # Generate embeddings for the batch (process each sample one by one)
            batch_lambda_t = get_lambda_schedule(T, mode="sigmoid").to(DEVICE)
            batch_c_t = []
            for sample in batch_samples:
                neutral_embedding = sample[2]
                stylized_embedding = sample[3]
                c_t = soft_combine_embeddings(neutral_embedding, stylized_embedding, batch_lambda_t)
                batch_c_t.append(c_t.to(DEVICE))  # Process one at a time

            # Generate images for each sample in the batch (individual call to generate_images_from_embeddings)
            generated_images = []
            print("Test: generate images")
            for sample in batch_samples:
                # Generate images for each sample as done previously (using the generate_images_from_embeddings function)
                generated_images_for_sample = generate_images_from_embeddings([sample], T)
                img_interpolated = generated_images_for_sample[-1]
                img_interpolated = transform(img_interpolated).to(DEVICE)
                generated_images.append(img_interpolated)

            # Calculate loss for the batch
            batch_loss = 0
            print("Test: calculate loss")
            for idx, sample in enumerate(batch_samples):
                img_interpolated = generated_images[idx]
                desc_neutral, desc_stylized = sample[2], sample[3]

                #img_interpolated = transforms.ToTensor()(img_interpolated).to(DEVICE)#unsqueeze(0).to(DEVICE)
                img_interpolated = clip_image_embedder.forward(img_interpolated).squeeze(0).to(DEVICE)

                # Compute loss for this sample
                loss = clip_loss(img_interpolated, desc_neutral, desc_stylized)
                batch_loss += loss.item()

            total_test_loss += batch_loss.item()

        # Print training and testing losses
        avg_train_loss = total_train_loss / len(train_data)
        avg_test_loss = total_test_loss / len(test_data)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}")

        # Store average losses per epoch
        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)

    # Save model and losses
    save_dict = {
        "lambda_t": lambda_t.detach().cpu(),  # Move to CPU before saving
        "optimizer_state": optimizer.state_dict(),
        "train_losses": train_losses,
        "test_losses": test_losses
    }
    torch.save(save_dict, save_path)
    print(f"Model saved at: {save_path}")
    plot_losses(train_losses, test_losses)

    return lambda_t, train_losses, test_losses


def load_lambda_model(load_path=checkpoint_path):
    """
    Loads the saved lambda model.
    """
    if not os.path.exists(load_path):
        print(f"No saved model found at {load_path}!")
        return None, None

    checkpoint = torch.load(load_path, map_location=DEVICE)
    lambda_t = checkpoint["lambda_t"].to(DEVICE).requires_grad_()
    optimizer = torch.optim.Adam([lambda_t])  # Recreate optimizer
    optimizer.load_state_dict(checkpoint["optimizer_state"])

    print(f"Model loaded from: {load_path}")

    return lambda_t, optimizer



In [None]:
pickle_filename = 'dataset_small.pkl'
batch_size = 4
data_with_desc = load_data_add_descriptions(pickle_filename)
data_w_embeddings = add_embeddings(data_with_desc, batch_size=batch_size)
T=5
test = data_w_embeddings[0]
#(test[0])
#print(test[1])
#print(test[2])
#print(test[3])
#generate_images_from_embeddings_visualize(data_w_embeddings)
train_lambda_train_test(data_w_embeddings, T, batch_size = batch_size, save_path=checkpoint_path)
#train_lambda([data_w_embeddings[0]], T) # test only one for faster testing

20536686640136348236148679891455886468_k6ga29.png A chest X-ray of a patient with no findings. A chest X-ray of a patient with no findings taken in PA orientation.
135803415504923515076821959678074435083_fzis7d.png A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings. A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings, taken in L orientation.
135803415504923515076821959678074435083_fzis7b.png A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings. A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings, taken in PA orientation.
113855343774216031107737439268243531979_3k951l.png A chest X-ray of a patient with chronic changes. A chest X-ray of a patient with chronic changes, taken in PA orientation.
113855343774216031107737439268243531979_3k951n.png A chest X-ray of a patient with chronic changes. A chest

100%|███████████████████████████████████████| 890M/890M [00:14<00:00, 66.3MiB/s]


A chest X-ray of a patient with no findings.
A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings.
A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings.
A chest X-ray of a patient with chronic changes.
A chest X-ray of a patient with no findings taken in PA orientation.
A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings, taken in L orientation.
A chest X-ray of a patient with pulmonary fibrosis, chronic changes, kyphosis, and 2 other findings, taken in PA orientation.
A chest X-ray of a patient with chronic changes, taken in PA orientation.
ndims: 2  shape: torch.Size([1728, 1872])
first if ndims: 3  shape: torch.Size([3, 1728, 1872])
ndims: 2  shape: torch.Size([3296, 3236])
first if ndims: 3  shape: torch.Size([3, 3296, 3236])
ndims: 2  shape: torch.Size([3572, 3732])
first if ndims: 3  shape: torch.Size([3, 3572, 3732])
ndims: 2  shape: torch

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]

scheduler_config.json:   0%|          | 0.00/341 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/5.00k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/518 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/912 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.60k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

generate images: processing file  20536686640136348236148679891455886468_k6ga29.png


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

generate images: processing file  135803415504923515076821959678074435083_fzis7d.png


  0%|          | 0/20 [00:00<?, ?it/s]