<a href="https://colab.research.google.com/github/kkhaledaawad/Roomify-AI/blob/main/Roomify_GAN_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Environment setup**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!python --version

Python 3.11.12


In [None]:
from google.colab import files
files.upload()
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets list

Saving kaggle.json to kaggle.json
ref                                                                  title                                                     size  lastUpdated                 downloadCount  voteCount  usabilityRating  
-------------------------------------------------------------------  --------------------------------------------------  ----------  --------------------------  -------------  ---------  ---------------  
atharvasoundankar/chocolate-sales                                    Chocolate Sales Data 📊🍫                                  14473  2025-03-19 03:51:40.270000          11523        203  1.0              
adilshamim8/student-depression-dataset                               Student Depression Dataset                              467020  2025-03-13 03:12:30.423000           3881         63  1.0              
abdulmalik1518/mobiles-dataset-2025                                  Mobiles Dataset (2025)                                   20314  2025-02-18 06

# **COCO2017 dataset preperation**

In [None]:
import json
import pandas as pd

# Path to your annotations file
coco_caption_path = "/content/drive/MyDrive/Roomify/data/coco/annotations/captions_train2017.json"

# Load JSON
with open(coco_caption_path, 'r') as f:
    coco_data = json.load(f)

# Map image_id to file_name
image_id_to_file = {img['id']: img['file_name'] for img in coco_data['images']}

# Extract caption entries
rows = []
for ann in coco_data['annotations']:
    image_id = ann['image_id']
    caption = ann['caption']
    image_name = image_id_to_file[image_id]
    rows.append([image_name, caption, ""])  # No mask path for COCO

# Save to CSV
df = pd.DataFrame(rows, columns=["image_name", "prompt_text", "mask_path"])
csv_path = "/content/drive/MyDrive/Roomify/data/coco/prompts.csv"
df.to_csv(csv_path, index=False)

print(f"✅ Saved {len(df)} caption entries to {csv_path}")


✅ Saved 591753 caption entries to /content/drive/MyDrive/Roomify/data/coco/prompts.csv


In [None]:
df_coco = pd.read_csv("/content/drive/MyDrive/Roomify/data/coco/prompts.csv")
print(df_coco.head(3))

         image_name                                        prompt_text  \
0  000000203564.jpg  A bicycle replica with a clock as the front wh...   
1  000000322141.jpg  A room with blue walls and a white sink and door.   
2  000000016977.jpg  A car that seems to be parked illegally behind...   

   mask_path  
0        NaN  
1        NaN  
2        NaN  


In [None]:
!ls "/content/drive/MyDrive/Roomify/data/coco/images/train2017" | head -n 5

000000000009.jpg
000000000025.jpg
000000000030.jpg
000000000034.jpg
000000000049.jpg


In [None]:
import os
import pandas as pd

# COCO image folder
coco_dir = "/content/drive/MyDrive/Roomify/data/coco/images/train2017"
output_csv = "/content/drive/MyDrive/Roomify/data/coco/prompts.csv"

rows = []
for file in os.listdir(coco_dir):
    if file.endswith(".jpg"):
        image_name = f"coco/images/train2017/{file}"
        prompt = "A realistic indoor scene"  # (optional: you can update this later)
        rows.append([image_name, prompt, ""])

df = pd.DataFrame(rows, columns=["image_name", "prompt_text", "mask_path"])
df.to_csv(output_csv, index=False)

print(f"✅ Fixed COCO prompts.csv with {len(df)} entries saved to: {output_csv}")

✅ Fixed COCO prompts.csv with 76783 entries saved to: /content/drive/MyDrive/Roomify/data/coco/prompts.csv


# **ADE20K dataset downloading and preperation**

In [None]:
!kaggle datasets download -d awsaf49/ade20k-dataset

Dataset URL: https://www.kaggle.com/datasets/awsaf49/ade20k-dataset
License(s): unknown


In [None]:
!mkdir -p /content/drive/MyDrive/Roomify/data/ade20k/
!mv /content/ade20k-dataset.zip /content/drive/MyDrive/Roomify/data/ade20k/

In [None]:
!unzip -q /content/drive/MyDrive/Roomify/data/ade20k/ade20k-dataset.zip -d /content/drive/MyDrive/Roomify/data/ade20k/

In [None]:
!ls /content/drive/MyDrive/Roomify/data/ade20k/ADEChallengeData2016/images/training | head

ADE_train_00000001.jpg
ADE_train_00000002.jpg
ADE_train_00000003.jpg
ADE_train_00000004.jpg
ADE_train_00000005.jpg
ADE_train_00000006.jpg
ADE_train_00000007.jpg
ADE_train_00000008.jpg
ADE_train_00000009.jpg
ADE_train_00000010.jpg


In [None]:
import os
import pandas as pd

# Paths
images_dir = "/content/drive/MyDrive/Roomify/data/ade20k/ADEChallengeData2016/images/training"
masks_dir = "/content/drive/MyDrive/Roomify/data/ade20k/ADEChallengeData2016/annotations/training"
output_csv = "/content/drive/MyDrive/Roomify/data/ade20k/prompts.csv"

# Build rows
data = []
for file in os.listdir(images_dir):
    if file.endswith(".jpg"):
        image_name = file
        mask_name = file.replace(".jpg", ".png")
        mask_path = f"ADEChallengeData2016/annotations/training/{mask_name}"
        prompt_text = "A scene with detailed semantic layout."  # Placeholder
        data.append([image_name, prompt_text, mask_path])

# Save CSV
df = pd.DataFrame(data, columns=["image_name", "prompt_text", "mask_path"])
df.to_csv(output_csv, index=False)

print(f"✅ Saved {len(df)} entries to prompts.csv for ADE20K")


✅ Saved 20210 entries to prompts.csv for ADE20K


In [None]:
df_ade20k = pd.read_csv("/content/drive/MyDrive/Roomify/data/ade20k/prompts.csv")
print(df_coco.head(3))

         image_name                                        prompt_text  \
0  000000203564.jpg  A bicycle replica with a clock as the front wh...   
1  000000322141.jpg  A room with blue walls and a white sink and door.   
2  000000016977.jpg  A car that seems to be parked illegally behind...   

   mask_path  
0        NaN  
1        NaN  
2        NaN  


In [None]:
import os
import pandas as pd

img_dir = "/content/drive/MyDrive/Roomify/data/ade20k/ADEChallengeData2016/images/training"
mask_dir = "/content/drive/MyDrive/Roomify/data/ade20k/ADEChallengeData2016/annotations/training"
output_csv = "/content/drive/MyDrive/Roomify/data/ade20k/prompts.csv"

rows = []
for file in os.listdir(img_dir):
    if file.endswith(".jpg"):
        image_path = f"ade20k/ADEChallengeData2016/images/training/{file}"
        mask_path = f"ade20k/ADEChallengeData2016/annotations/training/{file.replace('.jpg', '_seg.png')}"
        prompt = "A scene with detailed semantic layout"
        rows.append([image_path, prompt, mask_path])

df = pd.DataFrame(rows, columns=["image_name", "prompt_text", "mask_path"])
df.to_csv(output_csv, index=False)

print(f"✅ Fixed ADE20K prompts.csv with {len(df)} entries saved to: {output_csv}")


✅ Fixed ADE20K prompts.csv with 20210 entries saved to: /content/drive/MyDrive/Roomify/data/ade20k/prompts.csv


# **Furniture Dataset preperation**

In [None]:
!kaggle datasets download -d udaysankarmukherjee/furniture-image-dataset -p /content/drive/MyDrive/Roomify/data/furniture

Dataset URL: https://www.kaggle.com/datasets/udaysankarmukherjee/furniture-image-dataset
License(s): apache-2.0


In [None]:
!unzip -q /content/drive/MyDrive/Roomify/data/furniture/furniture-image-dataset.zip -d /content/drive/MyDrive/Roomify/data/furniture/

In [None]:
!ls /content/drive/MyDrive/Roomify/data/furniture/ | head

almirah_dataset
chair_dataset
fridge dataset
furniture-image-dataset.zip
table dataset
tv dataset


In [None]:
import os
import pandas as pd

root_dir = "/content/drive/MyDrive/Roomify/data/furniture"
output_csv = f"{root_dir}/prompts.csv"

rows = []
for category in os.listdir(root_dir):
    category_path = os.path.join(root_dir, category)
    if os.path.isdir(category_path) and not category.endswith(".zip"):
        for file in os.listdir(category_path):
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                image_path = f"{category}/{file}"
                clean_category = category.replace("_", " ").replace("dataset", "").strip()
                prompt = f"A modern {clean_category}"
                rows.append([image_path, prompt, ""])

df = pd.DataFrame(rows, columns=["image_name", "prompt_text", "mask_path"])
df.to_csv(output_csv, index=False)

print(f"✅ Saved {len(df)} entries to: {output_csv}")


✅ Saved 15000 entries to: /content/drive/MyDrive/Roomify/data/furniture/prompts.csv


