In [1]:
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
import matplotlib.pyplot as plt

In [2]:
import os

new_path = '/root/autodl-tmp/project'
os.chdir(new_path)

print(os.getcwd())

/root/autodl-tmp/project


In [3]:
import json
import os
from types import SimpleNamespace

def load_config(config_path='config.json'):
    with open(config_path) as f:
        config = json.load(f)

    args = SimpleNamespace(**config)
    return args

In [4]:
args = load_config() 

In [5]:
class ArtDataset(Dataset):
    def __init__(self, manifest_path, root_dir, transform=None):
        self.image_paths = []
        self.labels = []
        self.root_dir = root_dir
        self.transform = transform
        with open(manifest_path, 'r', encoding='utf-8') as f:
            for line in f:
                img_rel, label_rel = line.strip().rsplit(',', 1)
                img_path = os.path.join(root_dir, img_rel.replace('\\', '/'))
                #img_path = os.path.join(root_dir, img_rel.strip())
                #label_path = os.path.join(root_dir, label_rel.strip())
                label_path = os.path.join(self.root_dir, label_rel.replace('\\', '/'))
                if not os.path.exists(img_path) or not os.path.exists(label_path):
                    continue
                with open(label_path, 'r', encoding='utf-8') as lf:
                    label = int(lf.read().strip())
                self.image_paths.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx]

