<a href="https://colab.research.google.com/github/lauren1turner/DS4002_LAM/blob/main/PIx2pix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

base_dir = "/content/dataset"

folders = [
    "train/grayscale",
    "train/color",
    "test/grayscale",
    "test/color"
]

for folder in folders:
    path = os.path.join(base_dir, folder)
    os.makedirs(path, exist_ok=True)

print("Folders created successfully:")
for folder in folders:
    print(os.path.join(base_dir, folder))


Folders created successfully:
/content/dataset/train/grayscale
/content/dataset/train/color
/content/dataset/test/grayscale
/content/dataset/test/color


In [2]:
import os
from PIL import Image

# Set the folder path (where your .jpg images are located)
folder_path = "/content/"  # Replace with your actual image path
output_folder = "/content/dataset/test/grayscale"  # Folder to save grayscaled images

# Ensure the output folder exists
os.makedirs(output_folder, exist_ok=True)

# Loop through each file in the folder
for filename in os.listdir(folder_path):
    if filename.lower().endswith(".jpg"):
        image_path = os.path.join(folder_path, filename)

        # Open the image and convert it to grayscale
        img = Image.open(image_path).convert("L")

        # Modify filename to add .grayscale before .jpg
        name, ext = os.path.splitext(filename)
        new_filename = f"{name}.grayscale.jpg"

        # Save the grayscale image
        save_path = os.path.join(output_folder, new_filename)
        img.save(save_path)

        print(f"Grayscaled: {new_filename}")


Grayscaled: image-291-a.grayscale.jpg
Grayscaled: image-293-a.grayscale.jpg
Grayscaled: image-178-a.grayscale.jpg
Grayscaled: image-194-a.grayscale.jpg
Grayscaled: image-219-a.grayscale.jpg
Grayscaled: image-198-a.grayscale.jpg
Grayscaled: image-294-a.grayscale.jpg
Grayscaled: image-188-a.grayscale.jpg
Grayscaled: image-216-a.grayscale.jpg
Grayscaled: image-295-a.grayscale.jpg
Grayscaled: image-229-c.grayscale.jpg
Grayscaled: image-214-c.grayscale.jpg
Grayscaled: image-217-c.grayscale.jpg
Grayscaled: image-279-a.grayscale.jpg
Grayscaled: image-228-a.grayscale.jpg
Grayscaled: image-280-a.grayscale.jpg
Grayscaled: image-207-a.grayscale.jpg
Grayscaled: image-210-a.grayscale.jpg
Grayscaled: image-276-a.grayscale.jpg
Grayscaled: image-273-a.grayscale.jpg
Grayscaled: image-189-a.grayscale.jpg
Grayscaled: image-205-c.grayscale.jpg
Grayscaled: image-220-c.grayscale.jpg
Grayscaled: image-278-a.grayscale.jpg
Grayscaled: image-225-a.grayscale.jpg
Grayscaled: image-286-a.grayscale.jpg
Grayscaled: 

In [3]:
import os
import shutil

source_dir = "/content"
target_dir = "/content/dataset/test/color"

# Create the folder if it doesn't exist
os.makedirs(target_dir, exist_ok=True)

# Move all .jpg files
for filename in os.listdir(source_dir):
    if filename.lower().endswith(".jpg"):
        shutil.move(os.path.join(source_dir, filename), os.path.join(target_dir, filename))

print("✅ All .jpg files moved to /content/dataset/test/color")


✅ All .jpg files moved to /content/dataset/test/color


In [5]:
import shutil
import os

# Source directories
color_dir = '/content/dataset/test/color'
grayscale_dir = '/content/dataset/test/grayscale'

# Destination directories
color_copy_dir = '/content/dataset/train/color'
grayscale_copy_dir = '/content/dataset/train/grayscale'

# Create destination directories if they don't exist
os.makedirs(color_copy_dir, exist_ok=True)
os.makedirs(grayscale_copy_dir, exist_ok=True)

# Copy files from color directory
for filename in os.listdir(color_dir):
    src_path = os.path.join(color_dir, filename)
    dst_path = os.path.join(color_copy_dir, filename)
    shutil.copy2(src_path, dst_path)

