In [None]:
import os
import random
import sys
from pathlib import Path
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [None]:
class ColorizationDataset(Dataset):
    def __init__(self, gray_dir, color_dir, transform_gray=None, transform_color=None):
        self.gray_dir = Path(gray_dir)
        self.color_dir = Path(color_dir)
        self.transform_gray = transform_gray
        self.transform_color = transform_color

        gray_files = list(self.gray_dir.glob("*.*"))
        color_files = list(self.color_dir.glob("*.*"))
        
        # Mapowanie kolorowych obrazów według nazwy
        color_dict = {}
        for f in color_files:
            name = f.stem.lower()
            if name.endswith("_color"):
                name = name[:-6]  # usuń '_color'
            color_dict[name] = f

        self.paired_files = []
        for gray_path in gray_files:
            stem = gray_path.stem.lower()
            if stem.endswith("_gray") or stem.endswith("_czb"):
                stem = stem.rsplit("_", 1)[0]

            if stem in color_dict:
                self.paired_files.append((gray_path, color_dict[stem]))
            else:
                print(f"[!] Brak koloru dla: {gray_path.name}")

    def __len__(self):
        return len(self.paired_files)

    def __getitem__(self, idx):
        gray_path, color_path = self.paired_files[idx]

        gray_img = Image.open(gray_path).convert("L")   # 1-kanałowe
        color_img = Image.open(color_path).convert("RGB")  # 3-kanałowe

        if self.transform_gray:
            gray_img = self.transform_gray(gray_img)
        if self.transform_color:
            color_img = self.transform_color(color_img)

        return gray_img, color_img



# UNet model
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2, inplace=True),
            )

        self.enc1 = conv_block(1, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = conv_block(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = conv_block(128, 256)
        self.pool3 = nn.MaxPool2d(2)

        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.5),
        )

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = conv_block(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = conv_block(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = conv_block(128, 64)
        self.final = nn.Conv2d(64, 1, kernel_size=1)
        self.activation = nn.Sigmoid()

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bottleneck(self.pool3(e3))
        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))
        out = self.final(d1)
        return self.activation(out)

In [None]:
# Transformaty dla szarych i kolorowych obrazów
transform_gray = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

transform_color = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

# === Ścieżki ===
color_dir = ""      #folder z kwiatami kolorowymi
bw_dir = ""         #folder z szarymi kwiatami

# Dataset
dataset = ColorizationDataset(color_dir, bw_dir, transform_gray, transform_color)

print(f"Liczba sparowanych plików w dataset: {len(dataset)}")

# DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# liczba sparowanych plików powinna wynosić 21625

Liczba sparowanych plików w dataset: 21625


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epo = 5
num_epochs = 4

# Pełny dataset (utwórz raz)
full_dataset = ColorizationDataset(
    
    # === Ścieżki ===
    color_dir="",         #folder z kwiatami kolorowymi  
    bw_dir="",            #folder z czarno-białymi  
    
    transform_color=transform_color,
    transform_gray=transform_gray
)

for epos in range(epo):
    # === LOSUJEMY 500 losowych obrazów na aktualny cykl treningowy ===
    subset_indices = random.sample(range(len(full_dataset)), min(500, len(full_dataset)))
    subset = Subset(full_dataset, subset_indices)
    dataloader = DataLoader(subset, batch_size=8, shuffle=True)

    for epoch in range(num_epochs):
        for bw, color in dataloader:
            bw, color = bw.to(device), color.to(device)

            output = model(bw)
            loss = criterion(output, color)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f"[Epo {epos+1}/{epo}] Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

    # === Zapis modelu po każdej pełnej epoce (czyli po num_epochs iteracjach) ===
    save_path = f"C:\\zdjęcia na chwile\\kwiaty_same_model_i_data\\model_gray_to_color_flowers_{epos+1}_{loss.item():.4f}.pth"
    torch.save(model.state_dict(), save_path)
    print(f"Zapisano model: {save_path}")


[Epo 1/5] Epoch [1/4], Loss: 0.0481
[Epo 1/5] Epoch [2/4], Loss: 0.0383


dodatkowy trening

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)

model_path=""    #model który chcesz dotrenować


# Załaduj zapisane wcześniej wagi modelu
model.load_state_dict(torch.load(model_path, map_location=device))

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

epo = 5
num_epochs = 4
for i in range(epo):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for bw, color in dataloader:
            bw, color = bw.to(device), color.to(device)

            optimizer.zero_grad()
            output = model(bw)
            loss = criterion(output, color)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * bw.size(0)

        epoch_loss = running_loss / len(dataloader.dataset)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
        
    zapis_path=f""    #gdzie zapisać ten model (najlepiej daj tą samą nazwę ale z jakimś dopiskiem np. "_tren_{epo+1}_{epoch_loss:.4f}.pth")
    torch.save(model.state_dict(), zapis_path)

WCZYTANIE-one photo

In [None]:
# === Ścieżki ===
input_path = ""
output_path = ""
model_path =""

# === Wczytaj model ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet().to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# === Wczytaj obraz czarno-biały ===
if not os.path.exists(input_path):
    print(f"Plik {input_path} nie istnieje.")
    sys.exit(1)

img = Image.open(input_path).convert("L")
original_size = img.size  # zapamiętaj oryginalny rozmiar

# === Skaluj do 128x128 dla modelu ===
transform_bw = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

bw_tensor = transform_bw(img).unsqueeze(0).to(device)

# === Przewidź kolory ===
with torch.no_grad():
    output = model(bw_tensor)

# === Przeskaluj wynik do oryginalnego rozmiaru ===
output_resized = transforms.functional.resize(output.squeeze(0), original_size)

# === Zapisz wynik ===
save_image(output_resized, output_path)
print(f"Zapisano kolorowy obraz o oryginalnej rozdzielczości: {output_path}")


Zapisano kolorowy obraz o oryginalnej rozdzielczości: C:\zdjęcia na chwile\kwiaty_same_model_i_data\wynik_kolor_fullres.png
