# U-Net pour la colorisation d'images
Projet de colorisation d'images à niveaux de gris par deep learning.  
Implémentation d'un d'un réseau CNN simple au format U-Net pour la prédiction des canaux RGB à partir d'images à niveaux de gris.  
Utilisation de PyTorch avec le dataset CIFAR-10.

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import numpy as np

## Création du modèle U-Net

In [18]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2)
        )
        
        self.middle = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2)
        )
        
        self.final = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 3, kernel_size=1)
        )

    def forward(self, x):
        enc = self.encoder(x)
        middle = self.middle(enc)
        dec = self.decoder(middle)
        out = self.final(dec)
        return out

## Traitement du dataset CIFAR-10

In [9]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor()
])

# Charger les données CIFAR-10
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Préparer les labels (images originales)
train_labels = datasets.CIFAR10(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_labels = datasets.CIFAR10(root='./data', train=False, transform=transforms.ToTensor(), download=True)

train_labels_loader = DataLoader(train_labels, batch_size=32, shuffle=True)
test_labels_loader = DataLoader(test_labels, batch_size=32, shuffle=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100.0%


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Fonction d'entraînement

In [15]:
def train_model(model, train_loader, train_labels_loader, criterion, optimizer, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for (inputs, _), (targets, _) in zip(train_loader, train_labels_loader):
            inputs = inputs.cuda()
            targets = targets.cuda()
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * inputs.size(0)
        
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch}/{num_epochs - 1}, Loss: {epoch_loss:.4f}')

## Fonction de validation

In [16]:
def validate_model(model, test_loader, test_labels_loader, criterion):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for (inputs, _), (targets, _) in zip(test_loader, test_labels_loader):
            inputs = inputs.cuda()
            targets = targets.cuda()
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            val_loss += loss.item() * inputs.size(0)
    
    val_loss /= len(test_loader.dataset)
    print(f'Validation Loss: {val_loss:.4f}')

## Entraînement et validation du modèle

In [19]:
# Initialiser le modèle, la perte et l'optimiseur
model = UNet().cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Entraîner le modèle
train_model(model, train_loader, train_labels_loader, criterion, optimizer, num_epochs=25)

# Valider le modèle
validate_model(model, test_loader, test_labels_loader, criterion)

Epoch 0/24, Loss: 0.0638


KeyboardInterrupt: 

## Test du modèle

In [None]:
def get_test_images(test_loader, test_labels_loader, num_images=5):
    test_images = []
    test_labels = []
    for (inputs, _), (targets, _) in zip(test_loader, test_labels_loader):
        test_images.extend(inputs[:num_images])
        test_labels.extend(targets[:num_images])
        if len(test_images) >= num_images:
            break
    return torch.stack(test_images), torch.stack(test_labels)

num_test_images = 5
test_images, test_labels = get_test_images(test_loader, test_labels_loader, num_images=num_test_images)

In [None]:
model.eval()
with torch.no_grad():
    predicted_images = model(test_images)

In [None]:
def imshow(img):
    img = img / 2 + 0.5  # dénormaliser
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# Afficher les résultats
for i in range(num_test_images):
    print("Image en niveaux de gris:")
    imshow(test_images[i])
    
    print("Image colorée réelle:")
    imshow(test_labels[i])
    
    print("Image colorisée prédite:")
    imshow(predicted_images[i])