# Copy files from grayscale directory
for filename in os.listdir(grayscale_dir):
    src_path = os.path.join(grayscale_dir, filename)
    dst_path = os.path.join(grayscale_copy_dir, filename)
    shutil.copy2(src_path, dst_path)

print("Files copied successfully.")


Files copied successfully.


In [6]:
!pip install torchvision matplotlib

import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# === Dataset ===
class GrayscaleToColorDataset(Dataset):
    def __init__(self, grayscale_dir, color_dir, transform=None):
        self.grayscale_dir = grayscale_dir
        self.color_dir = color_dir
        self.transform = transform
        self.image_names = os.listdir(grayscale_dir)

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        gray_path = os.path.join(self.grayscale_dir, self.image_names[idx])
        color_path = os.path.join(self.color_dir, self.image_names[idx])

        gray_image = Image.open(gray_path).convert('L')
        color_image = Image.open(color_path).convert('RGB')

        if self.transform:
            gray_image = self.transform(gray_image)
            color_image = self.transform(color_image)

        return gray_image, color_image

# === Transforms ===
img_size = 256
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor()
])

train_dataset = GrayscaleToColorDataset(
    grayscale_dir="/content/dataset/train/grayscale",
    color_dir="/content/dataset/train/color",
    transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# === Generator ===
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, out_channels, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# === Discriminator ===
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 4, 1, 1), nn.Sigmoid()
        )

    def forward(self, x, y):
        # Concatenate input and target image
        return self.net(torch.cat([x, y], 1))

# === Training ===
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)

criterion = nn.BCELoss()
l1_loss = nn.L1Loss()

optimizer_G = torch.optim.Adam(generator.parameters(), lr=2e-4)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4)

epochs = 100

for epoch in range(epochs):
    for i, (gray, color) in enumerate(train_loader):
        gray, color = gray.to(device), color.to(device)

        # === Train Discriminator ===
        fake_color = generator(gray)
        real_label = torch.ones((gray.size(0), 1, 30, 30), device=device)
        fake_label = torch.zeros((gray.size(0), 1, 30, 30), device=device)

        optimizer_D.zero_grad()
        real_output = discriminator(gray, color)
        fake_output = discriminator(gray, fake_color.detach())
        d_loss = (criterion(real_output, real_label) + criterion(fake_output, fake_label)) * 0.5
        d_loss.backward()
        optimizer_D.step()

        # === Train Generator ===
        optimizer_G.zero_grad()
        fake_output = discriminator(gray, fake_color)
        g_loss = criterion(fake_output, real_label) + l1_loss(fake_color, color) * 100
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{epochs}] - D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")
    if (epoch + 1) % 10 == 0:
        save_image(fake_color, f"/content/fake_epoch_{epoch+1}.png")


Using device: cpu


FileNotFoundError: [Errno 2] No such file or directory: '/content/dataset/train/color/image-188-a.grayscale.jpg'

In [None]:
import shutil
import os

# Define paths
paths = [
    ("/content/dataset/test/color", "/content/dataset/train/color"),
    ("/content/dataset/test/grayscale", "/content/dataset/train/grayscale")
]

# Copy files
for src, dst in paths:
    os.makedirs(dst, exist_ok=True)  # Make sure the target folder exists
    for filename in os.listdir(src):
        src_file = os.path.join(src, filename)
        dst_file = os.path.join(dst, filename)
        if os.path.isfile(src_file):
            shutil.copy(src_file, dst_file)

print("✅ Files copied from test to train folders.")


In [7]:
print("Train Color:", len(os.listdir('/content/dataset/train/color')))
print("Train Grayscale:", len(os.listdir('/content/dataset/train/grayscale')))
print("Test Color:", len(os.listdir('/content/dataset/test/color')))
print("Test Grayscale:", len(os.listdir('/content/dataset/test/grayscale')))



import os

color_files = os.listdir("/content/dataset/train/color")
print("Sample color files:", color_files[:5])


Train Color: 84
Train Grayscale: 84
Test Color: 84
Test Grayscale: 84
Sample color files: ['image-291-a.jpg', 'image-293-a.jpg', 'image-178-a.jpg', 'image-194-a.jpg', 'image-219-a.jpg']


