-
Notifications
You must be signed in to change notification settings - Fork 0
/
distil.py
129 lines (114 loc) · 3.26 KB
/
distil.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from itertools import chain
import torch
from src.callbacks.gan import (
CycleGANLoss,
GANLoss,
IdenticalGANLoss,
PrepareGeneratorPhase,
GeneratorOptimizerCallback,
PrepareDiscriminatorPhase,
DiscriminatorLoss,
DiscriminatorOptimizerCallback
)
from src.callbacks.distillation import (
HiddenStateLoss,
TeacherStudentLoss,
)
from src.callbacks.visualization import LogImageCallback
from src.dataset import UnpairedDataset
from src.modules.generator import Generator
from src.modules.discriminator import NLayerDiscriminator, PixelDiscriminator
from src.runner import DistillRunner
from src.modules.loss import LSGanLoss
from src.utils.init_student import initialize_pretrained, transfer_student
from torchvision import transforms as T
from PIL import Image
train_ds = UnpairedDataset(
"./datasets/monet2photo/trainA_preprocessed",
"./datasets/monet2photo/trainB_preprocessed",
transforms=T.Compose([
T.Resize((300,300)),
T.RandomCrop((256, 256)),
T.RandomHorizontalFlip(),
T.ToTensor(),
])
)
train_dl = torch.utils.data.DataLoader(
train_ds,
batch_size=1,
shuffle=True
)
tr = transforms=T.Compose([
T.Resize((256,256)),
T.ToTensor(),
])
mipt_photo = tr(Image.open("./datasets/mipt.jpg"))
zinger_photo = tr(Image.open("./datasets/vk.jpg"))
model = {
"generator_ab": Generator(3, 3, n_blocks=9),
"generator_ba": Generator(3, 3, n_blocks=9),
"generator_s": Generator(3, 3, n_blocks=3),
"discriminator_a": PixelDiscriminator(3),
"discriminator_b": PixelDiscriminator(3),
}
optimizer = {
"generator": torch.optim.Adam(
chain(
model["generator_ab"].parameters(),
model["generator_ba"].parameters(),
),
lr=0.0002
),
"discriminator": torch.optim.Adam(
chain(
model["discriminator_a"].parameters(),
model["discriminator_b"].parameters()
),
lr=0.0002
)
}
callbacks = [
PrepareGeneratorPhase(),
GANLoss(),
CycleGANLoss(),
IdenticalGANLoss(ba_key="generator_s"),
GeneratorOptimizerCallback(
keys=[
"gan_loss",
"cycle_loss",
"identical_loss",
"hidden_state_loss",
"ts_difference",
],
weights=[1, 10, 5, 1, 10],
),
PrepareDiscriminatorPhase(),
DiscriminatorLoss(),
DiscriminatorOptimizerCallback(),
HiddenStateLoss(transfer_layer=[8]),
TeacherStudentLoss(),
LogImageCallback(model_key="generator_s"),
LogImageCallback(key="mipt", img=mipt_photo, model_key="generator_s"),
LogImageCallback(key="vk", img=zinger_photo, model_key="generator_s"),
]
criterion = {
"gan": LSGanLoss(),
"cycle": torch.nn.L1Loss(),
"identical": torch.nn.L1Loss(),
"hidden_state_loss": torch.nn.MSELoss(),
"teacher_student": torch.nn.L1Loss(),
}
initialize_pretrained("teacher/checkpoints/last.pth", model)
transfer_student("teacher/checkpoints/last.pth", model)
runner = DistillRunner(buffer_size=50, student_key="generator_s")
runner.train(
model=model,
optimizer=optimizer,
loaders={"train": train_dl},
callbacks=callbacks,
criterion=criterion,
num_epochs=100,
verbose=True,
logdir="student",
main_metric="identical_loss"
)