In [None]:
import os
import pandas as pd

root_dir = "/content/drive/MyDrive/Roomify/data/furniture"
output_csv = f"{root_dir}/prompts.csv"

rows = []
for category in os.listdir(root_dir):
    category_path = os.path.join(root_dir, category)
    if os.path.isdir(category_path) and not category.endswith(".zip"):
        for file in os.listdir(category_path):
            if file.lower().endswith((".jpg", ".jpeg", ".png")):
                image_path = f"furniture/{category}/{file}"
                prompt = f"A modern {category.replace('_', ' ').replace('dataset', '').strip()}"
                rows.append([image_path, prompt, ""])

df = pd.DataFrame(rows, columns=["image_name", "prompt_text", "mask_path"])
df.to_csv(output_csv, index=False)

print(f"✅ Fixed Furniture prompts.csv with {len(df)} entries saved to: {output_csv}")


✅ Fixed Furniture prompts.csv with 15000 entries saved to: /content/drive/MyDrive/Roomify/data/furniture/prompts.csv


# **Interior Design dataset**

In [None]:
!kaggle datasets download -d aishahsofea/interior-design -p /content/drive/MyDrive/Roomify/data/interior_design

Dataset URL: https://www.kaggle.com/datasets/aishahsofea/interior-design
License(s): copyright-authors


In [None]:
!unzip -q /content/drive/MyDrive/Roomify/data/interior_design/interior-design.zip -d /content/drive/MyDrive/Roomify/data/interior_design

In [None]:
!ls /content/drive/MyDrive/Roomify/data/interior_design | head

interior-design.zip
resized_images


In [None]:
import os
import pandas as pd

img_dir = "/content/drive/MyDrive/Roomify/data/interior_design/resized_images"
output_csv = "/content/drive/MyDrive/Roomify/data/interior_design/prompts.csv"

rows = []
for file in os.listdir(img_dir):
    if file.lower().endswith((".jpg", ".jpeg", ".png")):
        image_name = f"interior_design/resized_images/{file}"
        prompt = "A beautifully designed interior room"
        rows.append([image_name, prompt, ""])

df = pd.DataFrame(rows, columns=["image_name", "prompt_text", "mask_path"])
df.to_csv(output_csv, index=False)

print(f"✅ Fixed Interior Design prompts.csv with {len(df)} entries saved to: {output_csv}")


✅ Fixed Interior Design prompts.csv with 4147 entries saved to: /content/drive/MyDrive/Roomify/data/interior_design/prompts.csv


# **SUN_RGBD Dataset**

In [None]:
!kaggle datasets download -d thanhbnhphan/sun-rgbd-2d -p /content/drive/MyDrive/Roomify/data/sun_rgbd

Dataset URL: https://www.kaggle.com/datasets/thanhbnhphan/sun-rgbd-2d
License(s): MIT


In [None]:
!unzip -q /content/drive/MyDrive/Roomify/data/sun_rgbd/sun-rgbd-2d.zip -d /content/drive/MyDrive/Roomify/data/sun_rgbd

In [None]:
!ls /content/drive/MyDrive/Roomify/data/sun_rgbd/MYSUN | head

depth
depth_bfx
image
info.json


In [None]:
import os
import pandas as pd

img_dir = "/content/drive/MyDrive/Roomify/data/sun_rgbd/MYSUN/image"
output_csv = "/content/drive/MyDrive/Roomify/data/sun_rgbd/prompts.csv"

rows = []
for file in os.listdir(img_dir):
    if file.lower().endswith((".jpg", ".jpeg", ".png")):
        image_name = f"sun_rgbd/MYSUN/image/{file}"
        prompt = "An indoor scene with depth and layout"
        rows.append([image_name, prompt, ""])

df = pd.DataFrame(rows, columns=["image_name", "prompt_text", "mask_path"])
df.to_csv(output_csv, index=False)

print(f"✅ Fixed SUN RGB-D prompts.csv with {len(df)} entries saved to: {output_csv}")

# **Merge all datasets prompts**

In [None]:
import pandas as pd

paths = {
    "coco": "/content/drive/MyDrive/Roomify/data/coco/prompts.csv",
    "ade20k": "/content/drive/MyDrive/Roomify/data/ade20k/prompts.csv",
    "furniture": "/content/drive/MyDrive/Roomify/data/furniture/prompts.csv",
    "interior": "/content/drive/MyDrive/Roomify/data/interior_design/prompts.csv",
    "sunrgbd": "/content/drive/MyDrive/Roomify/data/sun_rgbd/prompts.csv",
}

df_all = []

for name, path in paths.items():
    df = pd.read_csv(path)
    df = df.dropna(subset=["image_name", "prompt_text"])
    df["source"] = name
    df_all.append(df)

df_merged = pd.concat(df_all, ignore_index=True)

# Final path
final_csv = "/content/drive/MyDrive/Roomify/data/unified_prompts.csv"
df_merged.to_csv(final_csv, index=False)

print(f"✅ Unified prompts.csv created with {len(df_merged)} entries at: {final_csv}")

✅ Unified prompts.csv created with 641445 entries at: /content/drive/MyDrive/Roomify/data/unified_prompts.csv


# **Preprocessing data**

In [None]:
import os
from PIL import Image
import pandas as pd
from tqdm import tqdm

# ==== CONFIG ====
input_csv = "/content/drive/MyDrive/Roomify/data/unified_prompts.csv"
input_base = "/content/drive/MyDrive/Roomify/data"
output_base = "/content/drive/MyDrive/Roomify/data/processed"
output_images = os.path.join(output_base, "images")
output_masks = os.path.join(output_base, "masks")
output_csv = os.path.join(output_base, "unified_prompts.csv")
target_size = (256, 256)

# ==== SETUP ====
os.makedirs(output_images, exist_ok=True)
os.makedirs(output_masks, exist_ok=True)

# ==== LOAD DATA ====
df = pd.read_csv(input_csv)
processed_rows = []

# ==== LOOP ====
for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
    image_rel = row["image_name"]
    prompt = row["prompt_text"]
    mask_rel = row["mask_path"]

    # === IMAGE ===
    image_path = os.path.join(input_base, image_rel)
    try:
        img = Image.open(image_path).convert("RGB").resize(target_size)
        img_name = f"img_{idx:06d}.jpg"
        img.save(os.path.join(output_images, img_name), "JPEG")
    except Exception as e:
        print(f"[Image Error] Skipped row {idx}: {e}")
        continue

    # === MASK ===
    mask_name = ""
    if isinstance(mask_rel, str) and mask_rel.strip():
        mask_path = os.path.join(input_base, mask_rel)
        try:
            mask = Image.open(mask_path).convert("L").resize(target_size)
            mask_name = f"mask_{idx:06d}.png"
            mask.save(os.path.join(output_masks, mask_name), "PNG")
        except Exception as e:
            print(f"[Mask Error] Row {idx} mask skipped: {e}")
            mask_name = ""

    # === RECORD ROW ===
    processed_rows.append([
        f"images/{img_name}",
        prompt,
        f"masks/{mask_name}" if mask_name else ""
    ])

# ==== SAVE CSV ====
df_out = pd.DataFrame(processed_rows, columns=["image_name", "prompt_text", "mask_path"])
df_out.to_csv(output_csv, index=False)

print(f"\n✅ Preprocessing complete.")
print(f"📁 Final CSV: {output_csv}")
print(f"📸 Processed images: {len(df_out)}")

🟢 Resuming from index: 97171


Processing (resumed): 100%|██████████| 24474/24474 [2:16:05<00:00,  3.00it/s]


✅ Resumed preprocessing complete! CSV now has 24474 rows.


In [None]:
df = pd.read_csv("/content/drive/MyDrive/Roomify/data/unified_prompts.csv")
print(df.head(3))
print(df.iloc[0]['image_name'])
print(os.path.exists("/content/drive/MyDrive/Roomify/data/" + df.iloc[0]['image_name']))

         image_name                                        prompt_text  \
0  000000203564.jpg  A bicycle replica with a clock as the front wh...   
1  000000322141.jpg  A room with blue walls and a white sink and door.   
2  000000016977.jpg  A car that seems to be parked illegally behind...   

  mask_path source  
0       NaN   coco  
1       NaN   coco  
2       NaN   coco  
000000203564.jpg
False


  df = pd.read_csv("/content/drive/MyDrive/Roomify/data/unified_prompts.csv")


In [None]:
import pandas as pd

# Original prompts.csv files with correct image paths
paths = {
    "coco": "/content/drive/MyDrive/Roomify/data/coco/prompts.csv",
    "ade20k": "/content/drive/MyDrive/Roomify/data/ade20k/prompts.csv",
    "furniture": "/content/drive/MyDrive/Roomify/data/furniture/prompts.csv",
    "interior": "/content/drive/MyDrive/Roomify/data/interior_design/prompts.csv",
    "sunrgbd": "/content/drive/MyDrive/Roomify/data/sun_rgbd/prompts.csv",
}