In [8]:
class GrayscaleToColorDataset(Dataset):
    def __init__(self, grayscale_dir, color_dir, transform=None):
        self.grayscale_dir = grayscale_dir
        self.color_dir = color_dir
        self.transform = transform

        self.grayscale_files = os.listdir(grayscale_dir)
        self.color_files = os.listdir(color_dir)

        # Normalize filenames (remove .grayscale, lowercased)
        def normalize(name):
            return name.replace(".grayscale", "").lower()

        # Build matched pairs
        self.pairs = []
        color_map = {normalize(f): f for f in self.color_files}
        for gray in self.grayscale_files:
            norm = normalize(gray)
            if norm in color_map:
                self.pairs.append((gray, color_map[norm]))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        gray_name, color_name = self.pairs[idx]
        gray_path = os.path.join(self.grayscale_dir, gray_name)
        color_path = os.path.join(self.color_dir, color_name)

        gray_image = Image.open(gray_path).convert('L')
        color_image = Image.open(color_path).convert('RGB')

        if self.transform:
            gray_image = self.transform(gray_image)
            color_image = self.transform(color_image)

        return gray_image, color_image


In [9]:
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from PIL import Image
import matplotlib.pyplot as plt
from torch import nn

# === Device ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# === Dataset with Automatic Matching ===
class GrayscaleToColorDataset(Dataset):
    def __init__(self, grayscale_dir, color_dir, transform=None):
        self.grayscale_dir = grayscale_dir
        self.color_dir = color_dir
        self.transform = transform

        self.grayscale_files = os.listdir(grayscale_dir)
        self.color_files = os.listdir(color_dir)

        def normalize(name):
            return name.replace(".grayscale", "").lower()

        # Build matched grayscale-color pairs
        color_map = {normalize(f): f for f in self.color_files}
        self.pairs = []
        for gray in self.grayscale_files:
            norm = normalize(gray)
            if norm in color_map:
                self.pairs.append((gray, color_map[norm]))

        print(f"✅ Matched {len(self.pairs)} grayscale-color pairs.")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        gray_name, color_name = self.pairs[idx]
        gray_path = os.path.join(self.grayscale_dir, gray_name)
        color_path = os.path.join(self.color_dir, color_name)

        gray_image = Image.open(gray_path).convert('L')
        color_image = Image.open(color_path).convert('RGB')

        if self.transform:
            gray_image = self.transform(gray_image)
            color_image = self.transform(color_image)

        return gray_image, color_image

# === Transforms ===
img_size = 256
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor()
])

# === DataLoader ===
train_dataset = GrayscaleToColorDataset(
    grayscale_dir="/content/dataset/train/grayscale",
    color_dir="/content/dataset/train/color",
    transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# === Generator (U-Net) ===
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, out_channels, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# === Discriminator (PatchGAN) ===
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 4, 1, 1), nn.Sigmoid()
        )

    def forward(self, x, y):
        return self.net(torch.cat([x, y], 1))  # Concatenate grayscale and color images

# === Model Init ===
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)

# === Loss and Optimizers ===
criterion = nn.BCELoss()
l1_loss = nn.L1Loss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=2e-4)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4)

# === Training Loop ===
epochs = 100
for epoch in range(epochs):
    for i, (gray, color) in enumerate(train_loader):
        gray, color = gray.to(device), color.to(device)

        # === Train Discriminator ===
        fake_color = generator(gray)
        real_label = torch.ones((gray.size(0), 1, 30, 30), device=device)
        fake_label = torch.zeros((gray.size(0), 1, 30, 30), device=device)

        optimizer_D.zero_grad()
        real_output = discriminator(gray, color)
        fake_output = discriminator(gray, fake_color.detach())
        d_loss = (criterion(real_output, real_label) + criterion(fake_output, fake_label)) * 0.5
        d_loss.backward()
        optimizer_D.step()

        # === Train Generator ===
        optimizer_G.zero_grad()
        fake_output = discriminator(gray, fake_color)
        g_loss = criterion(fake_output, real_label) + l1_loss(fake_color, color) * 100
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{epochs}] - D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # === Save sample output ===
    if (epoch + 1) % 10 == 0:
        save_image(fake_color, f"/content/fake_epoch_{epoch+1}.png")


