-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
76fe507
commit bb68473
Showing
14 changed files
with
970 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,5 +17,6 @@ Desktop.ini | |
*.so | ||
*.egg | ||
*.egg-info | ||
*.pypirc | ||
dist | ||
build |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
include docs/source/index.rst | ||
include requirements.txt | ||
include LICENSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# coding=utf-8 | ||
import os, torch | ||
from torch.nn import CrossEntropyLoss | ||
from torch.autograd import Variable | ||
from jdit.trainer.gan.generate import GanTrainer | ||
from jdit.model import Model | ||
from jdit.optimizer import Optimizer | ||
from jdit.dataset import Cifar10 | ||
from mypackage.tricks import gradPenalty | ||
from mypackage.model.Tnet import NLayer_D, TWnet_G | ||
|
||
|
||
# from mypackage.tricks import jcbClamp | ||
|
||
|
||
class GenerateGanTrainer(GanTrainer): | ||
mode = "RGB" | ||
every_epoch_checkpoint = 50 # 2 | ||
every_epoch_changelr = 2 # 1 | ||
d_turn = 5 | ||
|
||
def __init__(self, logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, latent_shape): | ||
super(GenerateGanTrainer, self).__init__(logdir, nepochs, gpu_ids_abs, netG, netD, optG, optD, dataset, | ||
latent_shape=latent_shape) | ||
|
||
self.watcher.graph(netG, (4, *self.latent_shape), self.use_gpu) | ||
data, label = self.datasets.samples_train | ||
self.watcher.embedding(data, data, label) | ||
|
||
def compute_d_loss(self): | ||
d_fake = self.netD(self.fake.detach()) | ||
d_real = self.netD(self.ground_truth) | ||
|
||
var_dic = {} | ||
var_dic["GP"] = gp = gradPenalty(self.netD, self.ground_truth, self.fake, input=None, | ||
use_gpu=self.use_gpu) | ||
var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach() | ||
var_dic["LOSS_D"] = loss_d = d_fake.mean() - d_real.mean() + gp | ||
|
||
return loss_d, var_dic | ||
|
||
def compute_g_loss(self): | ||
d_fake = self.netD(self.fake) | ||
var_dic = {} | ||
# var_dic["JC"] = jc = jcbClamp(self.netG, self.input, use_gpu=self.use_gpu) | ||
# var_dic["LOSS_D"] = loss_g = -d_fake.mean() + jc | ||
var_dic["LOSS_G"] = loss_g = -d_fake.mean() | ||
return loss_g, var_dic | ||
|
||
# def compute_valid(self): | ||
# var_dic = {} | ||
# # fake = self.netG(self.input).detach() | ||
# d_fake = self.netD(self.fake).detach() | ||
# d_real = self.netD(self.ground_truth).detach() | ||
# # var_dic["G"] = loss_g = (-d_fake.mean()).detach() | ||
# # var_dic["GP"] = gp = ( | ||
# # gradPenalty(self.netD, self.ground_truth, self.fake, input=self.input, use_gpu=self.use_gpu)).detach() | ||
# # var_dic["D"] = loss_d = (d_fake.mean() - d_real.mean() + gp).detach() | ||
# var_dic["WD"] = w_distance = (d_real.mean() - d_fake.mean()).detach() | ||
# return var_dic | ||
|
||
def valid(self): | ||
if self.fixed_input is None: | ||
self.fixed_input = Variable() | ||
if self.use_gpu: | ||
self.fixed_input = self.fixed_input.cuda() | ||
fixed_input_cpu = Variable(torch.randn((32, *self.latent_shape))) | ||
self.mv_inplace(fixed_input_cpu, self.fixed_input) | ||
|
||
self.netG.eval() | ||
with torch.no_grad(): | ||
fake = self.netG(self.fixed_input).detach() | ||
self.watcher.image(fake, self.current_epoch, tag="Valid/Fixed_fake", grid_size=(4, 4), shuffle=False) | ||
self.watcher.set_training_progress_images(fake, grid_size=(4, 4)) | ||
|
||
var_dic = {} | ||
# var_dic["FID_SCORE"] = self.metric.evaluate_model_fid(self.netG, (256, *self.latent_shape), amount=8) | ||
# self.watcher.scalars(var_dic, self.step, tag="Valid") | ||
self.netG.train() | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
gpus = [2, 3] | ||
batch_shape = (128, 3, 32, 32) | ||
image_channel = batch_shape[1] | ||
nepochs = 200 | ||
mid_channel = 8 | ||
|
||
opt_G_name = "Adam" | ||
depth_G = 8 | ||
lr = 1e-3 | ||
lr_decay = 0.94 # 0.94 | ||
weight_decay = 0 # 2e-5 | ||
betas = (0.9, 0.999) | ||
G_mid_channel = 8 | ||
|
||
opt_D_name = "RMSprop" | ||
depth_D = 64 | ||
momentum = 0 | ||
D_mid_channel = 16 | ||
|
||
latent_shape = (256, 1, 1) | ||
print('===> Build dataset') | ||
cifar10 = Cifar10(batch_shape=batch_shape) | ||
torch.backends.cudnn.benchmark = True | ||
print('===> Building model') | ||
# D_net = NThickLayer_D(input_nc=image_channel, mid_channels=D_mid_channel, depth=depth_D, norm_type=None, | ||
# active_type="ReLU") | ||
D_net = NLayer_D(input_nc=image_channel, depth=depth_D, use_sigmoid=False, use_liner=False, norm_type="batch", | ||
active_type="ReLU") | ||
D = Model(D_net, gpu_ids_abs=gpus, init_method="kaiming") | ||
# ----------------------------------- | ||
G_net = TWnet_G(input_nc=latent_shape[0], mid_channels=G_mid_channel, output_nc=image_channel, depth=depth_G, | ||
norm_type="batch", | ||
active_type="LeakyReLU") | ||
G = Model(G_net, gpu_ids_abs=gpus, init_method="kaiming") | ||
print('===> Building optimizer') | ||
opt_D = Optimizer(D.parameters(), lr, lr_decay, weight_decay, momentum, betas, opt_D_name) | ||
opt_G = Optimizer(G.parameters(), lr, lr_decay, weight_decay, momentum, betas, opt_G_name) | ||
|
||
print('===> Training') | ||
Trainer = GenerateGanTrainer("log", nepochs, gpus, G, D, opt_G, opt_D, cifar10, latent_shape) | ||
Trainer.train() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import torch.utils.model_zoo as model_zoo | ||
from torch import nn | ||
import math | ||
|
||
|
||
def conv3x3(in_planes, out_planes, stride=1): | ||
"3x3 convolution with padding" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
|
||
|
||
class BasicBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, in_planes, out_planes, stride=1, downsample=None, drop_rate=0.2): | ||
super(BasicBlock, self).__init__() | ||
self.conv1 = conv3x3(in_planes, out_planes, stride) | ||
self.bn1 = nn.BatchNorm2d(out_planes) | ||
self.lrelu = nn.LeakyReLU(0.1) | ||
self.conv2 = conv3x3(out_planes, out_planes) | ||
self.bn2 = nn.BatchNorm2d(out_planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
self.drop2d = nn.Dropout2d(drop_rate) | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.drop2d(out) | ||
out = self.lrelu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.drop2d(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.lrelu(out) | ||
|
||
return out | ||
|
||
|
||
class ResNet(nn.Module): | ||
|
||
def __init__(self, block, layers, inplanes=64, num_classes=10): | ||
self.inplanes = inplanes | ||
super(ResNet, self).__init__() | ||
self.conv1 = nn.Conv2d(1, self.inplanes, kernel_size=3, stride=1, padding=1, | ||
bias=False) | ||
self.bn1 = nn.BatchNorm2d(self.inplanes) | ||
self.relu = nn.ReLU() | ||
# self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.layer1 = self._make_layer(block, self.inplanes, layers[0]) | ||
self.layer2 = self._make_layer(block, self.inplanes * 2, layers[1], stride=2) | ||
self.layer3 = self._make_layer(block, self.inplanes * 4, layers[2], stride=2) | ||
# self.layer4 = self._make_layer(block, 512, layers[3], stride=2) | ||
self.avgpool = nn.AvgPool2d(8) | ||
self.fc = nn.Linear(self.inplanes * 4 * block.expansion, num_classes) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample)) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
# x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x = self.layer3(x) | ||
# x = self.layer4(x) | ||
|
||
x = self.avgpool(x) | ||
x = x.view(x.size(0), -1) | ||
x = self.fc(x) | ||
|
||
return x | ||
|
||
|
||
model_urls = { | ||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', | ||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', | ||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', | ||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', | ||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', | ||
} | ||
|
||
|
||
def Resnet18(depth=64, pretrained=False): | ||
"""Constructs a ResNet-18 model. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
""" | ||
model = ResNet(BasicBlock, [2, 2, 2, 2], inplanes=depth) | ||
if pretrained: | ||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False) | ||
return model |
Empty file.
Oops, something went wrong.