In [8]:
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
from skimage import color
from torchvision.datasets import CIFAR10
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
#from google.colab import files


In [2]:


class CIFAR10_Lab(Dataset):
  def __init__(self, train=True):
    self.data = CIFAR10(root='./data', train=train, download=True)
    self.images = self.data.data

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img_rgb = self.images[idx]/255.0
    img_lab = color.rgb2lab(img_rgb)

    L = img_lab[:,:, 0] / 100.0
    ab = img_lab[:,:, 1:] / 128.0

    L_tenzor = torch.tensor(L, dtype=torch.float32).unsqueeze(0)
    ab_tenzor = torch.tensor(ab, dtype=torch.float32).permute(2,0,1)

    return L_tenzor, ab_tenzor

In [4]:
train_dataset = CIFAR10_Lab(train=True)
test_dataset = CIFAR10_Lab(train=False)

train_loader =DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=32)

Files already downloaded and verified
Files already downloaded and verified


In [5]:


class ColorizationCNN(nn.Module):
    def __init__(self):
        super(ColorizationCNN, self).__init__()

        # Encoder: уменьшает размер, но увеличивает глубину (кол-во каналов)
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),  # (1, 32, 32) -> (64, 32, 32)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # (128, 32, 32)
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # (128, 16, 16)
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)   # (256, 8, 8)
        )

        # Decoder: восстанавливает размер до исходного (32x32), но с 2 каналами
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),  # (128, 16, 16)
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),   # (64, 32, 32)
            nn.ReLU(),
            nn.Conv2d(64, 2, kernel_size=3, padding=1),             # (2, 32, 32)
            nn.Tanh()  # чтобы выход был от -1 до 1, как и ab каналы
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = ColorizationCNN().to(device)
#model.load_state_dict(torch.load('colorization_weights.pth'))

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
#optimizer.load_state_dict(torch.load('optimizer_weights.pth'))

criterion = nn.MSELoss()

cpu


In [None]:
num_epochs = 20
for epoch in range(num_epochs):
  model.train()
  running_loss = 0.0
  loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{num_epochs}]", leave=False)
  for L, ab in loop:
    
    L, ab = L.to(device), ab.to(device)
    optimizer.zero_grad() 
    output = model(L)
    loss = criterion(output, ab)

    
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    loop.set_postfix(loss=loss.item())
  avg_loss = running_loss / len(train_loader)
  print(f"Эпоха {epoch+1}/{num_epochs}, Потеря: {avg_loss:.4f}")


Epoch [1/20]:   0%|          | 0/1563 [00:00<?, ?it/s]

In [None]:
torch.save(model.state_dict(), "saved_models/colorization_weights_lab_V1.pth")

torch.save(optimizer.state_dict(), 'saved_models/optimizer_weights_lab_V1.pth')

#files.download("colorization_weights_lab_V1.pth")
#files.download("optimizer_weights_lab_V1.pth")

In [None]:
model.eval()
test_loss = 0.0

with torch.no_grad():
    for L, ab in tqdm(test_loader, desc="Testing"):
        L, ab = L.to(device), ab.to(device)

        output = model(L)
        loss = criterion(output, ab)
        test_loss += loss.item()

print(f"Test Loss: {test_loss / len(test_loader):.6f}")