Using device: cpu
✅ Matched 84 grayscale-color pairs.


ValueError: Using a target size (torch.Size([4, 1, 30, 30])) that is different to the input size (torch.Size([4, 1, 63, 63])) is deprecated. Please ensure they have the same size.

In [10]:
class GrayscaleToColorDataset(Dataset):
    def __init__(self, grayscale_dir, color_dir, transform=None, target_size=(256, 256)):
        self.grayscale_dir = grayscale_dir
        self.color_dir = color_dir
        self.transform = transform
        self.target_size = target_size

        self.grayscale_files = os.listdir(grayscale_dir)
        self.color_files = os.listdir(color_dir)

        # Normalize filenames
        def normalize(name):
            return name.replace(".grayscale", "").lower()

        # Build matched grayscale-color pairs
        color_map = {normalize(f): f for f in self.color_files}
        self.pairs = []
        for gray in self.grayscale_files:
            norm = normalize(gray)
            if norm in color_map:
                gray_path = os.path.join(self.grayscale_dir, gray)
                color_path = os.path.join(self.color_dir, color_map[norm])

                # Check sizes
                gray_image = Image.open(gray_path)
                color_image = Image.open(color_path)

                if gray_image.size == color_image.size == self.target_size:
                    self.pairs.append((gray, color_map[norm]))

        print(f"✅ Matched {len(self.pairs)} grayscale-color pairs with target size {self.target_size}.")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        gray_name, color_name = self.pairs[idx]
        gray_path = os.path.join(self.grayscale_dir, gray_name)
        color_path = os.path.join(self.color_dir, color_name)

        gray_image = Image.open(gray_path).convert('L')
        color_image = Image.open(color_path).convert('RGB')

        if self.transform:
            gray_image = self.transform(gray_image)
            color_image = self.transform(color_image)

        return gray_image, color_image


In [11]:
import os
from PIL import Image
import shutil

# Directories for grayscale and color images
grayscale_dir = "/content/dataset/train/grayscale"
color_dir = "/content/dataset/train/color"
target_size = (1024, 768)  # Target size for filtering

# Function to filter images
def filter_images(input_dir, output_dir, target_size):
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Loop through all files in the directory
    for filename in os.listdir(input_dir):
        image_path = os.path.join(input_dir, filename)
        image = Image.open(image_path)

        # Check if the image size matches the target size
        if image.size == target_size:
            # If size matches, copy the image to the output directory
            shutil.copy(image_path, os.path.join(output_dir, filename))
        else:
            print(f"Skipping {filename} (size {image.size})")

# Filter grayscale images
filter_images(grayscale_dir, "/content/dataset/train/grayscale_filtered", target_size)

# Filter color images
filter_images(color_dir, "/content/dataset/train/color_filtered", target_size)



Skipping image-270-a.grayscale.jpg (size (1024, 608))
Skipping image-273-a.grayscale.jpg (size (1024, 568))
Skipping image-278-a.grayscale.jpg (size (1024, 729))
Skipping image-298-a.grayscale.jpg (size (1024, 683))
Skipping image-288-a.grayscale.jpg (size (890, 768))
Skipping image-290-a.grayscale.jpg (size (817, 768))
Skipping image-271-a.grayscale.jpg (size (790, 768))
Skipping image-279-a.grayscale.jpg (size (1024, 683))
Skipping image-279-a.jpg (size (1024, 683))
Skipping image-273-a.jpg (size (1024, 568))
Skipping image-278-a.jpg (size (1024, 729))
Skipping image-271-a.jpg (size (790, 768))
Skipping image-288-a.jpg (size (890, 768))
Skipping image-270-a.jpg (size (1024, 608))
Skipping image-290-a.jpg (size (817, 768))
Skipping image-298-a.jpg (size (1024, 683))


In [12]:
import os
from PIL import Image
import shutil

