In [4]:
import torch
from torch import optim
from torch import nn
import os

import datetime
import time


import matplotlib.pyplot as plt
import numpy as np

from tqdm import tqdm
import scipy.ndimage as nd
import scipy.io as io
import matplotlib

import skimage.measure as sk
from mpl_toolkits import mplot3d
import matplotlib.gridspec as gridspec
from torch.utils import data
from torch.autograd import Variable
import pickle

from collections import OrderedDict
import binvox_rw as binvox

from model import net_G, net_D
import params
import argparse


ModuleNotFoundError: No module named 'matplotlib'

In [None]:
os.getcwd()


In [None]:
# PARAMS
# Example: Setting some hyperparameters and paths (assuming this based on usual practice)
params = {
    "batch_size": 64,              
    "learning_rate": 0.0002,       
    "epochs": 100,                 
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"), 
    "model_save_path": "./models", 
    "data_path": "./data",        
    # Other hyperparameters...
}

In [None]:
# Initialize generator and discriminator
net_G = net_G().to(params["device"])  # 将生成器加载到指定设备
net_D = net_D().to(params["device"])  # 将判别器加载到指定设备

# Define loss function
criterion = nn.BCELoss()  # 二分类交叉熵损失，用于判断真假

# Define optimizers
optimizer_G = optim.Adam(net_G.parameters(), lr=params["learning_rate"])  # Adam优化器用于生成器
optimizer_D = optim.Adam(net_D.parameters(), lr=params["learning_rate"])  # Adam优化器用于判别器


In [None]:
# Load dataset
dataset = ...  # 加载数据集的逻辑（例如，使用torchvision.datasets等）
dataloader = torch.utils.data.DataLoader(dataset, batch_size=params["batch_size"], shuffle=True)

# Example: ImageFolder
# from torchvision import datasets, transforms
# transform = transforms.Compose([
#     transforms.Resize(64),
#     transforms.ToTensor(),
#     transforms.Normalize([0.5], [0.5])
# ])
# dataset = datasets.ImageFolder(root=params["data_path"], transform=transform)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=params["batch_size"], shuffle=True)


In [1]:
for epoch in range(params["epochs"]):
    for i, data in enumerate(dataloader):
        # 训练判别器
        optimizer_D.zero_grad()
        real_imgs = data[0].to(params["device"])
        real_labels = torch.ones(real_imgs.size(0), 1).to(params["device"])
        fake_labels = torch.zeros(real_imgs.size(0), 1).to(params["device"])
        
        outputs = net_D(real_imgs)
        d_loss_real = criterion(outputs, real_labels)
        d_loss_real.backward()
        
        z = torch.randn(real_imgs.size(0), 100).to(params["device"])  # 随机噪声输入生成器
        fake_imgs = net_G(z)
        outputs = net_D(fake_imgs.detach())
        d_loss_fake = criterion(outputs, fake_labels)
        d_loss_fake.backward()
        
        optimizer_D.step()

        # 训练生成器
        optimizer_G.zero_grad()
        outputs = net_D(fake_imgs)
        g_loss = criterion(outputs, real_labels)  # 生成器希望判别器认为假图像为真
        g_loss.backward()
        optimizer_G.step()
        
        # 打印和记录损失
        if i % 50 == 0:
            print(f"Epoch [{epoch}/{params['epochs']}], Step [{i}/{len(dataloader)}], d_loss: {d_loss_real.item() + d_loss_fake.item()}, g_loss: {g_loss.item()}")


NameError: name 'params' is not defined