In [1]:
!pip install torch torchvision pywavelets tqdm




In [12]:
#GANの学習
import os
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm


# ========== Dataset ==========
class AE2CleanDataset(Dataset):
    def __init__(self, ae_dir, clean_dir, transform=None):
        self.ae_dir = ae_dir
        self.clean_dir = clean_dir
        self.transform = transform
        self.filenames = sorted(os.listdir(ae_dir))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        ae_path = os.path.join(self.ae_dir, self.filenames[idx])
        clean_path = os.path.join(self.clean_dir, self.filenames[idx])

        ae_img = Image.open(ae_path).convert("RGB")
        clean_img = Image.open(clean_path).convert("RGB")

        if self.transform:
            ae_img = self.transform(ae_img)
            clean_img = self.transform(clean_img)

        return ae_img, clean_img


# ========== Generator (U-Net) ==========
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()

        self.encoder = nn.Sequential(
            self.block(in_channels, features, normalize=False),
            self.block(features, features * 2),
            self.block(features * 2, features * 4),
            self.block(features * 4, features * 8),
            self.block(features * 8, features * 8),
        )

        self.decoder = nn.Sequential(
            self.upblock(features * 8, features * 8),
            self.upblock(features * 8 * 2, features * 4),
            self.upblock(features * 4 * 2, features * 2),
            self.upblock(features * 2 * 2, features),
            nn.ConvTranspose2d(features * 2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def block(self, in_c, out_c, normalize=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def upblock(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )

    def forward(self, x):
        skips = []
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)

        skips = skips[:-1][::-1]  # reverse except last
        for idx, layer in enumerate(self.decoder[:-2]):
            x = layer(x)
            if idx < len(skips):
                x = torch.cat([x, skips[idx]], 1)

        return self.decoder[-2](x)


# ========== Discriminator ==========
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6, features=64):
        super().__init__()
        self.net = nn.Sequential(
            self.block(in_channels, features, normalize=False),
            self.block(features, features * 2),
            self.block(features * 2, features * 4),
            nn.Conv2d(features * 4, 1, kernel_size=4, stride=1, padding=1),
            nn.Sigmoid()
        )

    def block(self, in_c, out_c, normalize=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def forward(self, x, y):
        return self.net(torch.cat([x, y], 1))


# ========== 学習ループ ==========
def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = T.Compose([
        T.Resize((256, 256)),
        T.ToTensor(),
        T.Normalize([0.5]*3, [0.5]*3),
    ])

    dataset = AE2CleanDataset("C:\\Users\\sit\\wavelet_CGAN\\train\\AE_wavelet","C:\\Users\\sit\\wavelet_CGAN\\train\\normal256" , transform)
    loader = DataLoader(dataset, batch_size=8, shuffle=True)

    gen = UNetGenerator().to(device)
    disc = PatchDiscriminator().to(device)

    opt_g = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))

    bce = nn.BCELoss()
    l1 = nn.L1Loss()

    for epoch in range(100):
        loop = tqdm(loader, desc=f"Epoch {epoch+1}/100")
        for ae_img, clean_img in loop:
            ae_img, clean_img = ae_img.to(device), clean_img.to(device)

            # === Discriminator ===
            fake_img = gen(ae_img).detach()
            real_pred = disc(ae_img, clean_img)
            fake_pred = disc(ae_img, fake_img)

            real_loss = bce(real_pred, torch.ones_like(real_pred))
            fake_loss = bce(fake_pred, torch.zeros_like(fake_pred))
            d_loss = (real_loss + fake_loss) / 2

            opt_d.zero_grad()
            d_loss.backward()
            opt_d.step()

            # === Generator ===
            fake_img = gen(ae_img)
            pred = disc(ae_img, fake_img)
            adv_loss = bce(pred, torch.ones_like(pred))
            l1_loss = l1(fake_img, clean_img) * 100

            g_loss = adv_loss + l1_loss

            opt_g.zero_grad()
            g_loss.backward()
            opt_g.step()

            loop.set_postfix(G_loss=g_loss.item(), D_loss=d_loss.item())

        # Save checkpoints
        torch.save(gen.state_dict(), f"./gen_weights_epoch{epoch+1}.pth")
        save_image(fake_img * 0.5 + 0.5, f"./outputs/fake_epoch{epoch+1}.png")


if __name__ == "__main__":
    os.makedirs("outputs", exist_ok=True)
    train()