# Load and tag each one with its source
df_all = []
for name, path in paths.items():
    df = pd.read_csv(path)
    df["source"] = name
    df_all.append(df)

# Merge into one big CSV again
df_merged = pd.concat(df_all, ignore_index=True)

# Save the clean, fixed version
output_path = "/content/drive/MyDrive/Roomify/data/unified_prompts.csv"
df_merged.to_csv(output_path, index=False)

print(f"✅ Rebuilt unified_prompts.csv with {len(df_merged)} rows.")
print(f"📄 Saved to: {output_path}")


✅ Rebuilt unified_prompts.csv with 641445 rows.
📄 Saved to: /content/drive/MyDrive/Roomify/data/unified_prompts.csv


In [None]:
import os
import pandas as pd

# Load the rebuilt CSV
df = pd.read_csv("/content/drive/MyDrive/Roomify/data/unified_prompts.csv")

# Show first few rows
print(df.head(3))

# Test if the first file exists
test_path = "/content/drive/MyDrive/Roomify/data/" + df.iloc[0]["image_name"]
print("Checking path:", test_path)
print("Exists?", os.path.exists(test_path))


         image_name                                        prompt_text  \
0  000000203564.jpg  A bicycle replica with a clock as the front wh...   
1  000000322141.jpg  A room with blue walls and a white sink and door.   
2  000000016977.jpg  A car that seems to be parked illegally behind...   

  mask_path source  
0       NaN   coco  
1       NaN   coco  
2       NaN   coco  
Checking path: /content/drive/MyDrive/Roomify/data/000000203564.jpg
Exists? False


  df = pd.read_csv("/content/drive/MyDrive/Roomify/data/unified_prompts.csv")


In [None]:
import pandas as pd

paths = {
    "coco": "/content/drive/MyDrive/Roomify/data/coco/prompts.csv",
    "ade20k": "/content/drive/MyDrive/Roomify/data/ade20k/prompts.csv",
    "furniture": "/content/drive/MyDrive/Roomify/data/furniture/prompts.csv",
    "interior": "/content/drive/MyDrive/Roomify/data/interior_design/prompts.csv",
    "sunrgbd": "/content/drive/MyDrive/Roomify/data/sun_rgbd/prompts.csv",
}

df_all = []
for name, path in paths.items():
    df = pd.read_csv(path)
    df = df.dropna(subset=["image_name", "prompt_text"])  # drop broken rows
    df["source"] = name
    df_all.append(df)

df_merged = pd.concat(df_all, ignore_index=True)

# ✅ Make sure paths are relative to `/Roomify/data/`
# DO NOT strip folders!
df_merged.to_csv("/content/drive/MyDrive/Roomify/data/unified_prompts.csv", index=False)

print(f"✅ Fixed unified_prompts.csv written with {len(df_merged)} valid rows.")


✅ Fixed unified_prompts.csv written with 641445 valid rows.


In [None]:
df = pd.read_csv("/content/drive/MyDrive/Roomify/data/unified_prompts.csv")

first_image = df.iloc[0]["image_name"]
full_path = "/content/drive/MyDrive/Roomify/data/" + first_image

print("Image:", first_image)
print("Exists?", os.path.exists(full_path))


Image: 000000203564.jpg
Exists? False


  df = pd.read_csv("/content/drive/MyDrive/Roomify/data/unified_prompts.csv")


In [None]:
# from PIL import Image
# import os
# from torchvision import transforms

# source_dir = "/content/drive/MyDrive/Roomify/data/processed/images"
# target_dir = "/content/drive/MyDrive/Roomify/data/processed_512/images"
# os.makedirs(target_dir, exist_ok=True)

# transform = transforms.Compose([
#     transforms.Resize((512, 512)),
#     transforms.ToTensor()
# ])

# for file in os.listdir(source_dir):
#     if file.endswith((".jpg", ".png")):
#         try:
#             img_path = os.path.join(source_dir, file)
#             image = Image.open(img_path).convert("RGB")
#             img_tensor = transform(image)
#             img_output_path = os.path.join(target_dir, file)
#             transforms.ToPILImage()(img_tensor).save(img_output_path)
#         except Exception as e:
#             print(f"Error processing {file}: {e}")

# **Debugging**

In [None]:
import os
import pandas as pd

# Load CSV
csv_path = "/content/drive/MyDrive/Roomify/data/unified_prompts.csv"
base_dir = "/content/drive/MyDrive/Roomify/data"
df = pd.read_csv(csv_path)

# Validate image and mask paths
missing_images = []
missing_masks = []

for i, row in df.iterrows():
    img_path = os.path.join(base_dir, row["image_name"])
    if not os.path.exists(img_path):
        missing_images.append(row["image_name"])

    mask = row["mask_path"]
    if isinstance(mask, str) and mask.strip():
        mask_path = os.path.join(base_dir, mask)
        if not os.path.exists(mask_path):
            missing_masks.append(mask)

print(f"✅ Valid image paths: {len(df) - len(missing_images)} / {len(df)}")
print(f"🟥 Missing images: {len(missing_images)}")
print(f"🟨 Missing masks: {len(missing_masks)}")


✅ Valid image paths: 126475 / 126475
🟥 Missing images: 0
🟨 Missing masks: 20210


In [None]:
missing_masks = []

for i, row in df.iterrows():
    mask_path = row["mask_path"]
    if isinstance(mask_path, str) and mask_path.strip():
        full_path = os.path.join(base_dir, mask_path)
        if not os.path.exists(full_path):
            missing_masks.append(mask_path)

print(f"❌ Missing masks: {len(missing_masks)}")


❌ Missing masks: 20210


In [None]:
def file_exists(row):
    return os.path.exists(os.path.join(base_dir, row["image_name"]))

df_valid = df[df.apply(file_exists, axis=1)]
df_valid.to_csv(csv_path, index=False)
print(f"✅ Filtered CSV with {len(df_valid)} valid rows saved.")


✅ Filtered CSV with 0 valid rows saved.


# **RoomifyDataset Class**

In [None]:
!pip install transformers



In [None]:
import os
import pandas as pd
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class RoomifyDataset(Dataset):
    def __init__(self, csv_path, root_dir, transform=None, return_mask=True):
        self.data = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.transform = transform
        self.return_mask = return_mask

        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")

        self.default_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = os.path.join(self.root_dir, row["image_name"])
        prompt_text = row["prompt_text"]
        mask_path = row["mask_path"]

        # Load image
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image) if self.transform else self.default_transform(image)

        # Load mask or return empty tensor
        if self.return_mask and isinstance(mask_path, str) and len(mask_path.strip()) > 0:
            full_mask_path = os.path.join(self.root_dir, mask_path)
            try:
                mask = Image.open(full_mask_path).convert("L")
                mask = transforms.ToTensor()(mask)  # 1xHxW
            except:
                mask = torch.zeros((1, 256, 256))  # fallback
        else:
            mask = torch.zeros((1, 256, 256))

        return {
            "image": image,
            "text": prompt_text,
            "mask": mask,
            "index": idx
        }

In [None]:
import torch

def custom_collate(batch):
    # Filter out any None samples
    batch = [item for item in batch if item is not None]

    # If batch is empty, raise an error
    if len(batch) == 0:
        raise ValueError("All items in the batch are None.")

    images = torch.stack([item["image"] for item in batch])
    texts = [item["text"] for item in batch]

    # Handle masks
    if batch[0]["mask"] is not None:
        masks = torch.stack([item["mask"] for item in batch])
    else:
        masks = torch.zeros((len(batch), 1, 256, 256))  # fallback dummy masks

    return {
        "image": images,
        "text": texts,
        "mask": masks
    }

**data Loader**

In [None]:
from torch.utils.data import DataLoader

# Initialize the dataset
dataset = RoomifyDataset(
    csv_path="/content/drive/MyDrive/Roomify/data/processed/unified_prompts.csv",
    root_dir="/content/drive/MyDrive/Roomify/data/processed/"
)

# Create DataLoader
roomify_loader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=2,
    collate_fn=custom_collate
)

# Preview a batch
batch = next(iter(roomify_loader))
print("✅ Batch loaded!")
print("Image batch shape:", batch["image"].shape)     # [B, 3, 256, 256]
print("Text batch:", batch["text"][:2])               # Two sample prompts
print("Mask batch shape:", batch["mask"].shape)       # [B, 1, 256, 256]


✅ Batch loaded!
Image batch shape: torch.Size([16, 3, 256, 256])
Text batch: ['A beautifully designed interior room', 'A modern table']
Mask batch shape: torch.Size([16, 1, 256, 256])


**data module phase**

