/
regularized_inverse_scattering_MNIST.py
168 lines (133 loc) · 6.33 KB
/
regularized_inverse_scattering_MNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
"""
Regularized inverse of a scattering transform on MNIST
======================================================
Description:
This example trains a convolutional network to invert the scattering transform at scale 2 of MNIST digits.
After only two epochs, it produces a network that transforms a linear interpolation in the scattering space into a
nonlinear interpolation in the image space.
Remarks:
The model after two epochs and the path (which consists of a sequence of images) are stored in the cache directory.
The two epochs take roughly 5 minutes in a Quadro M6000.
Reference:
https://arxiv.org/abs/1805.06621
"""
import argparse
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from kymatio import Scattering2D as Scattering
from kymatio.caching import get_cache_dir
from kymatio.datasets import get_dataset_dir
class Generator(nn.Module):
def __init__(self, num_input_channels, num_hidden_channels, num_output_channels=1, filter_size=3):
super(Generator, self).__init__()
self.num_input_channels = num_input_channels
self.num_hidden_channels = num_hidden_channels
self.num_output_channels = num_output_channels
self.filter_size = filter_size
self.build()
def build(self):
padding = (self.filter_size - 1) // 2
self.main = nn.Sequential(
nn.ReflectionPad2d(padding),
nn.Conv2d(self.num_input_channels, self.num_hidden_channels, self.filter_size, bias=False),
nn.BatchNorm2d(self.num_hidden_channels, eps=0.001, momentum=0.9),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.ReflectionPad2d(padding),
nn.Conv2d(self.num_hidden_channels, self.num_hidden_channels, self.filter_size, bias=False),
nn.BatchNorm2d(self.num_hidden_channels, eps=0.001, momentum=0.9),
nn.ReLU(inplace=True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.ReflectionPad2d(padding),
nn.Conv2d(self.num_hidden_channels, self.num_output_channels, self.filter_size, bias=False),
nn.BatchNorm2d(self.num_output_channels, eps=0.001, momentum=0.9),
nn.Tanh()
)
def forward(self, input_tensor):
return self.main(input_tensor)
def main():
parser = argparse.ArgumentParser(description='Regularized inverse scattering')
parser.add_argument('--num_epochs', default=2, help='Number of epochs to train')
parser.add_argument('--load_model', default=False, help='Load a trained model?')
parser.add_argument('--dir_save_images', default='interpolation_images', help='Dir to save the sequence of images')
args = parser.parse_args()
num_epochs = args.num_epochs
load_model = args.load_model
dir_save_images = args.dir_save_images
dir_to_save = get_cache_dir('reg_inverse_example')
transforms_to_apply = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Pixel values should be in [-1,1]
])
mnist_dir = get_dataset_dir("MNIST", create=True)
dataset = datasets.MNIST(mnist_dir, train=True, download=True, transform=transforms_to_apply)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, pin_memory=True)
fixed_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
fixed_batch = next(iter(fixed_dataloader))
fixed_batch = fixed_batch[0].float().cuda()
scattering = Scattering(J=2, shape=(28, 28))
scattering.cuda()
scattering_fixed_batch = scattering(fixed_batch).squeeze(1)
num_input_channels = scattering_fixed_batch.shape[1]
num_hidden_channels = num_input_channels
generator = Generator(num_input_channels, num_hidden_channels)
generator.cuda()
generator.train()
# Either train the network or load a trained model
##################################################
if load_model:
filename_model = os.path.join(dir_to_save, 'model.pth')
generator.load_state_dict(torch.load(filename_model))
else:
criterion = torch.nn.L1Loss()
optimizer = optim.Adam(generator.parameters())
for idx_epoch in range(num_epochs):
print('Training epoch {}'.format(idx_epoch))
for _, current_batch in enumerate(dataloader):
generator.zero_grad()
batch_images = Variable(current_batch[0]).float().cuda()
batch_scattering = scattering(batch_images).squeeze(1)
batch_inverse_scattering = generator(batch_scattering)
loss = criterion(batch_inverse_scattering, batch_images)
loss.backward()
optimizer.step()
print('Saving results in {}'.format(dir_to_save))
torch.save(generator.state_dict(), os.path.join(dir_to_save, 'model.pth'))
generator.eval()
# We create the batch containing the linear interpolation points in the scattering space
########################################################################################
z0 = scattering_fixed_batch.cpu().numpy()[[0]]
z1 = scattering_fixed_batch.cpu().numpy()[[1]]
batch_z = np.copy(z0)
num_samples = 32
interval = np.linspace(0, 1, num_samples)
for t in interval:
if t > 0:
zt = (1 - t) * z0 + t * z1
batch_z = np.vstack((batch_z, zt))
z = torch.from_numpy(batch_z).float().cuda()
path = generator(z).data.cpu().numpy().squeeze(1)
path = (path + 1) / 2 # The pixels are now in [0, 1]
# We show and store the nonlinear interpolation in the image space
##################################################################
dir_path = os.path.join(dir_to_save, dir_save_images)
if not os.path.exists(dir_path):
os.makedirs(dir_path)
for idx_image in range(num_samples):
current_image = np.uint8(path[idx_image] * 255.0)
filename = os.path.join(dir_path, '{}.png'.format(idx_image))
Image.fromarray(current_image).save(filename)
plt.imshow(current_image, cmap='gray')
plt.axis('off')
plt.pause(0.1)
plt.draw()
if __name__ == '__main__':
main()