Epoch 1/100: 100%|██████████| 625/625 [00:33<00:00, 18.50it/s, D_loss=0.163, G_loss=9.36] 
Epoch 2/100: 100%|██████████| 625/625 [00:34<00:00, 18.31it/s, D_loss=0.0113, G_loss=11.4] 
Epoch 3/100: 100%|██████████| 625/625 [00:35<00:00, 17.72it/s, D_loss=0.0121, G_loss=12.4] 
Epoch 4/100: 100%|██████████| 625/625 [00:35<00:00, 17.62it/s, D_loss=0.0122, G_loss=10.4] 
Epoch 5/100: 100%|██████████| 625/625 [00:35<00:00, 17.85it/s, D_loss=0.00796, G_loss=11.4]
Epoch 6/100: 100%|██████████| 625/625 [00:34<00:00, 17.93it/s, D_loss=0.531, G_loss=6.19]  
Epoch 7/100: 100%|██████████| 625/625 [00:35<00:00, 17.67it/s, D_loss=0.773, G_loss=5.09] 
Epoch 8/100: 100%|██████████| 625/625 [00:34<00:00, 18.34it/s, D_loss=0.234, G_loss=7.48] 
Epoch 9/100: 100%|██████████| 625/625 [00:37<00:00, 16.68it/s, D_loss=0.521, G_loss=5.48] 
Epoch 10/100: 100%|██████████| 625/625 [00:38<00:00, 16.23it/s, D_loss=0.368, G_loss=6.92]
Epoch 11/100: 100%|██████████| 625/625 [00:37<00:00, 16.65it/s, D_loss=0.404, G_loss=

In [3]:
#AE画像からwavelet変換し、cleanに保存
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pywt
import numpy as np
import os
from tqdm import tqdm

# 作業ディレクトリに変更
os.chdir("C:\\Users\\sit\\wavelet_CGAN")

# ======== Dataset ========
class AEDataset(Dataset):
    def __init__(self, ae_dir, clean_dir, transform=None):
        self.ae_dir = ae_dir
        self.clean_dir = clean_dir
        self.names = [f for f in os.listdir(ae_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform or transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

    def __len__(self):
        return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        ae_path = os.path.join(self.ae_dir, name)
        clean_path = os.path.join(self.clean_dir, name)
        ae_img = Image.open(ae_path).convert('RGB')
        clean_img = Image.open(clean_path).convert('RGB')
        if self.transform:
            ae_img = self.transform(ae_img)
            clean_img = self.transform(clean_img)
        return ae_img, clean_img

# ======== Waveletノイズ除去 ========
def wavelet_denoise(img_tensor, wavelet='haar', level=1, threshold=0.00001):
    with torch.no_grad():
        img_np = img_tensor.cpu().numpy()
        denoised = []
        for b in range(img_np.shape[0]):
            channels = []
            for c in range(img_np.shape[1]):
                coeffs = pywt.wavedec2(img_np[b, c], wavelet=wavelet, level=level)
                cA, cD = coeffs[0], coeffs[1:]
                cD_thresh = []
                for (cH, cV, cD_) in cD:
                    cH = pywt.threshold(cH, threshold * np.max(cH))
                    cV = pywt.threshold(cV, threshold * np.max(cV))
                    cD_ = pywt.threshold(cD_, threshold * np.max(cD_))
                    cD_thresh.append((cH, cV, cD_))
                coeffs_thresh = [cA] + cD_thresh
                denoised_channel = pywt.waverec2(coeffs_thresh, wavelet)
                denoised_channel = denoised_channel[:img_np.shape[2], :img_np.shape[3]]
                channels.append(denoised_channel)
            denoised.append(np.stack(channels))
        denoised = np.stack(denoised)
        return torch.tensor(denoised).to(img_tensor.device).float().clamp(-1, 1)

def save_wavelet_clean_images(ae_dir, clean_dir):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    os.makedirs(clean_dir, exist_ok=True)
    files = [f for f in os.listdir(ae_dir) if f.endswith(('.png', '.jpg'))]
    for name in tqdm(files, desc="Wavelet保存"):
        img = Image.open(os.path.join(ae_dir, name)).convert('RGB')
        img_tensor = transform(img).unsqueeze(0)
        denoised = wavelet_denoise(img_tensor)
        out = denoised.squeeze(0).cpu() * 0.5 + 0.5
        Image.fromarray((out.permute(1,2,0).numpy() * 255).astype(np.uint8)).save(os.path.join(clean_dir, name))

# ======== Wavelet Layer ========
class DWTForward(nn.Module):
    def __init__(self, wave='haar'):
        super().__init__()
        self.wave = wave

    def forward(self, x):
        B, C, H, W = x.shape
        coeffs = []
        for b in range(B):
            c_stack = []
            for c in range(C):
                cA, (cH, cV, cD) = pywt.dwt2(x[b, c].cpu().numpy(), self.wave)
                c_stack.append(np.stack([cA, cH, cV, cD]))
            coeffs.append(np.stack(c_stack))
        coeffs = torch.tensor(coeffs).to(x.device)
        return coeffs

class DWTInverse(nn.Module):
    def __init__(self, wave='haar'):
        super().__init__()
        self.wave = wave

    def forward(self, coeffs):
        B, C, _, H, W = coeffs.shape
        imgs = []
        for b in range(B):
            rec = []
            for c in range(C):
                cA, cH, cV, cD = coeffs[b, c]
                i = pywt.idwt2((cA.cpu().numpy(), (cH.cpu().numpy(), cV.cpu().numpy(), cD.cpu().numpy())), self.wave)
                rec.append(i)
            imgs.append(np.stack(rec))
        return torch.tensor(imgs).to(coeffs.device)

# ======== WaveCNet Generator ========
class WaveCNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.dwt = DWTForward()
        self.iwt = DWTInverse()

        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels * 4, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, out_channels * 4, 3, padding=1)
        )

    def forward(self, x):
        coeffs = self.dwt(x)
        B, C, W, H, W = coeffs.size()
        coeffs = coeffs.view(B, C * 4, H, W)
        enc = self.encoder(coeffs)
        dec = self.decoder(enc)
        dec = dec.view(B, C, 4, H, W)
        return self.iwt(dec)

# ======== PatchGAN Discriminator ========
class Discriminator(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()
        def block(in_f, out_f, norm=True):
            layers = [nn.Conv2d(in_f, out_f, 4, 2, 1)]
            if norm:
                layers.append(nn.BatchNorm2d(out_f))
            layers.append(nn.LeakyReLU(0.2))
            return layers
        self.model = nn.Sequential(
            *block(in_channels, 64, norm=False),
            *block(64, 128),
            *block(128, 256),
            *block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img_A, img_B):
        return self.model(torch.cat((img_A, img_B), 1))


# ======== AE → wavelet変換画像を保存 =========
def save_wavelet_clean_images(ae_dir, clean_dir, wavelet='haar', level=1, threshold=0.00001):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    if not os.path.exists(clean_dir):
        os.makedirs(clean_dir)

    file_names = [f for f in os.listdir(ae_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]

    for name in tqdm(file_names, desc="Wavelet変換中"):
        img_path = os.path.join(ae_dir, name)
        img = Image.open(img_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0)

        denoised_tensor = wavelet_denoise(img_tensor, wavelet=wavelet, level=level, threshold=threshold)
        denoised_img = denoised_tensor.squeeze(0).cpu()
        denoised_img = denoised_img * 0.5 + 0.5  # [-1,1] → [0,1]
        denoised_img = transforms.ToPILImage()(denoised_img)
        denoised_img.save(os.path.join(clean_dir, name))

# ======== Generator (U-Net) ========
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        self.model = nn.Sequential(*layers)
        self.dropout = dropout

    def forward(self, x):
        x = self.model(x)
        return F.dropout2d(x, p=self.dropout) if self.dropout else x

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        self.model = nn.Sequential(*layers)
        self.dropout = dropout

    def forward(self, x, skip_input):
        x = self.model(x)
        if self.dropout:
            x = F.dropout2d(x, p=self.dropout)
        x = torch.cat((x, skip_input), 1)
        return x

class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64, 128)
        self.down3 = UNetDown(128, 256)
        self.down4 = UNetDown(256, 512, dropout=0.5)
        self.down5 = UNetDown(512, 512, dropout=0.5)
        self.down6 = UNetDown(512, 512, dropout=0.5)
        self.down7 = UNetDown(512, 512, dropout=0.5)
        self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)

        self.up1 = UNetUp(512, 512, dropout=0.5)
        self.up2 = UNetUp(1024, 512, dropout=0.5)
        self.up3 = UNetUp(1024, 512, dropout=0.5)
        self.up4 = UNetUp(1024, 512, dropout=0.5)
        self.up5 = UNetUp(1024, 256)
        self.up6 = UNetUp(512, 128)
        self.up7 = UNetUp(256, 64)

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)

        u1 = self.up1(d8, d7)
        u2 = self.up2(u1, d6)
        u3 = self.up3(u2, d5)
        u4 = self.up4(u3, d4)
        u5 = self.up5(u4, d3)
        u6 = self.up6(u5, d2)
        u7 = self.up7(u6, d1)
        return self.final(u7)

