##IMAGE COLORIZER

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Load the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

In [None]:
# Define the colorization model (U-Net-like architecture, no sigmoid at output)
class ColorizationNet(nn.Module):
    def __init__(self):
        super(ColorizationNet, self).__init__()
        # Encoder
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(2, 2)

        # Decoder
        self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec1 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.final = nn.Conv2d(64, 3, kernel_size=1)

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        # Decoder
        d1 = self.up1(e3)
        # Crop e2 to match d1 if needed
        if d1.shape[2:] != e2.shape[2:]:
            e2 = e2[:, :, :d1.shape[2], :d1.shape[3]]
        d1 = torch.cat([d1, e2], dim=1)
        d1 = self.dec1(d1)
        d2 = self.up2(d1)
        # Crop e1 to match d2 if needed
        if d2.shape[2:] != e1.shape[2:]:
            e1 = e1[:, :, :d2.shape[2], :d2.shape[3]]
        d2 = torch.cat([d2, e1], dim=1)
        d2 = self.dec2(d2)
        out = self.final(d2)
        return out

In [None]:
model = ColorizationNet().to(device)  # device is defined above

# Optionally load model weights if available
import os
weights_path = "colorization_model_weights.pth"
if os.path.exists(weights_path):
    model.load_state_dict(torch.load(weights_path, map_location=device))
    print(f"Loaded model weights from {weights_path}")

# Loss and optimizer
criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Changed from Adagrad to Adam

# Convert RGB image to grayscale
def rgb_to_gray(img):
    return img.mean(dim=1, keepdim=True) # bw has 1 dimension, keep-dimensions

    '''
    - model.parameters() provides the optimizer with access to the model's adjustable parameters.
    - lr=0.001 sets the learning rate, controlling the step size the optimizer takes when updating the model's parameters. A smaller learning rate leads to slower but potentially more stable training.
    '''


In [None]:
# Only train if not loading weights
skip_training = os.path.exists(weights_path)
if not skip_training:
    # Training loop
    EPOCHS = 50
    for epoch in range(EPOCHS):
        for i, (images, _) in enumerate(train_loader):
            grayscale_images = rgb_to_gray(images).to(device)
            images = images.to(device)

            # Forward pass
            outputs = model(grayscale_images)
            outputs = torch.clamp(outputs, 0.0, 1.0)
            loss = criterion(outputs, images)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Print statistics and gradient info
            if i % 100 == 0:
                grad_means = []
                grad_stds = []
                for name, param in model.named_parameters():
                    if param.grad is not None:
                        grad_means.append(param.grad.mean().item())
                        grad_stds.append(param.grad.std().item())
                grad_mean = sum(grad_means)/len(grad_means) if grad_means else 0.0
                grad_std = sum(grad_stds)/len(grad_stds) if grad_stds else 0.0
                print(f"Epoch [{epoch+1}/{EPOCHS}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}, Grad mean: {grad_mean:.6f}, Grad std: {grad_std:.6f}")

    print("Finished Training")
    # Save model weights after training
    torch.save(model.state_dict(), "colorization_model_weights.pth")
    print("Model weights saved to colorization_model_weights.pth")
else:
    print("Weights loaded, skipping training.")

In [None]:
# --- Colorize your own black and white image ---
from PIL import Image
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Path to your grayscale image (change this to your file)
img_path = r"C:\Users\thedy\Downloads\girl-nikon-z-series-tips-nikon-cameras-lenses-accessories.webp"

# Load image as grayscale
img = Image.open(img_path).convert('L')  # 'L' mode = grayscale

# Convert to tensor and normalize to [0,1]
to_tensor = transforms.ToTensor()  # outputs shape (1, H, W)
gray_tensor = to_tensor(img).unsqueeze(0)  # shape (1, 1, H, W)

# Move to device
gray_tensor = gray_tensor.to(device)

# Run through model
with torch.no_grad():
    colorized = model(gray_tensor)
    colorized = colorized.squeeze(0).cpu()  # shape (3, H, W)

# Convert to numpy for display
colorized_np = colorized.permute(1, 2, 0).numpy()

# Show original and colorized
fig, axs = plt.subplots(1, 2, figsize=(8, 4))
axs[0].imshow(img, cmap='gray')
axs[0].set_title('Input Grayscale')
axs[0].axis('off')
axs[1].imshow(colorized_np)
axs[1].set_title('Colorized Output')
axs[1].axis('off')
plt.show()