# Directories for test grayscale and color images
grayscale_test_dir = "/content/dataset/test/grayscale"
color_test_dir = "/content/dataset/test/color"
target_size = (1024, 768)  # Target size for filtering

# Function to filter images
def filter_images(input_dir, output_dir, target_size):
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Loop through all files in the directory
    for filename in os.listdir(input_dir):
        image_path = os.path.join(input_dir, filename)
        image = Image.open(image_path)

        # Check if the image size matches the target size
        if image.size == target_size:
            # If size matches, copy the image to the output directory
            shutil.copy(image_path, os.path.join(output_dir, filename))
        else:
            print(f"Skipping {filename} (size {image.size})")

# Filter test grayscale images
filter_images(grayscale_test_dir, "/content/dataset/test/grayscale_filtered", target_size)

# Filter test color images
filter_images(color_test_dir, "/content/dataset/test/color_filtered", target_size)


Skipping image-270-a.grayscale.jpg (size (1024, 608))
Skipping image-273-a.grayscale.jpg (size (1024, 568))
Skipping image-278-a.grayscale.jpg (size (1024, 729))
Skipping image-298-a.grayscale.jpg (size (1024, 683))
Skipping image-288-a.grayscale.jpg (size (890, 768))
Skipping image-290-a.grayscale.jpg (size (817, 768))
Skipping image-271-a.grayscale.jpg (size (790, 768))
Skipping image-279-a.grayscale.jpg (size (1024, 683))
Skipping image-279-a.jpg (size (1024, 683))
Skipping image-273-a.jpg (size (1024, 568))
Skipping image-278-a.jpg (size (1024, 729))
Skipping image-271-a.jpg (size (790, 768))
Skipping image-288-a.jpg (size (890, 768))
Skipping image-270-a.jpg (size (1024, 608))
Skipping image-290-a.jpg (size (817, 768))
Skipping image-298-a.jpg (size (1024, 683))


In [None]:

import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from PIL import Image
import torch.nn.functional as F
from torch import nn

# === Device ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# === Dataset with Automatic Matching ===
class GrayscaleToColorDataset(Dataset):
    def __init__(self, grayscale_dir, color_dir, transform=None):
        self.grayscale_dir = grayscale_dir
        self.color_dir = color_dir
        self.transform = transform

        self.grayscale_files = os.listdir(grayscale_dir)
        self.color_files = os.listdir(color_dir)

        def normalize(name):
            return name.replace(".grayscale", "").lower()

        # Build matched grayscale-color pairs
        color_map = {normalize(f): f for f in self.color_files}
        self.pairs = []
        for gray in self.grayscale_files:
            norm = normalize(gray)
            if norm in color_map:
                self.pairs.append((gray, color_map[norm]))

        print(f"✅ Matched {len(self.pairs)} grayscale-color pairs.")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        gray_name, color_name = self.pairs[idx]
        gray_path = os.path.join(self.grayscale_dir, gray_name)
        color_path = os.path.join(self.color_dir, color_name)

        gray_image = Image.open(gray_path).convert('L')
        color_image = Image.open(color_path).convert('RGB')

        if self.transform:
            gray_image = self.transform(gray_image)
            color_image = self.transform(color_image)

        return gray_image, color_image

# === Transforms ===
transform = transforms.Compose([
    transforms.ToTensor()  # Only convert images to tensors, no resizing
])

# === DataLoader ===
train_dataset = GrayscaleToColorDataset(
    grayscale_dir="/content/dataset/train/grayscale_filtered",
    color_dir="/content/dataset/train/color_filtered",
    transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

# === Generator (U-Net) ===
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(UNetGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, 4, 2, 1), nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, True)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),
            nn.ConvTranspose2d(64, out_channels, 4, 2, 1), nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

# === Discriminator (PatchGAN) ===
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(4, 64, 4, 2, 1), nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128, 1, 4, 1, 1), nn.Sigmoid()
        )

    def forward(self, x, y):
        return self.net(torch.cat([x, y], 1))  # Concatenate grayscale and color images

# === Model Init ===
generator = UNetGenerator().to(device)
discriminator = Discriminator().to(device)

# === Loss and Optimizers ===
criterion = nn.BCELoss()
l1_loss = nn.L1Loss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=2e-4)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=2e-4)