# ======== Discriminator (PatchGAN) ========
class Discriminator(nn.Module):
    def __init__(self, in_channels=6):
        super().__init__()
        def discriminator_block(in_filters, out_filters, normalization=True):
            layers = [nn.Conv2d(in_filters, out_filters, 4, 2, 1)]
            if normalization:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img_A, img_B):
        x = torch.cat((img_A, img_B), 1)
        return self.model(x)

# ======== 損失関数 ========
criterion_GAN = nn.BCEWithLogitsLoss()
criterion_L1 = nn.L1Loss()

# ======== 学習 ========
def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dataset = AEDataset('train/AE', 'train/CLEAN')
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=0)

    generator = GeneratorUNet().to(device)
    discriminator = Discriminator().to(device)

    optimizer_G = optim.Adam(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
    optimizer_D = optim.Adam(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

    real_label = 1.
    fake_label = 0.

    epochs = 100
    lambda_L1 = 100

    for epoch in range(epochs):
        loop = tqdm(dataloader)
        for i, (ae_imgs, clean_imgs) in enumerate(loop):
            ae_imgs = ae_imgs.to(device)
            clean_imgs = clean_imgs.to(device)

            ae_denoised = wavelet_denoise(ae_imgs)

            # Discriminator
            optimizer_D.zero_grad()
            fake_clean = generator(ae_denoised)
            pred_real = discriminator(ae_denoised, clean_imgs)
            pred_fake = discriminator(ae_denoised, fake_clean.detach())

            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real) * real_label)
            loss_D_fake = criterion_GAN(pred_fake, torch.ones_like(pred_fake) * fake_label)
            loss_D = (loss_D_real + loss_D_fake) * 0.5
            loss_D.backward()
            optimizer_D.step()

            # Generator
            optimizer_G.zero_grad()
            pred_fake = discriminator(ae_denoised, fake_clean)
            loss_GAN = criterion_GAN(pred_fake, torch.ones_like(pred_fake) * real_label)
            loss_L1_value = criterion_L1(fake_clean, clean_imgs) * lambda_L1
            loss_G = loss_GAN + loss_L1_value
            loss_G.backward()
            optimizer_G.step()

            loop.set_description(f"Epoch [{epoch+1}/{epochs}]")
            loop.set_postfix(loss_D=loss_D.item(), loss_G=loss_G.item())

    torch.save(generator.state_dict(), 'generator.pth')
    torch.save(discriminator.state_dict(), 'discriminator.pth')

