In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
from skimage import io, color
import numpy as np

# Define the model
class ColorizationModel(nn.Module):
    def __init__(self):
        super(ColorizationModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 8, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(16, 2, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x * 128

# Load and preprocess the image
image = io.imread('woman.jpg')
lab_image = color.rgb2lab(image)
X = lab_image[:, :, 0]
Y = lab_image[:, :, 1:] / 128.0
X = X.reshape(1, 1, 400, 400)
Y = Y.reshape(1, 2, 400, 400)

# Convert to PyTorch tensors
X = torch.from_numpy(X).float()
Y = torch.from_numpy(Y).float()

# Initialize the model, loss function, and optimizer
model = ColorizationModel()
criterion = nn.MSELoss()
optimizer = optim.RMSprop(model.parameters())

# Training loop
num_epochs = 1000
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X)
    loss = criterion(outputs, Y)

    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Save the colorized image
output = model(X).detach().numpy()
output = np.transpose(output[0], (1, 2, 0))
output_rgb = color.lab2rgb(output)

io.imsave("img_result_pytorch.png", output_rgb)
io.imsave("img_gray_version_pytorch.png", color.rgb2gray(output_rgb))