# === Training Loop ===
epochs = 100
for epoch in range(epochs):
    for i, (gray, color) in enumerate(train_loader):
        gray, color = gray.to(device), color.to(device)

        # === Train Discriminator ===
        fake_color = generator(gray)

        # Calculate the output size of the discriminator (patch size)
        output_size = discriminator(gray, color).size()[2:]  # (H, W)

        # Resize the labels to match the output size of the discriminator
        real_label = torch.ones((gray.size(0), 1) + output_size, device=device)
        fake_label = torch.zeros((gray.size(0), 1) + output_size, device=device)

        # Forward pass
        optimizer_D.zero_grad()
        real_output = discriminator(gray, color)
        fake_output = discriminator(gray, fake_color.detach())

        # Compute the loss
        d_loss = (criterion(real_output, real_label) + criterion(fake_output, fake_label)) * 0.5
        d_loss.backward()
        optimizer_D.step()

        # === Train Generator ===
        optimizer_G.zero_grad()
        fake_output = discriminator(gray, fake_color)

        # Generator loss
        g_loss = criterion(fake_output, real_label) + l1_loss(fake_color, color) * 100
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{epochs}] - D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

    # === Save sample output ===
    if (epoch + 1) % 10 == 0:
        save_image(fake_color, f"/content/fake_epoch_{epoch+1}.png")



Using device: cpu
✅ Matched 76 grayscale-color pairs.
Epoch [1/100] - D Loss: 0.6065 | G Loss: 22.6594
Epoch [2/100] - D Loss: 0.4729 | G Loss: 15.0845
Epoch [3/100] - D Loss: 0.4074 | G Loss: 14.3153
Epoch [4/100] - D Loss: 0.3214 | G Loss: 12.7364
Epoch [5/100] - D Loss: 0.4404 | G Loss: 12.1699
Epoch [6/100] - D Loss: 0.3967 | G Loss: 11.2210
Epoch [7/100] - D Loss: 0.4728 | G Loss: 12.1115
Epoch [8/100] - D Loss: 0.4385 | G Loss: 10.5087
Epoch [9/100] - D Loss: 0.4835 | G Loss: 11.7314
Epoch [10/100] - D Loss: 0.6453 | G Loss: 10.0990
Epoch [11/100] - D Loss: 0.5929 | G Loss: 10.7499
Epoch [12/100] - D Loss: 0.5897 | G Loss: 10.6210
Epoch [13/100] - D Loss: 0.5722 | G Loss: 9.5800
Epoch [14/100] - D Loss: 0.5220 | G Loss: 9.9277
Epoch [15/100] - D Loss: 0.4992 | G Loss: 13.6085
Epoch [16/100] - D Loss: 0.4397 | G Loss: 13.5181
Epoch [17/100] - D Loss: 0.4559 | G Loss: 9.4412
Epoch [18/100] - D Loss: 0.4604 | G Loss: 8.6245


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# === Visualization ===
def show_samples(generator, dataset, num_samples=4):
    generator.eval()
    fig, axs = plt.subplots(num_samples, 3, figsize=(12, 3 * num_samples))

    for i in range(num_samples):
        gray_img, real_color = dataset[i]
        with torch.no_grad():
            fake_color = generator(gray_img.unsqueeze(0).to(device))[0].cpu()

        # Convert to NumPy arrays for plotting
        gray_np = gray_img.squeeze().numpy()
        fake_np = (fake_color.permute(1, 2, 0).numpy() + 1) / 2  # Rescale from [-1, 1] to [0, 1]
        real_np = real_color.permute(1, 2, 0).numpy()

        axs[i, 0].imshow(gray_np, cmap='gray')
        axs[i, 0].set_title("Grayscale Input")
        axs[i, 1].imshow(fake_np)
        axs[i, 1].set_title("Generated Color")
        axs[i, 2].imshow(real_np)
        axs[i, 2].set_title("Ground Truth")

        for ax in axs[i]:
            ax.axis("off")

    plt.tight_layout()
    plt.show()

# Run visualization
show_samples(generator, train_dataset, num_samples=4)