if __name__ == "__main__":
    save_wavelet_clean_images('train/AE2', 'train/CLEAN2')
    train()


Wavelet変換中: 0it [00:00, ?it/s]
Epoch [1/100]:  54%|█████▎    | 669/1250 [01:20<01:10,  8.26it/s, loss_D=0.0136, loss_G=10.2] 


KeyboardInterrupt: 

In [5]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import pywt
from tqdm import tqdm

# ========== Waveletノイズ除去関数 ==========
def wavelet_denoise(img_tensor, wavelet='haar', level=1, threshold=0.05):
    with torch.no_grad():
        img_np = img_tensor.cpu().numpy()
        denoised = []
        for b in range(img_np.shape[0]):
            channels = []
            for c in range(img_np.shape[1]):
                coeffs = pywt.wavedec2(img_np[b, c], wavelet=wavelet, level=level)
                cA, cD = coeffs[0], coeffs[1:]
                cD_thresh = []
                for (cH, cV, cD_) in cD:
                    cH = pywt.threshold(cH, threshold * np.max(np.abs(cH)))
                    cV = pywt.threshold(cV, threshold * np.max(np.abs(cV)))
                    cD_ = pywt.threshold(cD_, threshold * np.max(np.abs(cD_)))
                    cD_thresh.append((cH, cV, cD_))
                coeffs_thresh = [cA] + cD_thresh
                denoised_channel = pywt.waverec2(coeffs_thresh, wavelet)
                # 元のサイズに切り詰め
                denoised_channel = denoised_channel[:img_np.shape[2], :img_np.shape[3]]
                channels.append(denoised_channel)
            denoised.append(np.stack(channels))
        denoised = np.stack(denoised)
        return torch.tensor(denoised).to(img_tensor.device).float().clamp(-1, 1)