transform = transforms.Compose([
    transforms.Resize((args.im_size, args.im_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [6]:
os.makedirs('genimgs/Genre128GANAE/gen64', exist_ok=True)
os.makedirs('genimgs/Genre128GANAE/real', exist_ok=True)
os.makedirs('genimgs/Genre128GANAE/gen128', exist_ok=True)
os.makedirs('models/Genre128GANAE', exist_ok=True)
os.makedirs('figs/Genre128GANAE', exist_ok=True)

In [7]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(args.zdim + args.n_classes, 512 * 4 * 4),
            nn.BatchNorm1d(512 * 4 * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )

        def block(in_c, out_c):
            return nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.LeakyReLU(0.2, inplace=True)
            )

        self.deconv = nn.Sequential(
            block(512, 512),   # 8×8
            block(512, 256),   # 16×16
            block(256, 128),   # 32×32
            block(128, 64),    # 64×64
            block(64, 32),     # 128×128
            nn.Conv2d(32, 3, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, z, labels):
        y = F.one_hot(labels, args.n_classes).float()
        x = torch.cat([z, y], dim=1)
        x = self.fc(x)
        x = x.view(-1, 512, 4, 4)
        out = self.deconv(x)
        return out  # 128×128


In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def down(in_c, out_c):
            return nn.Sequential(
                spectral_norm(nn.Conv2d(in_c, out_c, 4, 2, 1)),  # >>> MOD <<< SpectralNorm
                nn.LeakyReLU(0.2, inplace=True)
            )

        self.main = nn.Sequential(
            down(3, 64),   # 64
            down(64, 128), # 32
            down(128, 256),# 16
            down(256, 512),# 8
            down(512, 512) # 4
        )
        self.adv_head = spectral_norm(nn.Conv2d(512, 1, 4))  # 4→1 logit
        self.cls_head = nn.Linear(512*4*4, args.n_classes)

    def forward(self, x):
        feat = self.main(x)
        adv_out = self.adv_head(feat).view(x.size(0))  # [B]
        cls_logits = self.cls_head(feat.view(x.size(0), -1))
        return adv_out, cls_logits


In [9]:
'''def train(args):
    losses_D, losses_G, iterations = [], [], []
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    G = Generator().to(device)
    D = Discriminator().to(device)
    
    opt_G = optim.Adam(G.parameters(), lr=args.lr_init, betas=(0.0, 0.9))   # >>> MOD <<<
    opt_D = optim.Adam(D.parameters(), lr=args.lr_init, betas=(0.0, 0.9))
    
    train_set = ArtDataset(os.path.join(args.data_root, 'genre-train-index.csv'), args.data_root, transform)
    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True)
    
    # -----------------------
    #  6. 损失函数
    # -----------------------
    ce_loss = nn.CrossEntropyLoss()
    
    # -----------------------
    #  7. 训练循环
    # -----------------------
    iter_count = 0
    for epoch in range(args.num_epochs):
        for real_imgs, real_labels in train_loader:
            iter_count += 1
            real_imgs, real_labels = real_imgs.to(device), real_labels.to(device)
            batch_size = real_imgs.size(0)
    
            #  ===== 1.  训练 Discriminator =====
            D.zero_grad()
            # real
            real_adv, real_cls = D(real_imgs)
            loss_D_real = F.relu(1.0 - real_adv).mean()
            loss_cls_real = ce_loss(real_cls, real_labels)
    
            # fake
            z = torch.randn(batch_size, args.zdim, device=device)
            rand_labels = torch.randint(0, args.n_classes, (batch_size,), device=device)  # >>> MOD <<<
            fake_imgs = G(z, rand_labels).detach()
            fake_adv, _ = D(fake_imgs)
            loss_D_fake = F.relu(1.0 + fake_adv).mean()
    
            loss_D = loss_D_real + loss_D_fake + 0.1 * loss_cls_real  # 权重平衡
            loss_D.backward()
            opt_D.step()
    
            #  ===== 2.  训练 Generator =====
            G.zero_grad()
            z = torch.randn(batch_size, args.zdim, device=device)
            rand_labels = torch.randint(0, args.n_classes, (batch_size,), device=device)
            gen_imgs = G(z, rand_labels)
            adv_out, cls_out = D(gen_imgs)
            loss_G_adv = -adv_out.mean()
            loss_G_cls = ce_loss(cls_out, rand_labels)
            loss_G = loss_G_adv + 0.1 * loss_G_cls  # 权重平衡
            loss_G.backward()
            opt_G.step()
    
            #  ===== 3.  Logging / save =====
            if iter_count % args.display_iter == 0:
                losses_D.append(loss_D.item())
                losses_G.append(loss_G.item())
                iterations.append(iter_count)
                print(f"Iter {iter_count} | D_loss: {loss_D.item():.4f} | G_loss: {loss_G.item():.4f}")
            if iter_count % args.save_iter == 0:
                torch.save(G.state_dict(), f'models/Genre128GANAE_mod/G_{iter_count}.pth')
                torch.save(D.state_dict(), f'models/Genre128GANAE_mod/D_{iter_count}.pth')
                with torch.no_grad():
                    vis_img = (gen_imgs[:64] + 1) / 2  # 反归一化
                    save_image(vis_img, f'genimgs/Genre128GANAE_mod/gen/{iter_count}.png', nrow=8)
                    vis_real = (real_imgs[:64] + 1) / 2
                    save_image(vis_real, f'genimgs/Genre128GANAE_mod/real/{iter_count}.png', nrow=8)
                    
    
    print("Training finished!")'''

In [9]:
def train(args):
    losses_D, losses_G, iterations = [], [], []

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    G = Generator().to(device)
    D = Discriminator().to(device)

    opt_G = torch.optim.Adam(G.parameters(), lr=args.lr_init, betas=(0.0, 0.9))
    opt_D = torch.optim.Adam(D.parameters(), lr=args.lr_init, betas=(0.0, 0.9))

    train_set = ArtDataset(
        os.path.join(args.data_root, 'genre-train-index.csv'),
        args.data_root, transform
    )
    train_loader = DataLoader(
        train_set, batch_size=args.batch_size,
        shuffle=True, num_workers=4, drop_last=True
    )
    ce_loss = nn.CrossEntropyLoss()
    
    if args.resume_G and args.resume_D:
        G.load_state_dict(torch.load(args.resume_G, map_location=device))
        D.load_state_dict(torch.load(args.resume_D, map_location=device))
        iter_count = args.init_iter
        print(f"Resumed from iter {iter_count}: loaded {args.resume_G} and {args.resume_D}")
    else:
        iter_count = 0
        
    for epoch in range(args.num_epochs):
        for real_imgs, real_labels in train_loader:
            iter_count += 1
            real_imgs, real_labels = real_imgs.to(device), real_labels.to(device)
            bsz = real_imgs.size(0)

            # ===== 1. 训练 Discriminator =====
            D.zero_grad()
            # real
            real_adv, real_cls = D(real_imgs)
            loss_D_real = F.relu(1.0 - real_adv).mean()
            loss_cls_real = ce_loss(real_cls, real_labels)

            # fake
            z = torch.randn(bsz, args.zdim, device=device)
            rand_labels = torch.randint(0, args.n_classes, (bsz,), device=device)
            fake_imgs = G(z, rand_labels).detach()
            fake_adv, _    = D(fake_imgs)
            loss_D_fake   = F.relu(1.0 + fake_adv).mean()

            loss_D = loss_D_real + loss_D_fake + 0.1 * loss_cls_real
            loss_D.backward()
            opt_D.step()

            # ===== 2. 训练 Generator =====
            G.zero_grad()
            z = torch.randn(bsz, args.zdim, device=device)
            rand_labels = torch.randint(0, args.n_classes, (bsz,), device=device)
            gen_imgs = G(z, rand_labels)
            adv_out, cls_out = D(gen_imgs)
            loss_G_adv = -adv_out.mean()
            loss_G_cls = ce_loss(cls_out, rand_labels)
            loss_G = loss_G_adv + 0.1 * loss_G_cls
            loss_G.backward()
            opt_G.step()

            # ===== 3. Logging & Save =====
            if iter_count % args.display_iter == 0:
                losses_D.append(loss_D.item())
                losses_G.append(loss_G.item())
                iterations.append(iter_count)
                print(f"Iter {iter_count} | D_loss: {loss_D.item():.4f} | G_loss: {loss_G.item():.4f}")

            if iter_count % args.save_iter == 0:
                # 保存模型
                torch.save(G.state_dict(), f'models/Genre128GANAE/G_{iter_count}.pth')
                torch.save(D.state_dict(), f'models/Genre128GANAE/D_{iter_count}.pth')
                with torch.no_grad():
                    # 128×128 生成图
                    vis128 = (gen_imgs[:64] + 1) / 2
                    save_image(vis128,
                               f'genimgs/Genre128GANAE/gen128/{iter_count}.png',
                               nrow=8)
                    # 64×64 下采样生成图
                    vis64 = (F.avg_pool2d(gen_imgs, 2)[:64] + 1) / 2
                    save_image(vis64,
                               f'genimgs/Genre128GANAE/gen64/{iter_count}.png',
                               nrow=8)
                    # 真实图
                    vis_real = (real_imgs[:64] + 1) / 2
                    save_image(vis_real,
                               f'genimgs/Genre128GANAE/real/{iter_count}.png',
                               nrow=8)

    # ===== 4. 画并保存 loss 曲线 =====
    plt.figure()
    plt.plot(iterations, losses_D, label='D_loss')
    plt.plot(iterations, losses_G, label='G_loss')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    plt.savefig('figs/Genre128GANAE/loss_curve.png')
    plt.close()

    print("Training finished!")


In [None]:
if __name__ == "__main__":
    args = load_config()
    train(args)

Resumed from iter 70000: loaded models/Genre128GANAE/G_70000.pth and models/Genre128GANAE/D_70000.pth
Iter 70100 | D_loss: 1.8538 | G_loss: 0.6718
Iter 70200 | D_loss: 1.7938 | G_loss: 0.4876
Iter 70300 | D_loss: 1.7765 | G_loss: 0.1833
Iter 70400 | D_loss: 1.7578 | G_loss: 0.3190
Iter 70500 | D_loss: 2.1508 | G_loss: 1.4386
Iter 70600 | D_loss: 1.6656 | G_loss: 0.4800
Iter 70700 | D_loss: 1.7859 | G_loss: 0.5015
Iter 70800 | D_loss: 1.6681 | G_loss: 0.3286
Iter 70900 | D_loss: 1.8388 | G_loss: 0.6184
Iter 71000 | D_loss: 2.0600 | G_loss: -0.5961
Iter 71100 | D_loss: 1.7951 | G_loss: 0.4537
Iter 71200 | D_loss: 1.9431 | G_loss: 1.4342
Iter 71300 | D_loss: 1.7684 | G_loss: 0.1655
Iter 71400 | D_loss: 1.8232 | G_loss: 0.1735
Iter 71500 | D_loss: 1.7191 | G_loss: 0.4928
Iter 71600 | D_loss: 1.7808 | G_loss: 0.6730
Iter 71700 | D_loss: 1.7620 | G_loss: 0.3916
Iter 71800 | D_loss: 1.6733 | G_loss: 0.5845
Iter 71900 | D_loss: 1.7601 | G_loss: 0.6060
Iter 72000 | D_loss: 1.8017 | G_loss: 0.29