forked from EForoumandi/DroGAN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
109 lines (90 loc) · 3.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import config
from dataset import MapDataset
from Generator import Generator
from Discriminator import Discriminator
from utils import load_checkpoint
# Enable CUDA benchmarking to improve runtime performance.
torch.backends.cudnn.benchmark = True
def train_fn(
disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler,
):
"""
Train DroGAN for one epoch.
Args:
discriminator (nn.Module): The discriminator model.
generator (nn.Module): The generator model.
optimizer_disc (torch.optim.Optimizer): Optimizer for the discriminator.
optimizer_gen (torch.optim.Optimizer): Optimizer for the generator.
scaler_gen (torch.cuda.amp.GradScaler): Gradient scaler for the generator.
scaler_disc (torch.cuda.amp.GradScaler): Gradient scaler for the discriminator.
"""
loop = tqdm(dataloader, leave=True)
for idx, (x, y) in enumerate(loop):
x = x.to(config.DEVICE)
y = y.to(config.DEVICE)
with torch.cuda.amp.autocast():
y_fake = gen(x)
D_real = disc(x, y)
D_real_loss = bce(D_real, torch.ones_like(D_real))
D_fake = disc(x, y_fake.detach())
D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
D_loss = (D_real_loss + D_fake_loss) / 2
disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()
with torch.cuda.amp.autocast():
D_fake = disc(x, y_fake)
G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
G_loss = G_fake_loss + L1
opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
if idx % 10 == 0:
loop.set_postfix(
D_real=torch.sigmoid(D_real).mean().item(),
D_fake=torch.sigmoid(D_fake).mean().item(),
)
def main():
"""
Main function to setup the model, load data, and start training.
"""
disc = Discriminator(in_channels=3).to(config.DEVICE)
gen = Generator(in_channels=5, features=64).to(config.DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()
if config.LOAD_MODEL:
load_checkpoint(
config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
)
load_checkpoint(
config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
)
train_dataset = MapDataset(input_dir = config.Input_DIR, target_dir = config.Target_DIR )
train_loader = DataLoader(
train_dataset,
batch_size=config.BATCH_SIZE,
shuffle=True,
num_workers=config.NUM_WORKERS,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
# Training loop
for epoch in range(config.NUM_EPOCHS):
train_fn(
disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
)
if config.SAVE_MODEL and epoch % 1 == 0:
save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)
if __name__ == "__main__":
main()