# ========== GeneratorUNetモデル ==========
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        self.model = nn.Sequential(*layers)
        self.dropout = dropout

    def forward(self, x):
        x = self.model(x)
        if self.dropout:
            x = F.dropout2d(x, p=self.dropout)
        return x

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        self.model = nn.Sequential(*layers)
        self.dropout = dropout

    def forward(self, x, skip_input):
        x = self.model(x)
        if self.dropout:
            x = F.dropout2d(x, p=self.dropout)
        x = torch.cat((x, skip_input), 1)
        return x

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()

        self.encoder = nn.Sequential(
            self.block(in_channels, features, normalize=False),
            self.block(features, features * 2),
            self.block(features * 2, features * 4),
            self.block(features * 4, features * 8),
            self.block(features * 8, features * 8),
        )

        self.decoder = nn.Sequential(
            self.upblock(features * 8, features * 8),
            self.upblock(features * 8 * 2, features * 4),
            self.upblock(features * 4 * 2, features * 2),
            self.upblock(features * 2 * 2, features),
            nn.ConvTranspose2d(features * 2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def block(self, in_c, out_c, normalize=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def upblock(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )

    def forward(self, x):
        skips = []
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)

        skips = skips[:-1][::-1]  # reverse except last
        for idx, layer in enumerate(self.decoder[:-2]):
            x = layer(x)
            if idx < len(skips):
                x = torch.cat([x, skips[idx]], 1)

        return self.decoder[-2](x)


# ========== 推論処理 ==========
def load_generator(model_path, device):
    model = UNetGenerator().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def denoise_and_correct(generator, input_path, output_path, device):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    img = Image.open(input_path).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)  # (1,3,H,W)

    with torch.no_grad():
        input_denoised = wavelet_denoise(input_tensor)
        output = generator(input_denoised)
        output_img = (output.squeeze(0).cpu() * 0.5) + 0.5  # [-1,1] -> [0,1]

    output_pil = transforms.ToPILImage()(output_img.clamp(0,1))
    output_pil.save(output_path)


def wavelet_denoise_and_save(input_path, output_path, device):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    img = Image.open(input_path).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)  # (1,3,H,W)

    with torch.no_grad():
        denoised_tensor = wavelet_denoise(input_tensor)

    # denoised_tensorは[-1,1]範囲なので[0,1]に変換
    denoised_img = (denoised_tensor.squeeze(0).cpu() * 0.5) + 0.5
    denoised_pil = transforms.ToPILImage()(denoised_img.clamp(0,1))
    denoised_pil.save(output_path)

    

# ========== メイン処理 ==========
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = load_generator("C:\\Users\\sit\\wavelet_CGAN\\gen_weights_epoch100.pth", device)

    input_dir = "C:\\Users\\sit\\wavelet_CGAN\\train\\AE"
    wavelet_dir = "C:\\Users\\sit\\wavelet_CGAN\\train\\AE_wavelet"
    output_dir = "C:\\Users\\sit\\wavelet_CGAN\\train\\AE_wavelet_GAN"
    
    os.makedirs(wavelet_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)

    # 1. Waveletノイズ除去画像を別フォルダに保存
    for fname in tqdm(os.listdir(input_dir)):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        input_path = os.path.join(input_dir, fname)
        wavelet_path = os.path.join(wavelet_dir, fname)
        wavelet_denoise_and_save(input_path, wavelet_path, device)

    # 2. GANで復元処理（Waveletノイズ除去画像を入力に）
    for fname in tqdm(os.listdir(wavelet_dir)):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        input_path = os.path.join(wavelet_dir, fname)
        output_path = os.path.join(output_dir, fname)

        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

        img = Image.open(input_path).convert('RGB')
        input_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            output = generator(input_tensor)
            output_img = (output.squeeze(0).cpu() * 0.5) + 0.5  # [-1,1]→[0,1]

        output_pil = transforms.ToPILImage()(output_img.clamp(0,1))
        output_pil.save(output_path)

    print(f"Waveletノイズ除去画像を {wavelet_dir} に保存しました。")
    print(f"復元画像を {output_dir} に保存しました。")



100%|██████████| 5000/5000 [01:12<00:00, 69.40it/s]
100%|██████████| 5000/5000 [01:07<00:00, 74.34it/s]

Waveletノイズ除去画像を C:\Users\sit\wavelet_CGAN\train\AE_wavelet に保存しました。
復元画像を C:\Users\sit\wavelet_CGAN\train\AE_wavelet_GAN に保存しました。





In [16]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import numpy as np
import pywt
from tqdm import tqdm

# ========== Waveletノイズ除去関数 ==========
def wavelet_denoise(img_tensor, wavelet='haar', level=1, threshold=0.00001):
    with torch.no_grad():
        img_np = img_tensor.cpu().numpy()
        denoised = []
        for b in range(img_np.shape[0]):
            channels = []
            for c in range(img_np.shape[1]):
                coeffs = pywt.wavedec2(img_np[b, c], wavelet=wavelet, level=level)
                cA, cD = coeffs[0], coeffs[1:]
                cD_thresh = []
                for (cH, cV, cD_) in cD:
                    cH = pywt.threshold(cH, threshold * np.max(np.abs(cH)))
                    cV = pywt.threshold(cV, threshold * np.max(np.abs(cV)))
                    cD_ = pywt.threshold(cD_, threshold * np.max(np.abs(cD_)))
                    cD_thresh.append((cH, cV, cD_))
                coeffs_thresh = [cA] + cD_thresh
                denoised_channel = pywt.waverec2(coeffs_thresh, wavelet)
                # 元のサイズに切り詰め
                denoised_channel = denoised_channel[:img_np.shape[2], :img_np.shape[3]]
                channels.append(denoised_channel)
            denoised.append(np.stack(channels))
        denoised = np.stack(denoised)
        return torch.tensor(denoised).to(img_tensor.device).float().clamp(-1, 1)

