In [1]:
import torch
import torch.nn as nn

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("We are using {} now!".format(device))

We are using cuda now!


In [3]:
print(torch.version.cuda)

9.2


In [4]:
class Network(nn.Module):
    def __init__(self):
        pass
    def forward(self, x):
        pass

In [5]:
class AE(nn.Module):
    """ Auto Encoder """
    def __init__(self, in_channels, out_channels, nf):
        super().__init__()
        self.net = nn.Sequential(nn.Conv2d(in_channels, nf, 3, 1, 1),
                                 nn.Conv2d(nf, nf*2, 3, 2, 1),
                                 nn.BatchNorm2d(nf*2),
                                 nn.ReLU(),
                                 nn.Conv2d(nf*2, nf*4, 3, 2, 1),
                                 nn.BatchNorm2d(nf*4),
                                 nn.ReLU(),
                                 nn.ConvTranspose2d(nf*4, nf*2, 4, 2, 1),
                                 nn.BatchNorm2d(nf*2),
                                 nn.ReLU(),
                                 nn.ConvTranspose2d(nf*2, nf, 4, 2, 1),
                                 nn.BatchNorm2d(nf),
                                 nn.ReLU(),
                                 nn.Conv2d(nf, out_channels, 3, 1, 1),
                                 nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.net(x)

In [9]:
ae = AE(in_channels=3, out_channels=3, nf=8)
print(ae)

AE(
  (net): Sequential(
    (0): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU()
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (8): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU()
    (10): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (11): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU()
    (13): Conv2d(8, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): Sigmoid()
  )
)


In [10]:
x = torch.randn(size=(1, 3, 256, 256))
out = ae(x)
print(out.size())
# 输出结果
# torch.Size([1, 3, 256, 256])

torch.Size([1, 3, 256, 256])