In [None]:
!pip install pytorch-lightning --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.0/823.0 kB[0m [31m38.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m107.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m87.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m49.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Roomify")
from datasets.roomify_datamodule import RoomifyDataModule

dm = RoomifyDataModule(
    csv_path="/content/drive/MyDrive/Roomify/data/processed/unified_prompts.csv",
    root_dir="/content/drive/MyDrive/Roomify/data/processed/",
    batch_size=16
)

dm.setup()
loader = dm.train_dataloader()

batch = next(iter(loader))
print("✅ Batch loaded!")
print("Image shape:", batch["image"].shape)
print("Prompt sample:", batch["text"][:2])
print("Mask shape:", batch["mask"].shape)

[Image Error] Row 7070 image failed: [Errno 5] Input/output error: '/content/drive/MyDrive/Roomify/data/processed/images/img_104241.jpg'[Image Error] Row 4752 image failed: [Errno 5] Input/output error: '/content/drive/MyDrive/Roomify/data/processed/images/img_101923.jpg'

[Image Error] Row 16379 image failed: [Errno 5] Input/output error: '/content/drive/MyDrive/Roomify/data/processed/images/img_113550.jpg'[Image Error] Row 3948 image failed: [Errno 5] Input/output error: '/content/drive/MyDrive/Roomify/data/processed/images/img_101119.jpg'

[Image Error] Row 9029 image failed: [Errno 5] Input/output error: '/content/drive/MyDrive/Roomify/data/processed/images/img_106200.jpg'[Image Error] Row 8193 image failed: [Errno 5] Input/output error: '/content/drive/MyDrive/Roomify/data/processed/images/img_105364.jpg'

✅ Batch loaded!
Image shape: torch.Size([16, 3, 256, 256])
Prompt sample: ['[Corrupted image]', '[Corrupted image]']
Mask shape: torch.Size([16, 1, 256, 256])


# **Model Architecture Design**

In [None]:
# ✅ إعادة تثبيت PyTorch متوافقة مع A100 (CUDA 11.8)
!pip uninstall -y torch torchvision torchaudio
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# ✅ إعادة تثبيت CLIP من OpenAI (بعد PyTorch)
!pip install git+https://github.com/openai/CLIP.git

# ✅ إعادة تشغيل الجلسة تلقائيًا (مطلوبة بعد التثبيت)
import os
os.kill(os.getpid(), 9)

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: torchaudio 2.6.0+cu124
Uninstalling torchaudio-2.6.0+cu124:
  Successfully uninstalled torchaudio-2.6.0+cu124
Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting torch
  Downloading https://download.pytorch.org/whl/cu118/torch-2.7.0%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu118/torchvision-0.22.0%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu118/torchaudio-2.7.0%2Bcu118-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading https://download.pyto

In [1]:
!pip install wandb
import wandb
wandb.login()



<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mroomify6[0m ([33mroomify6-cairo-higher-institute-for-engineering[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
!pip install pytorch_lightning

Collecting pytorch_lightning
  Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>=0.7.0 (from pytorch_lightning)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch_lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl (823 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.1/823.1 kB[0m [31m15.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.14.3-py3-none-any.whl (28 kB)
Downloading torchmetrics-1.7.1-py3-none-any.whl (961 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m961.5/961.5 kB[0m [31m42.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lightning-utilities, torchmetrics, pytorch_lightning
Successfully installed lightning-utilities-0.14.3 pytorch_lightning-2.5.1.post0 torchmetrics-1.7.1


In [3]:
!pip install torchmetrics



In [4]:
!pip install torch-fidelity

Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Installing collected packages: torch-fidelity
Successfully installed torch-fidelity-0.3.0


In [5]:
!pip install torch torchvision clip pillow



In [6]:
!pip install tqdm wandb ftfy regex matplotlib pillow



In [7]:
!pip install torch torchvision wandb tqdm pytorch-msssim lpips

Collecting pytorch-msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Collecting lpips
  Downloading lpips-0.1.4-py3-none-any.whl.metadata (10 kB)
Downloading pytorch_msssim-1.0.0-py3-none-any.whl (7.7 kB)
Downloading lpips-0.1.4-py3-none-any.whl (53 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pytorch-msssim, lpips
Successfully installed lpips-0.1.4 pytorch-msssim-1.0.0


In [8]:
!pip install deep-translator

Collecting deep-translator
  Downloading deep_translator-1.11.4-py3-none-any.whl.metadata (30 kB)
Downloading deep_translator-1.11.4-py3-none-any.whl (42 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.3/42.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: deep-translator
Successfully installed deep-translator-1.11.4


In [None]:
!python /content/drive/MyDrive/Roomify/inference/inference.py

In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Roomify")
import os
import torch
from torch.utils.data import DataLoader
from datasets.roomify_dataset import RoomifyDataset
from models.generator import ConditionalUNetGenerator
from models.discriminator import RoomifyDiscriminator
from models.clip_encoder import CLIPTextEncoder
from training.trainer import RoomifyGANTrainer

# Set up directories with version number to avoid overwriting
VERSION = "v2_extreme"
CHECKPOINT_DIR = f"/content/drive/MyDrive/Roomify/checkpoints_{VERSION}"
GENERATED_DIR = f"/content/drive/MyDrive/Roomify/generated_{VERSION}"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(GENERATED_DIR, exist_ok=True)

# Set device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Load dataset with caching
dataset = RoomifyDataset(
    csv_path="/content/drive/MyDrive/Roomify/data/processed/unified_prompts.csv",
    root_dir="/content/drive/MyDrive/Roomify/data/processed",
    image_size=(256, 256),
    use_cache=True,  # Enable caching
    cache_size=200   # Cache size
)
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True  # Avoid batch size issues
)
print(f"Dataset loaded with {len(dataset)} samples")

# Create models
generator = ConditionalUNetGenerator(in_channels=6, out_channels=3).to(DEVICE)
discriminator = RoomifyDiscriminator(in_channels=6, text_dim=512).to(DEVICE)
clip_encoder = CLIPTextEncoder(device=DEVICE)

# Create trainer with extreme settings
trainer = RoomifyGANTrainer(
    generator=generator,
    discriminator=discriminator,
    text_encoder=clip_encoder,
    dataloader=dataloader,
    device=DEVICE,
    g_lr=2e-4,
    d_lr=3e-4,
    ckpt_dir=CHECKPOINT_DIR,
    image_save_dir=GENERATED_DIR,
    project_name=f"roomify-{VERSION}"
)

# Train for 40 epochs
trainer.train(num_epochs=40)

Using device: cuda


In [None]:
!python /content/drive/MyDrive/Roomify/models/compute_FID.py

📥 Loading images...
📸 Fake Images: torch.Size([97, 3, 256, 256])
🏞️ Real Images: torch.Size([500, 3, 256, 256])

📊 Calculating FID Score...

🎯 Final FID Score = 1.8786


In [None]:
import os
import sys
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datetime import datetime
import sys
sys.path.append("/content/drive/MyDrive/Roomify")

def test_inference(
    image_path,
    prompt,
    checkpoint_path=None,
    output_path=None,
    transformation_strength=1.0,
    show_attention_maps=False,
    show_intermediate=False,
    side_by_side=True
):
    """
    Run inference on a single image with the specified text prompt.

    Args:
        image_path (str): Path to the input image
        prompt (str): Text description of the desired transformation
        checkpoint_path (str, optional): Path to a specific model checkpoint
        output_path (str, optional): Path to save the output image
        transformation_strength (float): Controls the intensity of the transformation (0.0-1.0)
        show_attention_maps (bool): Whether to visualize attention maps
        show_intermediate (bool): Whether to show intermediate transformation steps
        side_by_side (bool): Whether to display before/after comparison

    Returns:
        str: Path to the saved output image
    """
    # Validate input image
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Input image not found: {image_path}")

    valid_extensions = ['.jpg', '.jpeg', '.png']
    if not any(image_path.lower().endswith(ext) for ext in valid_extensions):
        raise ValueError(f"Unsupported image format. Please use: {valid_extensions}")

    # Set default output path if not provided
    if output_path is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_dir = os.path.join(os.path.dirname(image_path), "outputs")
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, f"transformed_{timestamp}.jpg")

    # Import necessary modules from your project
    try:
        from models.generator import Generator
        from models.text_encoder import TextEncoder
        from utils.image_processing import preprocess_image, postprocess_image
    except ImportError as e:
        print(f"Error importing project modules: {e}")
        print("Make sure your project path is correctly set in sys.path")
        return None

    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load model
    print(f"Loading model checkpoint...")
    try:
        # Initialize text encoder
        text_encoder = TextEncoder().to(device)

        # Initialize generator
        generator = Generator().to(device)

        # Load checkpoint
        if checkpoint_path is None:
            # Try to find the latest checkpoint
            checkpoint_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "checkpoints")
            checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith('.pth')]
            if not checkpoints:
                raise FileNotFoundError(f"No checkpoints found in {checkpoint_dir}")

            latest_checkpoint = sorted(checkpoints)[-1]
            checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)

        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        generator.load_state_dict(checkpoint['generator'])
        text_encoder.load_state_dict(checkpoint['text_encoder'])

        # Set to evaluation mode
        generator.eval()
        text_encoder.eval()

    except Exception as e:
        print(f"Error loading model: {e}")
        return None

    # Process the image
    print(f"Processing image with prompt: '{prompt}'")
    try:
        # Load and preprocess the image
        input_image = Image.open(image_path).convert('RGB')
        processed_image = preprocess_image(input_image).to(device)

        # Encode the text prompt
        text_embedding = text_encoder(prompt)

        # Generate the transformed image
        with torch.no_grad():
            if show_intermediate:
                # For visualization of intermediate steps
                intermediate_outputs = []
                transformed_image, intermediates = generator(
                    processed_image,
                    text_embedding,
                    transformation_strength=transformation_strength,
                    return_intermediates=True
                )
                intermediate_outputs = intermediates
            else:
                transformed_image = generator(
                    processed_image,
                    text_embedding,
                    transformation_strength=transformation_strength
                )

            # Get attention maps if requested
            attention_maps = None
            if show_attention_maps:
                attention_maps = generator.get_attention_maps(processed_image, text_embedding)

        # Convert to output image
        output_image = postprocess_image(transformed_image)

        # Save the output
        output_image.save(output_path)
        print(f"Transformation complete. Output saved to: {output_path}")

        # Visualization
        if side_by_side or show_attention_maps or show_intermediate:
            plt.figure(figsize=(15, 10))

            if side_by_side:
                plt.subplot(1, 2, 1)
                plt.title("Original Image")
                plt.imshow(np.array(input_image))
                plt.axis('off')

                plt.subplot(1, 2, 2)
                plt.title(f"Transformed: {prompt}")
                plt.imshow(np.array(output_image))
                plt.axis('off')

            if show_attention_maps and attention_maps is not None:
                plt.figure(figsize=(15, 5))
                plt.title("Attention Maps")
                for i, attn_map in enumerate(attention_maps):
                    plt.subplot(1, len(attention_maps), i+1)
                    plt.imshow(attn_map.cpu().numpy(), cmap='viridis')
                    plt.axis('off')

            if show_intermediate and intermediate_outputs:
                plt.figure(figsize=(15, 5))
                plt.title("Transformation Steps")
                for i, img in enumerate(intermediate_outputs):
                    plt.subplot(1, len(intermediate_outputs), i+1)
                    plt.imshow(postprocess_image(img))
                    plt.title(f"Step {i+1}")
                    plt.axis('off')

            plt.tight_layout()
            plt.show()

        return output_path

    except Exception as e:
        print(f"Error during inference: {e}")
        import traceback
        traceback.print_exc()
        return None


def batch_inference(
    image_folder,
    prompts_list=None,
    prompt=None,
    output_folder=None,
    checkpoint_path=None,
    transformation_strength=1.0
):
    """
    Process multiple images with corresponding prompts.

    Args:
        image_folder (str): Path to folder containing input images
        prompts_list (list, optional): List of prompts corresponding to each image
        prompt (str, optional): Single prompt to apply to all images
        output_folder (str, optional): Path to save output images
        checkpoint_path (str, optional): Path to model checkpoint
        transformation_strength (float): Transformation intensity (0.0-1.0)

    Returns:
        list: Paths to all generated images
    """
    if not os.path.exists(image_folder):
        raise FileNotFoundError(f"Image folder not found: {image_folder}")

    if prompts_list is None and prompt is None:
        raise ValueError("Either prompts_list or prompt must be provided")

    # Set default output folder
    if output_folder is None:
        output_folder = os.path.join(image_folder, "batch_outputs")

    os.makedirs(output_folder, exist_ok=True)

    # Get all images in the folder
    valid_extensions = ['.jpg', '.jpeg', '.png']
    image_files = [
        f for f in os.listdir(image_folder)
        if any(f.lower().endswith(ext) for ext in valid_extensions)
    ]

    if not image_files:
        print(f"No valid images found in {image_folder}")
        return []

    # Prepare prompts
    if prompts_list is None:
        prompts_list = [prompt] * len(image_files)
    elif len(prompts_list) < len(image_files):
        # Extend prompts list if needed
        prompts_list.extend([prompts_list[-1]] * (len(image_files) - len(prompts_list)))

    # Process each image
    output_paths = []
    for i, (image_file, img_prompt) in enumerate(zip(image_files, prompts_list)):
        print(f"\nProcessing image {i+1}/{len(image_files)}: {image_file}")
        image_path = os.path.join(image_folder, image_file)
        output_path = os.path.join(output_folder, f"transformed_{os.path.splitext(image_file)[0]}.jpg")

        result = test_inference(
            image_path=image_path,
            prompt=img_prompt,
            checkpoint_path=checkpoint_path,
            output_path=output_path,
            transformation_strength=transformation_strength,
            side_by_side=False,
            show_attention_maps=False,
            show_intermediate=False
        )

        if result:
            output_paths.append(result)

    print(f"\nBatch processing complete. {len(output_paths)} images generated in {output_folder}")
    return output_paths


# Example usage
if __name__ == "__main__":
    # Single image inference
    test_inference(
        image_path="/content/drive/MyDrive/Roomify/inference/sample_input.jpg",
        prompt="change the background of this room to something woody",
        transformation_strength=0.8,
        side_by_side=True
    )

    # Batch processing example
    # batch_inference(
    #     image_folder="/content/drive/MyDrive/Roomify/inference/sample_images",
    #     prompt="make this room look more modern and minimalist",
    #     transformation_strength=0.7
    # )

In [None]:
import sys
import os
import torch
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import clip

sys.path.append("/content/drive/MyDrive/Roomify")

# Part 1: Load and modify the model to boost text influence
class CLIPTextEncoder(torch.nn.Module):
    def __init__(self, device="cuda", model_name="ViT-B/32", use_cache=True):
        super().__init__()
        self.device = device

        # Load CLIP
        try:
            self.model, _ = clip.load(model_name, device=device)
            self.model = self.model.float()
            self.model.eval()
            self.clip_available = True
            print("✅ CLIP model loaded successfully")
        except Exception as e:
            print(f"⚠️ Could not load CLIP model: {e}")
            self.clip_available = False

        self.use_cache = use_cache
        self.embedding_cache = {}

        # Enhanced text projection
        self.text_enhancement = torch.nn.Sequential(
            torch.nn.Linear(512, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 512)
        )
        self.text_enhancement.to(device)

    def clear_cache(self):
        self.embedding_cache = {}

    @torch.no_grad()
    def forward(self, texts, temperature=2.5, enhance=True):
        """Process text with higher temperature by default"""
        if not isinstance(texts, list):
            texts = [texts]

        if self.use_cache:
            cached_results = []
            texts_to_process = []
            indices = []

            for i, text in enumerate(texts):
                if text in self.embedding_cache:
                    cached_results.append(self.embedding_cache[text])
                else:
                    texts_to_process.append(text)
                    indices.append(i)

            if len(texts_to_process) == 0:
                embeddings = torch.stack(cached_results)
                return embeddings

            if self.clip_available:
                tokenized = clip.tokenize(texts_to_process).to(self.device)
            else:
                tokenized = torch.ones((len(texts_to_process), 77), dtype=torch.long, device=self.device)
        else:
            if self.clip_available:
                tokenized = clip.tokenize(texts).to(self.device)
            else:
                tokenized = torch.ones((len(texts), 77), dtype=torch.long, device=self.device)

        if self.clip_available:
            embeddings = self.model.encode_text(tokenized)
        else:
            embeddings = torch.randn(len(tokenized), 512, device=self.device)

        if temperature != 1.0:
            noise_scale = (temperature - 1.0) * 0.2  # Increased noise scale
            embeddings = embeddings + torch.randn_like(embeddings) * noise_scale

        if enhance:
            embeddings = self.text_enhancement(embeddings)

        embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True)

        if self.use_cache and len(texts_to_process) > 0:
            for i, text in enumerate(texts_to_process):
                self.embedding_cache[text] = embeddings[i]

            if len(cached_results) > 0:
                all_embeddings = torch.zeros(len(texts), embeddings.shape[1], device=embeddings.device)

                for i, orig_idx in enumerate(indices):
                    all_embeddings[orig_idx] = embeddings[i]

                cache_idx = 0
                for i in range(len(texts)):
                    if i not in indices:
                        all_embeddings[i] = cached_results[cache_idx]
                        cache_idx += 1

                embeddings = all_embeddings

        return embeddings

# Step 1: Set up directories
output_dir = "/content/drive/MyDrive/Roomify/boosted_test_results"
os.makedirs(output_dir, exist_ok=True)

# Step 2: Boost the model weights
from models.generator import ConditionalUNetGenerator
checkpoint_path = "/content/drive/MyDrive/Roomify/checkpoints_enhanced/generator_epoch50.pth"
boosted_checkpoint = "/content/drive/MyDrive/Roomify/checkpoints_enhanced/generator_epoch50_boosted.pth"

# Load model
print("Loading and boosting the model...")
model = ConditionalUNetGenerator().to("cuda")
model.load_state_dict(torch.load(checkpoint_path, map_location="cuda"))

# Boost text influence dramatically
count = 0
for name, param in model.named_parameters():
    if "text_to_feature" in name or "text_proj" in name:
        if "weight" in name:
            print(f"Boosting {name}")
            param.data *= 10.0  # 10x stronger!
            count += 1

print(f"Boosted {count} text-related parameter weights")

# Save boosted model
torch.save(model.state_dict(), boosted_checkpoint)
print(f"Saved boosted model to {boosted_checkpoint}")

# Step 3: Run inference with boosted model
print("\nRunning inference with boosted model...")

# Load image
image_path = "/content/drive/MyDrive/Roomify/inference/sample_input.jpg"
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to("cuda")

# Create mask (central area)
mask = torch.zeros_like(image)
h, w = image.shape[2:]
y1, y2 = int(h * 0.2), int(h * 0.8)
x1, x2 = int(w * 0.2), int(w * 0.8)
mask[:, :, y1:y2, x1:x2] = 1.0
mask = mask.to("cuda")

# Initialize encoders and load boosted model
text_encoder = CLIPTextEncoder(device="cuda")
model = ConditionalUNetGenerator().to("cuda")
model.load_state_dict(torch.load(boosted_checkpoint))
model.eval()

# List of test prompts
test_prompts = [
    "transform this room into a cabin with wooden walls",
    "add rustic wooden planks to the wall",
    "convert to dark stone walls",
    "make the walls bright blue",
    "change to industrial style with exposed brick"
]

# Process each prompt
results = []
for i, prompt in enumerate(test_prompts):
    print(f"Processing prompt: '{prompt}'")

    # Get text embedding with high temperature
    with torch.no_grad():
        text_embedding = text_encoder(prompt, temperature=2.5)

        # Generate image
        output = model(image, mask, text_embedding)
        output = torch.clamp(output, 0, 1)

    # Save result
    prompt_filename = prompt.replace(" ", "_").replace("'", "")[:30]
    output_path = os.path.join(output_dir, f"boosted_{i}_{prompt_filename}.png")
    save_image(output, output_path)

    # Save mask for reference (first prompt only)
    if i == 0:
        mask_path = os.path.join(output_dir, "mask.png")
        save_image(mask, mask_path)
        original_path = os.path.join(output_dir, "original.png")
        save_image(image, original_path)

    # Calculate difference
    diff = torch.abs(output - image).mean().item()
    print(f"Average pixel difference: {diff:.6f}")
    print(f"Saved to: {output_path}")

    results.append((prompt, output[0], diff))

# Create and save comparison grid
all_images = [image[0]] + [r[1] for r in results]
grid = torch.stack(list(all_images))
grid_path = os.path.join(output_dir, "comparison_grid.png")
save_image(grid, grid_path, nrow=3)
print(f"\nSaved comparison grid to: {grid_path}")

print("\n✅ Inference complete! Check the results in:", output_dir)

Loading and boosting the model...
Boosting text_proj_inc.0.weight
Boosting text_proj_inc.2.weight
Boosting text_proj_down1.0.weight
Boosting text_proj_down1.2.weight
Boosting text_proj_down2.0.weight
Boosting text_proj_down2.2.weight
Boosting text_proj_down3.0.weight
Boosting text_proj_down3.2.weight
Boosting text_proj_down4.0.weight
Boosting text_proj_down4.2.weight
Boosting up1.text_proj.0.weight
Boosting up1.text_proj.2.weight
Boosting up1.text_proj.4.weight
Boosting up1.text_to_feature.0.weight
Boosting up1.text_to_feature.2.weight
Boosting up2.text_proj.0.weight
Boosting up2.text_proj.2.weight
Boosting up2.text_proj.4.weight
Boosting up2.text_to_feature.0.weight
Boosting up2.text_to_feature.2.weight
Boosting up3.text_proj.0.weight
Boosting up3.text_proj.2.weight
Boosting up3.text_proj.4.weight
Boosting up3.text_to_feature.0.weight
Boosting up3.text_to_feature.2.weight
Boosting up4.text_proj.0.weight
Boosting up4.text_proj.2.weight
Boosting up4.text_proj.4.weight
Boosting up4.text_

In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Roomify")
import torch
from torchvision.utils import save_image
import os
from PIL import Image
import torchvision.transforms as transforms

# Create a test directory
test_dir = "/content/drive/MyDrive/Roomify/diagnostic_test"
os.makedirs(test_dir, exist_ok=True)

# 1. Load image and create masks
image_path = "/content/drive/MyDrive/Roomify/inference/sample_input.jpg"
transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])
image = transform(Image.open(image_path).convert("RGB")).unsqueeze(0)

# 2. Create a different mask for each test
masks = [
    torch.ones_like(image),  # Full image
    torch.zeros_like(image),  # Empty mask
]
# Add a middle region mask
mask3 = torch.zeros_like(image)
mask3[:, :, 50:200, 50:200] = 1.0
masks.append(mask3)

# 3. Create artificially modified outputs to test save functionality
modifications = [
    lambda img: img * 0.7,                           # Darker
    lambda img: img * 0 + torch.tensor([1,0,0]).view(1,3,1,1),  # Pure red
    lambda img: img * (mask3 * 0.5 + (1-mask3))      # Darken masked area only
]

# 4. Save all test images
for i, (mask, mod_fn) in enumerate(zip(masks, modifications)):
    # Save mask
    mask_path = os.path.join(test_dir, f"mask_{i}.png")
    save_image(mask, mask_path)
    print(f"✓ Saved mask to {mask_path}")

    # Create and save modified image
    modified = mod_fn(image)
    mod_path = os.path.join(test_dir, f"modified_{i}.png")
    save_image(modified, mod_path)
    print(f"✓ Saved modified image to {mod_path}")

    # Save original for reference
    if i == 0:
        orig_path = os.path.join(test_dir, "original.png")
        save_image(image, orig_path)
        print(f"✓ Saved original to {orig_path}")

print("\nDiagnostic test complete! Please check the images in:", test_dir)

✓ Saved mask to /content/drive/MyDrive/Roomify/diagnostic_test/mask_0.png
✓ Saved modified image to /content/drive/MyDrive/Roomify/diagnostic_test/modified_0.png
✓ Saved original to /content/drive/MyDrive/Roomify/diagnostic_test/original.png
✓ Saved mask to /content/drive/MyDrive/Roomify/diagnostic_test/mask_1.png
✓ Saved modified image to /content/drive/MyDrive/Roomify/diagnostic_test/modified_1.png
✓ Saved mask to /content/drive/MyDrive/Roomify/diagnostic_test/mask_2.png
✓ Saved modified image to /content/drive/MyDrive/Roomify/diagnostic_test/modified_2.png

Diagnostic test complete! Please check the images in: /content/drive/MyDrive/Roomify/diagnostic_test


In [None]:
import torch
import sys
sys.path.append("/content/drive/MyDrive/Roomify")
from models.generator import ConditionalUNetGenerator

# 1. Load your trained model
checkpoint_path = "/content/drive/MyDrive/Roomify/checkpoints_enhanced/generator_epoch50.pth"
model = ConditionalUNetGenerator().to("cuda")
model.load_state_dict(torch.load(checkpoint_path))

# 2. Dramatically increase text influence in the model
for name, param in model.named_parameters():
    if "text_to_feature" in name and "weight" in name:
        print(f"Boosting weights for {name}")
        # Multiply text feature weights by 10
        param.data *= 10.0

# 3. Save the modified model
modified_path = "/content/drive/MyDrive/Roomify/checkpoints_enhanced/generator_epoch50_boosted.pth"
torch.save(model.state_dict(), modified_path)
print(f"Saved modified model to {modified_path}")

Boosting weights for up1.text_to_feature.0.weight
Boosting weights for up1.text_to_feature.2.weight
Boosting weights for up2.text_to_feature.0.weight
Boosting weights for up2.text_to_feature.2.weight
Boosting weights for up3.text_to_feature.0.weight
Boosting weights for up3.text_to_feature.2.weight
Boosting weights for up4.text_to_feature.0.weight
Boosting weights for up4.text_to_feature.2.weight
Saved modified model to /content/drive/MyDrive/Roomify/checkpoints_enhanced/generator_epoch50_boosted.pth


In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Roomify")
import os
import torch
from torch.utils.data import DataLoader
from datasets.roomify_dataset import RoomifyDataset
from models.generator import ConditionalUNetGenerator
from models.discriminator import RoomifyDiscriminator
from models.clip_encoder import CLIPTextEncoder
from training.trainer import RoomifyGANTrainer
import gc

# Free up memory
gc.collect()
torch.cuda.empty_cache()

# Set up new directories
VERSION = "extreme_v2"
CHECKPOINT_DIR = f"/content/drive/MyDrive/Roomify/checkpoints_{VERSION}"
GENERATED_DIR = f"/content/drive/MyDrive/Roomify/generated_{VERSION}"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
os.makedirs(GENERATED_DIR, exist_ok=True)

# Set device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# Load dataset
dataset = RoomifyDataset(
    csv_path="/content/drive/MyDrive/Roomify/data/processed/unified_prompts.csv",
    root_dir="/content/drive/MyDrive/Roomify/data/processed",
    image_size=(256, 256),
    use_cache=True,
    cache_size=100
)

dataloader = DataLoader(
    dataset,
    batch_size=6,  # Smaller batch size for Colab
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True
)
print(f"Dataset loaded with {len(dataset)} samples")

# Create models
generator = ConditionalUNetGenerator(in_channels=6, out_channels=3).to(DEVICE)
discriminator = RoomifyDiscriminator(in_channels=6, text_dim=512).to(DEVICE)
clip_encoder = CLIPTextEncoder(device=DEVICE)

# Create trainer
trainer = RoomifyGANTrainer(
    generator=generator,
    discriminator=discriminator,
    text_encoder=clip_encoder,
    dataloader=dataloader,
    device=DEVICE,
    g_lr=2e-4,
    d_lr=3e-4,
    ckpt_dir=CHECKPOINT_DIR,
    image_save_dir=GENERATED_DIR,
    project_name=f"roomify-{VERSION}"
)

# Train for 10 epochs at a time
TOTAL_EPOCHS = 40
EPOCHS_PER_RUN = 10

for start_epoch in range(0, TOTAL_EPOCHS, EPOCHS_PER_RUN):
    end_epoch = min(start_epoch + EPOCHS_PER_RUN, TOTAL_EPOCHS)
    print(f"\n{'='*50}")
    print(f"Training epochs {start_epoch+1} to {end_epoch}")
    print(f"{'='*50}")

    # Clear cache
    gc.collect()
    torch.cuda.empty_cache()

    # Train
    trainer.train(num_epochs=EPOCHS_PER_RUN, resume_epoch=start_epoch)

Using device: cuda
[Dataset] Dropped 0 empty rows, 0 corrupt/missing files.
[Dataset] Loaded 24474 valid samples.
Dataset loaded with 24474 samples


[34m[1mwandb[0m: Currently logged in as: [33mroomify6[0m ([33mroomify6-cairo-higher-institute-for-engineering[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



Training epochs 1 to 10

🚀 Starting training...

=== Epoch 1/10 ===
Current loss weights: L1=0.01, Adv=5.00


Training Epoch 1: 100%|██████████| 4079/4079 [20:42<00:00,  3.28it/s, g_loss=16.8350, d_loss=0.6935, adv_loss=0.6734, l1_loss=0.1726, perc_loss=0.0154, l1_weight=0.0100, adv_weight=5.0000]



📊 Epoch 1 Summary:
Generator Loss: 19.4164
Discriminator Loss: 0.9299
Adversarial Loss: 0.7766
L1 Loss: 0.1213
L1 Weight: 0.0100
💾 Checkpoints saved for epoch 1

=== Epoch 2/10 ===
Current loss weights: L1=0.01, Adv=4.85


Training Epoch 2: 100%|██████████| 4079/4079 [20:00<00:00,  3.40it/s, g_loss=16.0649, d_loss=0.6944, adv_loss=0.6625, l1_loss=0.0589, perc_loss=0.0053, l1_weight=0.0145, adv_weight=4.8500]



📊 Epoch 2 Summary:
Generator Loss: 16.8166
Discriminator Loss: 0.6934
Adversarial Loss: 0.6935
L1 Loss: 0.1018
L1 Weight: 0.0145
💾 Checkpoints saved for epoch 2

=== Epoch 3/10 ===
Current loss weights: L1=0.02, Adv=4.70


Training Epoch 3: 100%|██████████| 4079/4079 [20:04<00:00,  3.39it/s, g_loss=16.2316, d_loss=0.6932, adv_loss=0.6907, l1_loss=0.0634, perc_loss=0.0054, l1_weight=0.0190, adv_weight=4.7000]



📊 Epoch 3 Summary:
Generator Loss: 18.6496
Discriminator Loss: 0.8075
Adversarial Loss: 0.7936
L1 Loss: 0.0954
L1 Weight: 0.0190
💾 Checkpoints saved for epoch 3

=== Epoch 4/10 ===
Current loss weights: L1=0.02, Adv=4.55


Training Epoch 4: 100%|██████████| 4079/4079 [22:07<00:00,  3.07it/s, g_loss=16.2180, d_loss=0.6944, adv_loss=0.7129, l1_loss=0.0432, perc_loss=0.0035, l1_weight=0.0235, adv_weight=4.5500]



📊 Epoch 4 Summary:
Generator Loss: 20.3350
Discriminator Loss: 0.9502
Adversarial Loss: 0.8938
L1 Loss: 0.0479
L1 Weight: 0.0235
💾 Checkpoints saved for epoch 4

=== Epoch 5/10 ===
Current loss weights: L1=0.03, Adv=4.40


Training Epoch 5: 100%|██████████| 4079/4079 [24:49<00:00,  2.74it/s, g_loss=15.5786, d_loss=0.6933, adv_loss=0.7081, l1_loss=0.0220, perc_loss=0.0019, l1_weight=0.0280, adv_weight=4.4000]



📊 Epoch 5 Summary:
Generator Loss: 15.3269
Discriminator Loss: 0.6967
Adversarial Loss: 0.6967
L1 Loss: 0.0288
L1 Weight: 0.0280
💾 Checkpoints saved for epoch 5

=== Epoch 6/10 ===
Current loss weights: L1=0.03, Adv=4.25


Training Epoch 6: 100%|██████████| 4079/4079 [25:07<00:00,  2.71it/s, g_loss=14.7792, d_loss=0.6932, adv_loss=0.6955, l1_loss=0.0118, perc_loss=0.0011, l1_weight=0.0325, adv_weight=4.2500]



📊 Epoch 6 Summary:
Generator Loss: 14.7316
Discriminator Loss: 0.6933
Adversarial Loss: 0.6932
L1 Loss: 0.0209
L1 Weight: 0.0325
💾 Checkpoints saved for epoch 6

=== Epoch 7/10 ===
Current loss weights: L1=0.04, Adv=4.10


Training Epoch 7: 100%|██████████| 4079/4079 [25:10<00:00,  2.70it/s, g_loss=13.0531, d_loss=0.7002, adv_loss=0.6367, l1_loss=0.0277, perc_loss=0.0021, l1_weight=0.0370, adv_weight=4.1000]



📊 Epoch 7 Summary:
Generator Loss: 20.2158
Discriminator Loss: 1.2155
Adversarial Loss: 0.9861
L1 Loss: 0.0176
L1 Weight: 0.0370
💾 Checkpoints saved for epoch 7

=== Epoch 8/10 ===
Current loss weights: L1=0.04, Adv=3.95


Training Epoch 8: 100%|██████████| 4079/4079 [25:02<00:00,  2.71it/s, g_loss=13.5114, d_loss=0.6938, adv_loss=0.6841, l1_loss=0.0121, perc_loss=0.0010, l1_weight=0.0415, adv_weight=3.9500]



📊 Epoch 8 Summary:
Generator Loss: 17.2888
Discriminator Loss: 0.8703
Adversarial Loss: 0.8754
L1 Loss: 0.0148
L1 Weight: 0.0415
💾 Checkpoints saved for epoch 8

=== Epoch 9/10 ===
Current loss weights: L1=0.05, Adv=3.80


Training Epoch 9: 100%|██████████| 4079/4079 [25:38<00:00,  2.65it/s, g_loss=13.7038, d_loss=0.6937, adv_loss=0.7212, l1_loss=0.0131, perc_loss=0.0011, l1_weight=0.0460, adv_weight=3.8000]



📊 Epoch 9 Summary:
Generator Loss: 13.2006
Discriminator Loss: 0.6945
Adversarial Loss: 0.6948
L1 Loss: 0.0133
L1 Weight: 0.0460
💾 Checkpoints saved for epoch 9

=== Epoch 10/10 ===
Current loss weights: L1=0.05, Adv=3.65


Training Epoch 10: 100%|██████████| 4079/4079 [25:24<00:00,  2.67it/s, g_loss=12.7174, d_loss=0.6932, adv_loss=0.6968, l1_loss=0.0182, perc_loss=0.0014, l1_weight=0.0505, adv_weight=3.6500]



📊 Epoch 10 Summary:
Generator Loss: 13.6448
Discriminator Loss: 0.7453
Adversarial Loss: 0.7477
L1 Loss: 0.0120
L1 Weight: 0.0505
💾 Checkpoints saved for epoch 10

✅ Training completed!

Training epochs 11 to 20

🚀 Starting training...

=== Epoch 11/20 ===
Current loss weights: L1=0.06, Adv=3.50


Training Epoch 11: 100%|██████████| 4079/4079 [24:54<00:00,  2.73it/s, g_loss=12.0708, d_loss=0.6932, adv_loss=0.6898, l1_loss=0.0057, perc_loss=0.0005, l1_weight=0.0550, adv_weight=3.5000]



📊 Epoch 11 Summary:
Generator Loss: 12.1326
Discriminator Loss: 0.6933
Adversarial Loss: 0.6933
L1 Loss: 0.0087
L1 Weight: 0.0550
💾 Checkpoints saved for epoch 11

=== Epoch 12/20 ===
Current loss weights: L1=0.06, Adv=3.35


Training Epoch 12: 100%|██████████| 4079/4079 [24:47<00:00,  2.74it/s, g_loss=11.6083, d_loss=0.6932, adv_loss=0.6930, l1_loss=0.0078, perc_loss=0.0007, l1_weight=0.0595, adv_weight=3.3500]



📊 Epoch 12 Summary:
Generator Loss: 11.6716
Discriminator Loss: 0.6970
Adversarial Loss: 0.6968
L1 Loss: 0.0081
L1 Weight: 0.0595
💾 Checkpoints saved for epoch 12

=== Epoch 13/20 ===
Current loss weights: L1=0.06, Adv=3.20


Training Epoch 13:  83%|████████▎ | 3404/4079 [20:27<04:12,  2.67it/s, g_loss=11.1696, d_loss=0.6932, adv_loss=0.6981, l1_loss=0.0079, perc_loss=0.0007, l1_weight=0.0640, adv_weight=3.2000]

In [None]:
!mkdir -p /content/drive/MyDrive/Roomify/test_images

In [None]:
import sys
sys.path.append("/content/drive/MyDrive/Roomify")
import torch
import os
from PIL import Image
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from models.generator import ConditionalUNetGenerator
from models.clip_encoder import CLIPTextEncoder
import matplotlib.pyplot as plt
import numpy as np

# Setup directories
test_dir = "/content/drive/MyDrive/Roomify/test_results"
os.makedirs(test_dir, exist_ok=True)

# Load your best model
model_path = "/content/drive/MyDrive/Roomify/checkpoints_extreme_v2/generator_epoch8.pth"  # Use your latest epoch
device = "cuda" if torch.cuda.is_available() else "cpu"

# Initialize models
generator = ConditionalUNetGenerator(in_channels=6, out_channels=3).to(device)
generator.load_state_dict(torch.load(model_path, map_location=device))
generator.eval()

text_encoder = CLIPTextEncoder(device=device)

# Image preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

# Function to generate images based on prompts
def generate_room_variants(image_path, prompts, output_dir, use_mask=True):
    os.makedirs(output_dir, exist_ok=True)

    # Load and preprocess image
    img = Image.open(image_path).convert("RGB")
    img_tensor = transform(img).unsqueeze(0).to(device)

    # Create mask (default: full image)
    if use_mask:
        mask = torch.ones_like(img_tensor)
    else:
        # Create mask for just walls (example)
        mask = torch.zeros_like(img_tensor)
        # Middle rectangular area (typical walls area)
        mask[:, :, 30:226, 30:226] = 1.0

    # Save original image
    original_path = os.path.join(output_dir, "original.png")
    save_image(img_tensor, original_path)

    # Save mask visualization
    mask_path = os.path.join(output_dir, "mask.png")
    save_image(mask, mask_path)

    results = []
    all_images = [img_tensor[0]]

    # Process each prompt
    for i, prompt in enumerate(prompts):
        print(f"Processing prompt: '{prompt}'")

        # Get text embedding with high temperature for diversity
        text_embedding = text_encoder(prompt, temperature=2.0)

        # Generate image
        with torch.no_grad():
            output = generator(img_tensor, mask, text_embedding)
            output = torch.clamp(output, 0, 1)

        # Save output
        output_path = os.path.join(output_dir, f"{i+1}_{prompt.replace(' ', '_')[:30]}.png")
        save_image(output, output_path)

        # Calculate difference
        diff = torch.abs(output - img_tensor).mean().item()
        print(f"Average difference: {diff:.4f}")

        results.append((prompt, output[0], diff))
        all_images.append(output[0])

    # Create comparison grid
    grid = make_grid(all_images, nrow=3)
    grid_path = os.path.join(output_dir, "comparison_grid.png")
    save_image(grid, grid_path)

    # Create difference visualization
    diff_images = []
    diff_images.append(img_tensor[0])  # Original
    for _, img, _ in results:
        # Amplify differences for visibility
        diff_img = torch.abs(img - img_tensor[0]) * 5
        diff_images.append(diff_img)

    diff_grid = make_grid(diff_images, nrow=3)
    diff_grid_path = os.path.join(output_dir, "difference_grid.png")
    save_image(diff_grid, diff_grid_path)

    print(f"Results saved to {output_dir}")
    return results

# Test with different prompt categories
test_images = [
    "/content/drive/MyDrive/Roomify/test_images/023c66ab118a2c487f82c3ac145c69c9.jpg",
    "/content/drive/MyDrive/Roomify/test_images/2bb101db004211f248fb7f1c0e254fee.jpg",
    "/content/drive/MyDrive/Roomify/test_images/e216c5d3ba2356676429bcf10bc5245a.jpg",
    "/content/drive/MyDrive/Roomify/test_images/e216c5d3ba2356676429bcf10bc5245a.jpg"
    # Add more test images
]

# Define test prompt sets
color_prompts = [
    "change the wall color to light blue",
    "paint the walls dark green",
    "make the walls bright yellow",
    "change to white walls with black accents"
]

material_prompts = [
    "add wooden panels to the walls",
    "convert to exposed brick walls",
    "change to marble wall texture",
    "add stone texture to the walls"
]

style_prompts = [
    "transform to modern minimalist style",
    "convert to rustic farmhouse style",
    "change to luxury penthouse style",
    "redesign as industrial style room"
]

# Run tests on each image with different prompt types
for img_path in test_images:
    img_name = os.path.basename(img_path).split('.')[0]
    base_dir = os.path.join(test_dir, img_name)

    # Test with color prompts
    generate_room_variants(
        image_path=img_path,
        prompts=color_prompts,
        output_dir=os.path.join(base_dir, "colors"),
        use_mask=False  # Target walls
    )

    # Test with material prompts
    generate_room_variants(
        image_path=img_path,
        prompts=material_prompts,
        output_dir=os.path.join(base_dir, "materials"),
        use_mask=False
    )

    # Test with style prompts
    generate_room_variants(
        image_path=img_path,
        prompts=style_prompts,
        output_dir=os.path.join(base_dir, "styles"),
        use_mask=True  # Full image for style changes
    )

print("Testing complete!")

100%|███████████████████████████████████████| 338M/338M [00:15<00:00, 23.1MiB/s]


Processing prompt: 'change the wall color to light blue'
Average difference: 0.1753
Processing prompt: 'paint the walls dark green'
Average difference: 0.1816
Processing prompt: 'make the walls bright yellow'
Average difference: 0.1713
Processing prompt: 'change to white walls with black accents'
Average difference: 0.1755
Results saved to /content/drive/MyDrive/Roomify/test_results/023c66ab118a2c487f82c3ac145c69c9/colors
Processing prompt: 'add wooden panels to the walls'
Average difference: 0.1799
Processing prompt: 'convert to exposed brick walls'
Average difference: 0.1731
Processing prompt: 'change to marble wall texture'
Average difference: 0.1720
Processing prompt: 'add stone texture to the walls'
Average difference: 0.1739
Results saved to /content/drive/MyDrive/Roomify/test_results/023c66ab118a2c487f82c3ac145c69c9/materials
Processing prompt: 'transform to modern minimalist style'
Average difference: 0.1147
Processing prompt: 'convert to rustic farmhouse style'
Average differe

In [None]:
!python /content/drive/MyDrive/Roomify/train_stylegan.py

[34m[1mwandb[0m: Currently logged in as: [33mroomify6[0m ([33mroomify6-cairo-higher-institute-for-engineering[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Tracking run with wandb version 0.19.10
[34m[1mwandb[0m: Run data is saved locally in [35m[1m/content/wandb/run-20250505_052315-d4xryxzp[0m
[34m[1mwandb[0m: Run [1m`wandb offline`[0m to turn off syncing.
[34m[1mwandb[0m: Syncing run [33mold-droid-1[0m
[34m[1mwandb[0m: ⭐️ View project at [34m[4mhttps://wandb.ai/roomify6-cairo-higher-institute-for-engineering/roomify-stylegan[0m
[34m[1mwandb[0m: 🚀 View run at [34m[4mhttps://wandb.ai/roomify6-cairo-higher-institute-for-engineering/roomify-stylegan/runs/d4xryxzp[0m


In [2]:
!mkdir -p /content/drive/MyDrive/Roomify/checkpoints_stylegan