# ========== GeneratorUNetモデル ==========
class UNetDown(nn.Module):
    def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
        super().__init__()
        layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_size))
        layers.append(nn.LeakyReLU(0.2))
        self.model = nn.Sequential(*layers)
        self.dropout = dropout

    def forward(self, x):
        x = self.model(x)
        if self.dropout:
            x = F.dropout2d(x, p=self.dropout)
        return x

class UNetUp(nn.Module):
    def __init__(self, in_size, out_size, dropout=0.0):
        super().__init__()
        layers = [
            nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_size),
            nn.ReLU(inplace=True),
        ]
        self.model = nn.Sequential(*layers)
        self.dropout = dropout

    def forward(self, x, skip_input):
        x = self.model(x)
        if self.dropout:
            x = F.dropout2d(x, p=self.dropout)
        x = torch.cat((x, skip_input), 1)
        return x

class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()

        self.encoder = nn.Sequential(
            self.block(in_channels, features, normalize=False),
            self.block(features, features * 2),
            self.block(features * 2, features * 4),
            self.block(features * 4, features * 8),
            self.block(features * 8, features * 8),
        )

        self.decoder = nn.Sequential(
            self.upblock(features * 8, features * 8),
            self.upblock(features * 8 * 2, features * 4),
            self.upblock(features * 4 * 2, features * 2),
            self.upblock(features * 2 * 2, features),
            nn.ConvTranspose2d(features * 2, out_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )

    def block(self, in_c, out_c, normalize=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2))
        return nn.Sequential(*layers)

    def upblock(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )

    def forward(self, x):
        skips = []
        for layer in self.encoder:
            x = layer(x)
            skips.append(x)

        skips = skips[:-1][::-1]  # reverse except last
        for idx, layer in enumerate(self.decoder[:-2]):
            x = layer(x)
            if idx < len(skips):
                x = torch.cat([x, skips[idx]], 1)

        return self.decoder[-2](x)


# ========== 推論処理 ==========
def load_generator(model_path, device):
    model = UNetGenerator().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

def denoise_and_correct(generator, input_path, output_path, device):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    img = Image.open(input_path).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)  # (1,3,H,W)

    with torch.no_grad():
        input_denoised = wavelet_denoise(input_tensor)
        output = generator(input_denoised)
        output_img = (output.squeeze(0).cpu() * 0.5) + 0.5  # [-1,1] -> [0,1]

    output_pil = transforms.ToPILImage()(output_img.clamp(0,1))
    output_pil.save(output_path)


def wavelet_denoise_and_save(input_path, output_path, device):
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    img = Image.open(input_path).convert('RGB')
    input_tensor = transform(img).unsqueeze(0).to(device)  # (1,3,H,W)

    with torch.no_grad():
        denoised_tensor = wavelet_denoise(input_tensor)

    # denoised_tensorは[-1,1]範囲なので[0,1]に変換
    denoised_img = (denoised_tensor.squeeze(0).cpu() * 0.5) + 0.5
    denoised_pil = transforms.ToPILImage()(denoised_img.clamp(0,1))
    denoised_pil.save(output_path)

    

# ========== メイン処理 ==========
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator = load_generator("C:\\Users\\sit\\wavelet_CGAN\\gen_weights_epoch100.pth", device)

    input_dir = "C:\\Users\\sit\\wavelet_CGAN\\testdata2\\renamed_FGSM"
    wavelet_dir = "C:\\Users\\sit\\wavelet_CGAN\\testdata2\\clean"
    output_dir = "C:\\Users\\sit\\wavelet_CGAN\\testdata2\\corrected"
    
    os.makedirs(wavelet_dir, exist_ok=True)
    os.makedirs(output_dir, exist_ok=True)

    # 1. Waveletノイズ除去画像を別フォルダに保存
    for fname in tqdm(os.listdir(input_dir)):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        input_path = os.path.join(input_dir, fname)
        wavelet_path = os.path.join(wavelet_dir, fname)
        wavelet_denoise_and_save(input_path, wavelet_path, device)

    # 2. GANで復元処理（Waveletノイズ除去画像を入力に）
    for fname in tqdm(os.listdir(wavelet_dir)):
        if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        input_path = os.path.join(wavelet_dir, fname)
        output_path = os.path.join(output_dir, fname)

        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*3, [0.5]*3)
        ])

        img = Image.open(input_path).convert('RGB')
        input_tensor = transform(img).unsqueeze(0).to(device)

        with torch.no_grad():
            output = generator(input_tensor)
            output_img = (output.squeeze(0).cpu() * 0.5) + 0.5  # [-1,1]→[0,1]

        output_pil = transforms.ToPILImage()(output_img.clamp(0,1))
        output_pil.save(output_path)

    print(f"Waveletノイズ除去画像を {wavelet_dir} に保存しました。")
    print(f"復元画像を {output_dir} に保存しました。")



100%|██████████| 998/998 [00:06<00:00, 145.87it/s]
100%|██████████| 998/998 [00:04<00:00, 204.50it/s]

Waveletノイズ除去画像を C:\Users\sit\wavelet_CGAN\testdata2\clean に保存しました。
復元画像を C:\Users\sit\wavelet_CGAN\testdata2\corrected に保存しました。





In [11]:
#元画像を256*256にするコード
import os
from PIL import Image
from tqdm import tqdm

input_dir ="C:\\Users\\sit\\wavelet_CGAN\\train\\normal"
output_dir = "C:\\Users\\sit\\wavelet_CGAN\\train\\normal256"

os.makedirs(output_dir, exist_ok=True)

for fname in tqdm(os.listdir(input_dir)):
    if not fname.lower().endswith((".png", ".jpg", ".jpeg")):
        continue
    input_path = os.path.join(input_dir, fname)
    output_path = os.path.join(output_dir, fname)

    img = Image.open(input_path).convert("RGB")
    img_resized = img.resize((256, 256), Image.BILINEAR)  # リサイズ
    img_resized.save(output_path)


100%|██████████| 5000/5000 [00:24<00:00, 201.20it/s]


In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
from tqdm import tqdm

# ===============================
# Dataset
# ===============================
class AE2CleanDataset(Dataset):
    def __init__(self, ae_dir, clean_dir, transform=None):
        self.ae_dir = ae_dir
        self.clean_dir = clean_dir
        self.transform = transform
        self.filenames = sorted(os.listdir(ae_dir))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        ae_path = os.path.join(self.ae_dir, self.filenames[idx])
        clean_path = os.path.join(self.clean_dir, self.filenames[idx])

        ae_img = Image.open(ae_path).convert("RGB")
        clean_img = Image.open(clean_path).convert("RGB")

        if self.transform:
            ae_img = self.transform(ae_img)
            clean_img = self.transform(clean_img)

        return ae_img, clean_img

# ===============================
# Pix2Pix Generator (U-Net)
# ===============================
class UNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        self.down1 = self.block(in_channels, features, normalize=False)
        self.down2 = self.block(features, features*2)
        self.down3 = self.block(features*2, features*4)
        self.down4 = self.block(features*4, features*8)
        self.down5 = self.block(features*8, features*8)

        self.up1 = self.upblock(features*8, features*8)
        self.up2 = self.upblock(features*8*2, features*4)
        self.up3 = self.upblock(features*4*2, features*2)
        self.up4 = self.upblock(features*2*2, features)
        self.final = nn.Sequential(
            nn.ConvTranspose2d(features*2, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def block(self, in_c, out_c, normalize=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)

    def upblock(self, in_c, out_c):
        return nn.Sequential(
            nn.ConvTranspose2d(in_c, out_c, 4, 2, 1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        bottleneck = self.down5(d4)

        u1 = self.up1(bottleneck)
        u1 = torch.cat([u1, d4], 1)
        u2 = self.up2(u1)
        u2 = torch.cat([u2, d3], 1)
        u3 = self.up3(u2)
        u3 = torch.cat([u3, d2], 1)
        u4 = self.up4(u3)
        u4 = torch.cat([u4, d1], 1)
        return self.final(u4)

# ===============================
# CRU-Net Generator
# ===============================
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(channels, channels, 3, padding=1),
            nn.BatchNorm2d(channels),
        )
    def forward(self, x):
        return torch.clamp(x + self.block(x), -1, 1)

class CRUNetGenerator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=64):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1),
            nn.ReLU(inplace=True),
            ResidualBlock(features),
            nn.Conv2d(features, features*2, 4, 2, 1),
            nn.ReLU(inplace=True),
            ResidualBlock(features*2)
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(features*2, features, 4, 2, 1),
            nn.ReLU(inplace=True),
            ResidualBlock(features),
            nn.ConvTranspose2d(features, out_channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, x):
        out = self.encoder(x)
        out = self.decoder(out)
        # Generator outputs residual noise
        return torch.clamp(x + out, -1, 1)

# ===============================
# Patch Discriminator
# ===============================
class PatchDiscriminator(nn.Module):
    def __init__(self, in_channels=6, features=64):
        super().__init__()
        self.model = nn.Sequential(
            self.block(in_channels, features, normalize=False),
            self.block(features, features*2),
            self.block(features*2, features*4),
            nn.Conv2d(features*4, 1, 4, 1, 1),
            nn.Sigmoid()
        )
    def block(self, in_c, out_c, normalize=True):
        layers = [nn.Conv2d(in_c, out_c, 4, 2, 1)]
        if normalize:
            layers.append(nn.BatchNorm2d(out_c))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        return nn.Sequential(*layers)
    def forward(self, x, y):
        return self.model(torch.cat([x, y], 1))

# ===============================
# Training Loop
# ===============================
def train_model(model_name, generator_class, ae_dir, clean_dir, epochs=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    dataset = AE2CleanDataset(ae_dir, clean_dir, transform)
    loader = DataLoader(dataset, batch_size=4, shuffle=True)

    gen = generator_class().to(device)
    disc = PatchDiscriminator().to(device)
    opt_g = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5, 0.999))
    opt_d = torch.optim.Adam(disc.parameters(), lr=2e-4, betas=(0.5, 0.999))
    bce = nn.BCELoss()
    l1 = nn.L1Loss()

    os.makedirs(f"./outputs_{model_name}", exist_ok=True)

    for epoch in range(epochs):
        loop = tqdm(loader, desc=f"{model_name} Epoch {epoch+1}/{epochs}")
        for ae_img, clean_img in loop:
            ae_img, clean_img = ae_img.to(device), clean_img.to(device)

            # --- Discriminator ---
            fake_img = gen(ae_img).detach()
            real_pred = disc(ae_img, clean_img)
            fake_pred = disc(ae_img, fake_img)
            d_loss = (bce(real_pred, torch.ones_like(real_pred)) +
                      bce(fake_pred, torch.zeros_like(fake_pred))) / 2

            opt_d.zero_grad()
            d_loss.backward()
            opt_d.step()

            # --- Generator ---
            fake_img = gen(ae_img)
            pred = disc(ae_img, fake_img)
            adv_loss = bce(pred, torch.ones_like(pred))
            l1_loss = l1(fake_img, clean_img) * 100
            g_loss = adv_loss + l1_loss

            opt_g.zero_grad()
            g_loss.backward()
            opt_g.step()

            loop.set_postfix(G_loss=g_loss.item(), D_loss=d_loss.item())

        # Save outputs every 10 epochs
        if (epoch+1) % 10 == 0:
            save_image(fake_img * 0.5 + 0.5, f"./outputs_{model_name}/fake_epoch{epoch+1}.png")
            torch.save(gen.state_dict(), f"./outputs_{model_name}/gen_epoch{epoch+1}.pth")

# ===============================
# Main
# ===============================
if __name__ == "__main__":
    ae_dir = "C:\\Users\\sit\\wavelet_CGAN\\train\\AE_wavelet"
    clean_dir = "C:\\Users\\sit\\wavelet_CGAN\\train\\normal256"

    train_model("pix2pix_unet", UNetGenerator, ae_dir, clean_dir, epochs=50)
    train_model("crunet_residual", CRUNetGenerator, ae_dir, clean_dir, epochs=50)


pix2pix_unet Epoch 1/50: 100%|██████████| 1250/1250 [00:32<00:00, 38.41it/s, D_loss=1.02, G_loss=7.11] 
pix2pix_unet Epoch 2/50: 100%|██████████| 1250/1250 [00:26<00:00, 47.02it/s, D_loss=0.0764, G_loss=7.17] 
pix2pix_unet Epoch 3/50: 100%|██████████| 1250/1250 [00:26<00:00, 46.83it/s, D_loss=0.11, G_loss=7.99]  
pix2pix_unet Epoch 4/50: 100%|██████████| 1250/1250 [00:26<00:00, 46.54it/s, D_loss=0.0152, G_loss=9.01] 
pix2pix_unet Epoch 5/50: 100%|██████████| 1250/1250 [00:26<00:00, 47.07it/s, D_loss=0.666, G_loss=4.67]  
pix2pix_unet Epoch 6/50: 100%|██████████| 1250/1250 [00:26<00:00, 47.07it/s, D_loss=0.685, G_loss=5.17]
pix2pix_unet Epoch 7/50: 100%|██████████| 1250/1250 [00:26<00:00, 47.04it/s, D_loss=0.468, G_loss=6.3] 
pix2pix_unet Epoch 8/50: 100%|██████████| 1250/1250 [00:26<00:00, 47.02it/s, D_loss=0.308, G_loss=7.04] 
pix2pix_unet Epoch 9/50: 100%|██████████| 1250/1250 [00:26<00:00, 46.89it/s, D_loss=1.04, G_loss=6.8]   
pix2pix_unet Epoch 10/50: 100%|██████████| 1250/1250 [0